Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a MemorySegment Vector scorer - for scoring without copying on-heap #13339

Merged
merged 43 commits into from
May 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
43 commits
Select commit Hold shift + click to select a range
1f64bad
Add a MemorySegment Vector scorer - for scoring without copying on-heap
ChrisHegarty May 2, 2024
8313c88
refactoring
ChrisHegarty May 3, 2024
8c6ab61
restore
ChrisHegarty May 3, 2024
89aa9a2
renames and cleanup
ChrisHegarty May 3, 2024
e9c24a0
Merge branch 'main' into msscorer
ChrisHegarty May 3, 2024
ede3dfe
move creation to VectorizationProvider - much nicer!
ChrisHegarty May 3, 2024
86c47b2
Merge remote-tracking branch 'origin/msscorer' into msscorer
ChrisHegarty May 3, 2024
2f6a9e2
fix benchmark
ChrisHegarty May 3, 2024
c6ef6ea
unused import
ChrisHegarty May 3, 2024
7a1faa1
MemorySegmentAccessInput refactor
ChrisHegarty May 3, 2024
d8c76d8
fix benchmark again
ChrisHegarty May 3, 2024
3b2bc63
remove scorer name checking from benchmark
ChrisHegarty May 4, 2024
8b14344
add level of indirection to avoid directly using VectorizationProvider
ChrisHegarty May 4, 2024
653bd26
unwrap only test filter index inputs
ChrisHegarty May 4, 2024
fa6db68
add FilterIndexInputAccess to register test filter classes
ChrisHegarty May 6, 2024
0223c94
Add cosine and max inner product
ChrisHegarty May 6, 2024
ccb1d09
rework into nested classes
ChrisHegarty May 6, 2024
2a1ba05
remove unwanted comment
ChrisHegarty May 9, 2024
5f0c553
add more tests
ChrisHegarty May 9, 2024
7dedf44
fix bug
ChrisHegarty May 9, 2024
330f55a
use as raw scorer in SQ
ChrisHegarty May 9, 2024
ed57037
more test improvements
ChrisHegarty May 9, 2024
05ebc42
unused
ChrisHegarty May 9, 2024
cba8281
more testing
ChrisHegarty May 9, 2024
8b3f3c2
expand test
ChrisHegarty May 10, 2024
17923f2
separate supplier and scorer
ChrisHegarty May 12, 2024
f06baf9
Merge branch 'main' into msscorer
ChrisHegarty May 12, 2024
eca47c3
fix benchmark
ChrisHegarty May 12, 2024
8efed14
include Lucene99 in the name
ChrisHegarty May 12, 2024
9edb423
fix license header
ChrisHegarty May 12, 2024
92dfdb2
Merge branch 'main' into msscorer
ChrisHegarty May 17, 2024
244352e
clean up and more tests
ChrisHegarty May 17, 2024
9742e1e
Merge remote-tracking branch 'upstream/main' into msscorer
ChrisHegarty May 21, 2024
2a7096e
test copies in threads do not interfere with each other
ChrisHegarty May 21, 2024
e018da1
fix compilation
ChrisHegarty May 21, 2024
a743907
static instance
ChrisHegarty May 21, 2024
c8c70ee
new -> get
ChrisHegarty May 21, 2024
b5a3f45
one more INSTANCE
ChrisHegarty May 21, 2024
ad271f3
make private
ChrisHegarty May 21, 2024
c42c9a1
add lucene99
ChrisHegarty May 21, 2024
e6cac8b
fix toString
ChrisHegarty May 21, 2024
d9bba27
Merge remote-tracking branch 'upstream/main' into msscorer
ChrisHegarty May 21, 2024
80229fb
tidy
ChrisHegarty May 21, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 2 additions & 0 deletions lucene/CHANGES.txt
Original file line number Diff line number Diff line change
Expand Up @@ -348,6 +348,8 @@ Optimizations

* GITHUB#13392: Replace Map<Long, Object> by primitive LongObjectHashMap. (Bruno Roustant)

* GITHUB#13339: Add a MemorySegment Vector scorer - for scoring without copying on-heap (Chris Hegarty)

Bug Fixes
---------------------

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.lucene.benchmark.jmh;

import static org.apache.lucene.index.VectorSimilarityFunction.DOT_PRODUCT;

import java.io.IOException;
import java.nio.file.Files;
import java.util.concurrent.ThreadLocalRandom;
import java.util.concurrent.TimeUnit;
import org.apache.lucene.codecs.hnsw.FlatVectorScorerUtil;
import org.apache.lucene.codecs.hnsw.FlatVectorsScorer;
import org.apache.lucene.codecs.lucene95.OffHeapByteVectorValues;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.store.Directory;
import org.apache.lucene.store.IOContext;
import org.apache.lucene.store.IndexInput;
import org.apache.lucene.store.IndexOutput;
import org.apache.lucene.store.MMapDirectory;
import org.apache.lucene.util.IOUtils;
import org.apache.lucene.util.hnsw.RandomAccessVectorValues;
import org.apache.lucene.util.hnsw.RandomVectorScorer;
import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier;
import org.openjdk.jmh.annotations.*;

@BenchmarkMode(Mode.Throughput)
@OutputTimeUnit(TimeUnit.MICROSECONDS)
@State(Scope.Benchmark)
// first iteration is complete garbage, so make sure we really warmup
@Warmup(iterations = 4, time = 1)
// real iterations. not useful to spend tons of time here, better to fork more
@Measurement(iterations = 5, time = 1)
// engage some noise reduction
@Fork(
value = 3,
jvmArgsAppend = {"-Xmx2g", "-Xms2g", "-XX:+AlwaysPreTouch"})
public class VectorScorerBenchmark {

@Param({"1", "128", "207", "256", "300", "512", "702", "1024"})
int size;

Directory dir;
IndexInput in;
RandomAccessVectorValues vectorValues;
byte[] vec1, vec2;
RandomVectorScorer scorer;

@Setup(Level.Iteration)
public void init() throws IOException {
vec1 = new byte[size];
vec2 = new byte[size];
ThreadLocalRandom.current().nextBytes(vec1);
ThreadLocalRandom.current().nextBytes(vec2);

dir = new MMapDirectory(Files.createTempDirectory("VectorScorerBenchmark"));
try (IndexOutput out = dir.createOutput("vector.data", IOContext.DEFAULT)) {
out.writeBytes(vec1, 0, vec1.length);
out.writeBytes(vec2, 0, vec2.length);
}
in = dir.openInput("vector.data", IOContext.DEFAULT);
vectorValues = vectorValues(size, 2, in, DOT_PRODUCT);
scorer =
FlatVectorScorerUtil.getLucene99FlatVectorsScorer()
.getRandomVectorScorerSupplier(DOT_PRODUCT, vectorValues)
.scorer(0);
}

@TearDown
public void teardown() throws IOException {
IOUtils.close(dir, in);
}

@Benchmark
public float binaryDotProductDefault() throws IOException {
return scorer.score(1);
}

@Benchmark
@Fork(jvmArgsPrepend = {"--add-modules=jdk.incubator.vector"})
public float binaryDotProductMemSeg() throws IOException {
return scorer.score(1);
}

static RandomAccessVectorValues vectorValues(
int dims, int size, IndexInput in, VectorSimilarityFunction sim) throws IOException {
return new OffHeapByteVectorValues.DenseOffHeapVectorValues(
dims, size, in.slice("test", 0, in.length()), dims, new ThrowingFlatVectorScorer(), sim);
}

static final class ThrowingFlatVectorScorer implements FlatVectorsScorer {

@Override
public RandomVectorScorerSupplier getRandomVectorScorerSupplier(
VectorSimilarityFunction similarityFunction, RandomAccessVectorValues vectorValues) {
throw new UnsupportedOperationException();
}

@Override
public RandomVectorScorer getRandomVectorScorer(
VectorSimilarityFunction similarityFunction,
RandomAccessVectorValues vectorValues,
float[] target) {
throw new UnsupportedOperationException();
}

@Override
public RandomVectorScorer getRandomVectorScorer(
VectorSimilarityFunction similarityFunction,
RandomAccessVectorValues vectorValues,
byte[] target) {
throw new UnsupportedOperationException();
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@
* @lucene.experimental
*/
public class DefaultFlatVectorScorer implements FlatVectorsScorer {

public static final DefaultFlatVectorScorer INSTANCE = new DefaultFlatVectorScorer();

@Override
public RandomVectorScorerSupplier getRandomVectorScorerSupplier(
VectorSimilarityFunction similarityFunction, RandomAccessVectorValues vectorValues)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.lucene.codecs.hnsw;

import org.apache.lucene.internal.vectorization.VectorizationProvider;

/**
* Utilities for {@link FlatVectorsScorer}.
*
* @lucene.experimental
*/
public final class FlatVectorScorerUtil {

private static final VectorizationProvider IMPL = VectorizationProvider.getInstance();

private FlatVectorScorerUtil() {}

/**
* Returns a FlatVectorsScorer that supports the Lucene99 format. Scorers retrieved through this
* method may be optimized on certain platforms. Otherwise, a DefaultFlatVectorScorer is returned.
*/
public static FlatVectorsScorer getLucene99FlatVectorsScorer() {
return IMPL.getLucene99FlatVectorsScorer();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
import org.apache.lucene.codecs.KnnVectorsFormat;
import org.apache.lucene.codecs.KnnVectorsReader;
import org.apache.lucene.codecs.KnnVectorsWriter;
import org.apache.lucene.codecs.hnsw.DefaultFlatVectorScorer;
import org.apache.lucene.codecs.hnsw.FlatVectorScorerUtil;
import org.apache.lucene.codecs.hnsw.FlatVectorsFormat;
import org.apache.lucene.codecs.lucene90.IndexedDISI;
import org.apache.lucene.index.MergePolicy;
Expand Down Expand Up @@ -139,7 +139,7 @@ public final class Lucene99HnswVectorsFormat extends KnnVectorsFormat {

/** The format for storing, reading, merging vectors on disk */
private static final FlatVectorsFormat flatVectorsFormat =
new Lucene99FlatVectorsFormat(new DefaultFlatVectorScorer());
new Lucene99FlatVectorsFormat(FlatVectorScorerUtil.getLucene99FlatVectorsScorer());

private final int numMergeWorkers;
private final TaskExecutor mergeExec;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

import java.io.IOException;
import org.apache.lucene.codecs.hnsw.DefaultFlatVectorScorer;
import org.apache.lucene.codecs.hnsw.FlatVectorScorerUtil;
import org.apache.lucene.codecs.hnsw.FlatVectorsFormat;
import org.apache.lucene.codecs.hnsw.FlatVectorsReader;
import org.apache.lucene.codecs.hnsw.FlatVectorsWriter;
Expand Down Expand Up @@ -48,7 +49,7 @@ public class Lucene99ScalarQuantizedVectorsFormat extends FlatVectorsFormat {
static final String VECTOR_DATA_EXTENSION = "veq";

private static final FlatVectorsFormat rawVectorFormat =
new Lucene99FlatVectorsFormat(new DefaultFlatVectorScorer());
new Lucene99FlatVectorsFormat(FlatVectorScorerUtil.getLucene99FlatVectorsScorer());

/** The minimum confidence interval */
private static final float MINIMUM_CONFIDENCE_INTERVAL = 0.9f;
Expand Down Expand Up @@ -101,7 +102,8 @@ public Lucene99ScalarQuantizedVectorsFormat(
this.bits = (byte) bits;
this.confidenceInterval = confidenceInterval;
this.compress = compress;
this.flatVectorScorer = new Lucene99ScalarQuantizedVectorScorer(new DefaultFlatVectorScorer());
this.flatVectorScorer =
new Lucene99ScalarQuantizedVectorScorer(DefaultFlatVectorScorer.INSTANCE);
}

public static float calculateDefaultConfidenceInterval(int vectorDimension) {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.lucene.internal.tests;

import org.apache.lucene.store.FilterIndexInput;

/**
* Access to {@link org.apache.lucene.store.FilterIndexInput} internals exposed to the test
* framework.
*
* @lucene.internal
*/
public interface FilterIndexInputAccess {
/** Adds the given test FilterIndexInput class. */
void addTestFilterType(Class<? extends FilterIndexInput> cls);
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.IndexWriter;
import org.apache.lucene.index.SegmentReader;
import org.apache.lucene.store.FilterIndexInput;

/**
* A set of static methods returning accessors for internal, package-private functionality in
Expand All @@ -48,12 +49,14 @@ public final class TestSecrets {
ensureInitialized.accept(ConcurrentMergeScheduler.class);
ensureInitialized.accept(SegmentReader.class);
ensureInitialized.accept(IndexWriter.class);
ensureInitialized.accept(FilterIndexInput.class);
}

private static IndexPackageAccess indexPackageAccess;
private static ConcurrentMergeSchedulerAccess cmsAccess;
private static SegmentReaderAccess segmentReaderAccess;
private static IndexWriterAccess indexWriterAccess;
private static FilterIndexInputAccess filterIndexInputAccess;

private TestSecrets() {}

Expand Down Expand Up @@ -81,6 +84,12 @@ public static IndexWriterAccess getIndexWriterAccess() {
return Objects.requireNonNull(indexWriterAccess);
}

/** Return the accessor to internal secrets for an {@link FilterIndexInput}. */
public static FilterIndexInputAccess getFilterInputIndexAccess() {
ensureCaller();
return Objects.requireNonNull(filterIndexInputAccess);
}

/** For internal initialization only. */
public static void setIndexWriterAccess(IndexWriterAccess indexWriterAccess) {
ensureNull(TestSecrets.indexWriterAccess);
Expand All @@ -105,6 +114,12 @@ public static void setSegmentReaderAccess(SegmentReaderAccess segmentReaderAcces
TestSecrets.segmentReaderAccess = segmentReaderAccess;
}

/** For internal initialization only. */
public static void setFilterInputIndexAccess(FilterIndexInputAccess filterIndexInputAccess) {
ensureNull(TestSecrets.filterIndexInputAccess);
TestSecrets.filterIndexInputAccess = filterIndexInputAccess;
}

private static void ensureNull(Object ob) {
if (ob != null) {
throw new AssertionError(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@

package org.apache.lucene.internal.vectorization;

import org.apache.lucene.codecs.hnsw.DefaultFlatVectorScorer;
import org.apache.lucene.codecs.hnsw.FlatVectorsScorer;

/** Default provider returning scalar implementations. */
final class DefaultVectorizationProvider extends VectorizationProvider {

Expand All @@ -30,4 +33,9 @@ final class DefaultVectorizationProvider extends VectorizationProvider {
public VectorUtilSupport getVectorUtilSupport() {
return vectorUtilSupport;
}

@Override
public FlatVectorsScorer getLucene99FlatVectorsScorer() {
return DefaultFlatVectorScorer.INSTANCE;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import java.util.function.Predicate;
import java.util.logging.Logger;
import java.util.stream.Stream;
import org.apache.lucene.codecs.hnsw.FlatVectorsScorer;
import org.apache.lucene.util.Constants;
import org.apache.lucene.util.VectorUtil;

Expand Down Expand Up @@ -91,6 +92,9 @@ public static VectorizationProvider getInstance() {
*/
public abstract VectorUtilSupport getVectorUtilSupport();

/** Returns a FlatVectorsScorer that supports the Lucene99 format. */
public abstract FlatVectorsScorer getLucene99FlatVectorsScorer();

// *** Lookup mechanism: ***

private static final Logger LOG = Logger.getLogger(VectorizationProvider.class.getName());
Expand Down Expand Up @@ -177,7 +181,10 @@ private static Optional<Module> lookupVectorModule() {
}

// add all possible callers here as FQCN:
private static final Set<String> VALID_CALLERS = Set.of("org.apache.lucene.util.VectorUtil");
private static final Set<String> VALID_CALLERS =
Set.of(
"org.apache.lucene.codecs.hnsw.FlatVectorScorerUtil",
"org.apache.lucene.util.VectorUtil");

private static void ensureCaller() {
final boolean validCaller =
Expand Down