From 507d791792a6c82e112f880b312ac56588df8b15 Mon Sep 17 00:00:00 2001 From: Chris Hegarty <62058229+ChrisHegarty@users.noreply.github.com> Date: Tue, 21 May 2024 17:34:37 +0100 Subject: [PATCH] Add a MemorySegment Vector scorer - for scoring without copying on-heap (#13339) Add a MemorySegment Vector scorer - for scoring without copying on-heap. The vector scorer loads values directly from the backing memory segment when available. Otherwise, if the vector data spans across segments the scorer copies the vector data on-heap. A benchmark shows ~2x performance improvement of this scorer over the default copy-on-heap scorer. The scorer currently only operates on vectors with an element size of byte. We can evaluate if and how to support floats separately. --- lucene/CHANGES.txt | 2 + .../codecs/hnsw/DefaultFlatVectorScorer.java | 3 + .../codecs/hnsw/FlatVectorScorerUtil.java | 40 ++ .../lucene99/Lucene99HnswVectorsFormat.java | 4 +- .../Lucene99ScalarQuantizedVectorsFormat.java | 6 +- .../tests/FilterIndexInputAccess.java | 31 ++ .../lucene/internal/tests/TestSecrets.java | 15 + .../DefaultVectorizationProvider.java | 8 + .../vectorization/VectorizationProvider.java | 9 +- .../apache/lucene/store/FilterIndexInput.java | 19 + ...Lucene99MemorySegmentByteVectorScorer.java | 151 +++++++ ...MemorySegmentByteVectorScorerSupplier.java | 210 +++++++++ ...ucene99MemorySegmentFlatVectorsScorer.java | 93 ++++ .../PanamaVectorUtilSupport.java | 101 +++-- .../PanamaVectorizationProvider.java | 6 + .../store/MemorySegmentAccessInput.java | 33 ++ .../lucene/store/MemorySegmentIndexInput.java | 31 +- .../codecs/hnsw/TestFlatVectorScorer.java | 16 +- ...estLucene99HnswQuantizedVectorsFormat.java | 13 +- .../TestLucene99HnswVectorsFormat.java | 14 +- .../lucene/index/VectorScorerBenchmark.java | 128 ++++++ .../vectorization/TestVectorScorer.java | 398 ++++++++++++++++++ .../search/BaseKnnVectorQueryTestCase.java | 34 +- .../search/TestKnnByteVectorQueryMMap.java | 36 ++ .../tests/store/MockIndexInputWrapper.java | 6 + .../SlowClosingMockIndexInputWrapper.java | 6 + .../SlowOpeningMockIndexInputWrapper.java | 6 + 27 files changed, 1349 insertions(+), 70 deletions(-) create mode 100644 lucene/core/src/java/org/apache/lucene/codecs/hnsw/FlatVectorScorerUtil.java create mode 100644 lucene/core/src/java/org/apache/lucene/internal/tests/FilterIndexInputAccess.java create mode 100644 lucene/core/src/java20/org/apache/lucene/internal/vectorization/Lucene99MemorySegmentByteVectorScorer.java create mode 100644 lucene/core/src/java20/org/apache/lucene/internal/vectorization/Lucene99MemorySegmentByteVectorScorerSupplier.java create mode 100644 lucene/core/src/java20/org/apache/lucene/internal/vectorization/Lucene99MemorySegmentFlatVectorsScorer.java create mode 100644 lucene/core/src/java21/org/apache/lucene/store/MemorySegmentAccessInput.java create mode 100644 lucene/core/src/test/org/apache/lucene/index/VectorScorerBenchmark.java create mode 100644 lucene/core/src/test/org/apache/lucene/internal/vectorization/TestVectorScorer.java create mode 100644 lucene/core/src/test/org/apache/lucene/search/TestKnnByteVectorQueryMMap.java diff --git a/lucene/CHANGES.txt b/lucene/CHANGES.txt index 6359aa5bc47..69d19a11b67 100644 --- a/lucene/CHANGES.txt +++ b/lucene/CHANGES.txt @@ -118,6 +118,8 @@ Optimizations * GITHUB#13327: Reduce memory usage of field maps in FieldInfos and BlockTree TermsReader. (Bruno Roustant, David Smiley) +* GITHUB#13339: Add a MemorySegment Vector scorer - for scoring without copying on-heap (Chris Hegarty) + Bug Fixes --------------------- diff --git a/lucene/core/src/java/org/apache/lucene/codecs/hnsw/DefaultFlatVectorScorer.java b/lucene/core/src/java/org/apache/lucene/codecs/hnsw/DefaultFlatVectorScorer.java index d89045b17d2..50fbc9851c6 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/hnsw/DefaultFlatVectorScorer.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/hnsw/DefaultFlatVectorScorer.java @@ -29,6 +29,9 @@ * @lucene.experimental */ public class DefaultFlatVectorScorer implements FlatVectorsScorer { + + public static final DefaultFlatVectorScorer INSTANCE = new DefaultFlatVectorScorer(); + @Override public RandomVectorScorerSupplier getRandomVectorScorerSupplier( VectorSimilarityFunction similarityFunction, RandomAccessVectorValues vectorValues) diff --git a/lucene/core/src/java/org/apache/lucene/codecs/hnsw/FlatVectorScorerUtil.java b/lucene/core/src/java/org/apache/lucene/codecs/hnsw/FlatVectorScorerUtil.java new file mode 100644 index 00000000000..808d7b3cc88 --- /dev/null +++ b/lucene/core/src/java/org/apache/lucene/codecs/hnsw/FlatVectorScorerUtil.java @@ -0,0 +1,40 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.lucene.codecs.hnsw; + +import org.apache.lucene.internal.vectorization.VectorizationProvider; + +/** + * Utilities for {@link FlatVectorsScorer}. + * + * @lucene.experimental + */ +public final class FlatVectorScorerUtil { + + private static final VectorizationProvider IMPL = VectorizationProvider.getInstance(); + + private FlatVectorScorerUtil() {} + + /** + * Returns a FlatVectorsScorer that supports the Lucene99 format. Scorers retrieved through this + * method may be optimized on certain platforms. Otherwise, a DefaultFlatVectorScorer is returned. + */ + public static FlatVectorsScorer getLucene99FlatVectorsScorer() { + return IMPL.getLucene99FlatVectorsScorer(); + } +} diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99HnswVectorsFormat.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99HnswVectorsFormat.java index 8c78a0cb0a0..3238fd1f4ae 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99HnswVectorsFormat.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99HnswVectorsFormat.java @@ -22,7 +22,7 @@ import org.apache.lucene.codecs.KnnVectorsFormat; import org.apache.lucene.codecs.KnnVectorsReader; import org.apache.lucene.codecs.KnnVectorsWriter; -import org.apache.lucene.codecs.hnsw.DefaultFlatVectorScorer; +import org.apache.lucene.codecs.hnsw.FlatVectorScorerUtil; import org.apache.lucene.codecs.hnsw.FlatVectorsFormat; import org.apache.lucene.codecs.lucene90.IndexedDISI; import org.apache.lucene.index.MergePolicy; @@ -139,7 +139,7 @@ public final class Lucene99HnswVectorsFormat extends KnnVectorsFormat { /** The format for storing, reading, merging vectors on disk */ private static final FlatVectorsFormat flatVectorsFormat = - new Lucene99FlatVectorsFormat(new DefaultFlatVectorScorer()); + new Lucene99FlatVectorsFormat(FlatVectorScorerUtil.getLucene99FlatVectorsScorer()); private final int numMergeWorkers; private final TaskExecutor mergeExec; diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99ScalarQuantizedVectorsFormat.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99ScalarQuantizedVectorsFormat.java index c10f87da2a6..26fa791468d 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99ScalarQuantizedVectorsFormat.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99ScalarQuantizedVectorsFormat.java @@ -19,6 +19,7 @@ import java.io.IOException; import org.apache.lucene.codecs.hnsw.DefaultFlatVectorScorer; +import org.apache.lucene.codecs.hnsw.FlatVectorScorerUtil; import org.apache.lucene.codecs.hnsw.FlatVectorsFormat; import org.apache.lucene.codecs.hnsw.FlatVectorsReader; import org.apache.lucene.codecs.hnsw.FlatVectorsWriter; @@ -48,7 +49,7 @@ public class Lucene99ScalarQuantizedVectorsFormat extends FlatVectorsFormat { static final String VECTOR_DATA_EXTENSION = "veq"; private static final FlatVectorsFormat rawVectorFormat = - new Lucene99FlatVectorsFormat(new DefaultFlatVectorScorer()); + new Lucene99FlatVectorsFormat(FlatVectorScorerUtil.getLucene99FlatVectorsScorer()); /** The minimum confidence interval */ private static final float MINIMUM_CONFIDENCE_INTERVAL = 0.9f; @@ -101,7 +102,8 @@ public Lucene99ScalarQuantizedVectorsFormat( this.bits = (byte) bits; this.confidenceInterval = confidenceInterval; this.compress = compress; - this.flatVectorScorer = new Lucene99ScalarQuantizedVectorScorer(new DefaultFlatVectorScorer()); + this.flatVectorScorer = + new Lucene99ScalarQuantizedVectorScorer(DefaultFlatVectorScorer.INSTANCE); } public static float calculateDefaultConfidenceInterval(int vectorDimension) { diff --git a/lucene/core/src/java/org/apache/lucene/internal/tests/FilterIndexInputAccess.java b/lucene/core/src/java/org/apache/lucene/internal/tests/FilterIndexInputAccess.java new file mode 100644 index 00000000000..eee40b43610 --- /dev/null +++ b/lucene/core/src/java/org/apache/lucene/internal/tests/FilterIndexInputAccess.java @@ -0,0 +1,31 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.lucene.internal.tests; + +import org.apache.lucene.store.FilterIndexInput; + +/** + * Access to {@link org.apache.lucene.store.FilterIndexInput} internals exposed to the test + * framework. + * + * @lucene.internal + */ +public interface FilterIndexInputAccess { + /** Adds the given test FilterIndexInput class. */ + void addTestFilterType(Class cls); +} diff --git a/lucene/core/src/java/org/apache/lucene/internal/tests/TestSecrets.java b/lucene/core/src/java/org/apache/lucene/internal/tests/TestSecrets.java index e2d74fc6ae6..cfcf2008c3e 100644 --- a/lucene/core/src/java/org/apache/lucene/internal/tests/TestSecrets.java +++ b/lucene/core/src/java/org/apache/lucene/internal/tests/TestSecrets.java @@ -23,6 +23,7 @@ import org.apache.lucene.index.IndexReader; import org.apache.lucene.index.IndexWriter; import org.apache.lucene.index.SegmentReader; +import org.apache.lucene.store.FilterIndexInput; /** * A set of static methods returning accessors for internal, package-private functionality in @@ -48,12 +49,14 @@ public final class TestSecrets { ensureInitialized.accept(ConcurrentMergeScheduler.class); ensureInitialized.accept(SegmentReader.class); ensureInitialized.accept(IndexWriter.class); + ensureInitialized.accept(FilterIndexInput.class); } private static IndexPackageAccess indexPackageAccess; private static ConcurrentMergeSchedulerAccess cmsAccess; private static SegmentReaderAccess segmentReaderAccess; private static IndexWriterAccess indexWriterAccess; + private static FilterIndexInputAccess filterIndexInputAccess; private TestSecrets() {} @@ -81,6 +84,12 @@ public static IndexWriterAccess getIndexWriterAccess() { return Objects.requireNonNull(indexWriterAccess); } + /** Return the accessor to internal secrets for an {@link FilterIndexInput}. */ + public static FilterIndexInputAccess getFilterInputIndexAccess() { + ensureCaller(); + return Objects.requireNonNull(filterIndexInputAccess); + } + /** For internal initialization only. */ public static void setIndexWriterAccess(IndexWriterAccess indexWriterAccess) { ensureNull(TestSecrets.indexWriterAccess); @@ -105,6 +114,12 @@ public static void setSegmentReaderAccess(SegmentReaderAccess segmentReaderAcces TestSecrets.segmentReaderAccess = segmentReaderAccess; } + /** For internal initialization only. */ + public static void setFilterInputIndexAccess(FilterIndexInputAccess filterIndexInputAccess) { + ensureNull(TestSecrets.filterIndexInputAccess); + TestSecrets.filterIndexInputAccess = filterIndexInputAccess; + } + private static void ensureNull(Object ob) { if (ob != null) { throw new AssertionError( diff --git a/lucene/core/src/java/org/apache/lucene/internal/vectorization/DefaultVectorizationProvider.java b/lucene/core/src/java/org/apache/lucene/internal/vectorization/DefaultVectorizationProvider.java index f3d9aa95fd3..c5193aa23de 100644 --- a/lucene/core/src/java/org/apache/lucene/internal/vectorization/DefaultVectorizationProvider.java +++ b/lucene/core/src/java/org/apache/lucene/internal/vectorization/DefaultVectorizationProvider.java @@ -17,6 +17,9 @@ package org.apache.lucene.internal.vectorization; +import org.apache.lucene.codecs.hnsw.DefaultFlatVectorScorer; +import org.apache.lucene.codecs.hnsw.FlatVectorsScorer; + /** Default provider returning scalar implementations. */ final class DefaultVectorizationProvider extends VectorizationProvider { @@ -30,4 +33,9 @@ final class DefaultVectorizationProvider extends VectorizationProvider { public VectorUtilSupport getVectorUtilSupport() { return vectorUtilSupport; } + + @Override + public FlatVectorsScorer getLucene99FlatVectorsScorer() { + return DefaultFlatVectorScorer.INSTANCE; + } } diff --git a/lucene/core/src/java/org/apache/lucene/internal/vectorization/VectorizationProvider.java b/lucene/core/src/java/org/apache/lucene/internal/vectorization/VectorizationProvider.java index 32ee2aa97b3..19a4694b6ec 100644 --- a/lucene/core/src/java/org/apache/lucene/internal/vectorization/VectorizationProvider.java +++ b/lucene/core/src/java/org/apache/lucene/internal/vectorization/VectorizationProvider.java @@ -29,6 +29,7 @@ import java.util.function.Predicate; import java.util.logging.Logger; import java.util.stream.Stream; +import org.apache.lucene.codecs.hnsw.FlatVectorsScorer; import org.apache.lucene.util.Constants; import org.apache.lucene.util.VectorUtil; @@ -93,6 +94,9 @@ public static VectorizationProvider getInstance() { */ public abstract VectorUtilSupport getVectorUtilSupport(); + /** Returns a FlatVectorsScorer that supports the Lucene99 format. */ + public abstract FlatVectorsScorer getLucene99FlatVectorsScorer(); + // *** Lookup mechanism: *** private static final Logger LOG = Logger.getLogger(VectorizationProvider.class.getName()); @@ -199,7 +203,10 @@ private static boolean isAffectedByJDK8301190() { } // add all possible callers here as FQCN: - private static final Set VALID_CALLERS = Set.of("org.apache.lucene.util.VectorUtil"); + private static final Set VALID_CALLERS = + Set.of( + "org.apache.lucene.codecs.hnsw.FlatVectorScorerUtil", + "org.apache.lucene.util.VectorUtil"); private static void ensureCaller() { final boolean validCaller = diff --git a/lucene/core/src/java/org/apache/lucene/store/FilterIndexInput.java b/lucene/core/src/java/org/apache/lucene/store/FilterIndexInput.java index 5b4a5c506ee..9e60a51790f 100644 --- a/lucene/core/src/java/org/apache/lucene/store/FilterIndexInput.java +++ b/lucene/core/src/java/org/apache/lucene/store/FilterIndexInput.java @@ -17,6 +17,8 @@ package org.apache.lucene.store; import java.io.IOException; +import java.util.concurrent.CopyOnWriteArrayList; +import org.apache.lucene.internal.tests.TestSecrets; /** * IndexInput implementation that delegates calls to another directory. This class can be used to @@ -29,6 +31,12 @@ */ public class FilterIndexInput extends IndexInput { + static final CopyOnWriteArrayList> TEST_FILTER_INPUTS = new CopyOnWriteArrayList<>(); + + static { + TestSecrets.setFilterInputIndexAccess(TEST_FILTER_INPUTS::add); + } + /** * Unwraps all FilterIndexInputs until the first non-FilterIndexInput IndexInput instance and * returns it @@ -40,6 +48,17 @@ public static IndexInput unwrap(IndexInput in) { return in; } + /** + * Unwraps all test FilterIndexInputs until the first non-test FilterIndexInput IndexInput + * instance and returns it + */ + public static IndexInput unwrapOnlyTest(IndexInput in) { + while (in instanceof FilterIndexInput && TEST_FILTER_INPUTS.contains(in.getClass())) { + in = ((FilterIndexInput) in).in; + } + return in; + } + protected final IndexInput in; /** Creates a FilterIndexInput with a resource description and wrapped delegate IndexInput */ diff --git a/lucene/core/src/java20/org/apache/lucene/internal/vectorization/Lucene99MemorySegmentByteVectorScorer.java b/lucene/core/src/java20/org/apache/lucene/internal/vectorization/Lucene99MemorySegmentByteVectorScorer.java new file mode 100644 index 00000000000..aae36204240 --- /dev/null +++ b/lucene/core/src/java20/org/apache/lucene/internal/vectorization/Lucene99MemorySegmentByteVectorScorer.java @@ -0,0 +1,151 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.lucene.internal.vectorization; + +import java.io.IOException; +import java.lang.foreign.MemorySegment; +import java.util.Optional; +import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.store.FilterIndexInput; +import org.apache.lucene.store.IndexInput; +import org.apache.lucene.store.MemorySegmentAccessInput; +import org.apache.lucene.util.hnsw.RandomAccessVectorValues; +import org.apache.lucene.util.hnsw.RandomVectorScorer; + +abstract sealed class Lucene99MemorySegmentByteVectorScorer + extends RandomVectorScorer.AbstractRandomVectorScorer { + + final int vectorByteSize; + final MemorySegmentAccessInput input; + final MemorySegment query; + byte[] scratch; + + /** + * Return an optional whose value, if present, is the scorer. Otherwise, an empty optional is + * returned. + */ + public static Optional create( + VectorSimilarityFunction type, + IndexInput input, + RandomAccessVectorValues values, + byte[] queryVector) { + input = FilterIndexInput.unwrapOnlyTest(input); + if (!(input instanceof MemorySegmentAccessInput msInput)) { + return Optional.empty(); + } + checkInvariants(values.size(), values.getVectorByteLength(), input); + return switch (type) { + case COSINE -> Optional.of(new CosineScorer(msInput, values, queryVector)); + case DOT_PRODUCT -> Optional.of(new DotProductScorer(msInput, values, queryVector)); + case EUCLIDEAN -> Optional.of(new EuclideanScorer(msInput, values, queryVector)); + case MAXIMUM_INNER_PRODUCT -> Optional.of( + new MaxInnerProductScorer(msInput, values, queryVector)); + }; + } + + Lucene99MemorySegmentByteVectorScorer( + MemorySegmentAccessInput input, RandomAccessVectorValues values, byte[] queryVector) { + super(values); + this.input = input; + this.vectorByteSize = values.getVectorByteLength(); + this.query = MemorySegment.ofArray(queryVector); + } + + final MemorySegment getSegment(int ord) throws IOException { + checkOrdinal(ord); + long byteOffset = (long) ord * vectorByteSize; + MemorySegment seg = input.segmentSliceOrNull(byteOffset, vectorByteSize); + if (seg == null) { + if (scratch == null) { + scratch = new byte[vectorByteSize]; + } + input.readBytes(byteOffset, scratch, 0, vectorByteSize); + seg = MemorySegment.ofArray(scratch); + } + return seg; + } + + static void checkInvariants(int maxOrd, int vectorByteLength, IndexInput input) { + if (input.length() < (long) vectorByteLength * maxOrd) { + throw new IllegalArgumentException("input length is less than expected vector data"); + } + } + + final void checkOrdinal(int ord) { + if (ord < 0 || ord >= maxOrd()) { + throw new IllegalArgumentException("illegal ordinal: " + ord); + } + } + + static final class CosineScorer extends Lucene99MemorySegmentByteVectorScorer { + CosineScorer(MemorySegmentAccessInput input, RandomAccessVectorValues values, byte[] query) { + super(input, values, query); + } + + @Override + public float score(int node) throws IOException { + checkOrdinal(node); + float raw = PanamaVectorUtilSupport.cosine(query, getSegment(node)); + return (1 + raw) / 2; + } + } + + static final class DotProductScorer extends Lucene99MemorySegmentByteVectorScorer { + DotProductScorer( + MemorySegmentAccessInput input, RandomAccessVectorValues values, byte[] query) { + super(input, values, query); + } + + @Override + public float score(int node) throws IOException { + checkOrdinal(node); + // divide by 2 * 2^14 (maximum absolute value of product of 2 signed bytes) * len + float raw = PanamaVectorUtilSupport.dotProduct(query, getSegment(node)); + return 0.5f + raw / (float) (query.byteSize() * (1 << 15)); + } + } + + static final class EuclideanScorer extends Lucene99MemorySegmentByteVectorScorer { + EuclideanScorer(MemorySegmentAccessInput input, RandomAccessVectorValues values, byte[] query) { + super(input, values, query); + } + + @Override + public float score(int node) throws IOException { + checkOrdinal(node); + float raw = PanamaVectorUtilSupport.squareDistance(query, getSegment(node)); + return 1 / (1f + raw); + } + } + + static final class MaxInnerProductScorer extends Lucene99MemorySegmentByteVectorScorer { + MaxInnerProductScorer( + MemorySegmentAccessInput input, RandomAccessVectorValues values, byte[] query) { + super(input, values, query); + } + + @Override + public float score(int node) throws IOException { + checkOrdinal(node); + float raw = PanamaVectorUtilSupport.dotProduct(query, getSegment(node)); + if (raw < 0) { + return 1 / (1 + -1 * raw); + } + return raw + 1; + } + } +} diff --git a/lucene/core/src/java20/org/apache/lucene/internal/vectorization/Lucene99MemorySegmentByteVectorScorerSupplier.java b/lucene/core/src/java20/org/apache/lucene/internal/vectorization/Lucene99MemorySegmentByteVectorScorerSupplier.java new file mode 100644 index 00000000000..90b3bfb014c --- /dev/null +++ b/lucene/core/src/java20/org/apache/lucene/internal/vectorization/Lucene99MemorySegmentByteVectorScorerSupplier.java @@ -0,0 +1,210 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.lucene.internal.vectorization; + +import java.io.IOException; +import java.lang.foreign.MemorySegment; +import java.util.Optional; +import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.store.FilterIndexInput; +import org.apache.lucene.store.IndexInput; +import org.apache.lucene.store.MemorySegmentAccessInput; +import org.apache.lucene.util.hnsw.RandomAccessVectorValues; +import org.apache.lucene.util.hnsw.RandomVectorScorer; +import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier; + +/** A score supplier of vectors whose element size is byte. */ +public abstract sealed class Lucene99MemorySegmentByteVectorScorerSupplier + implements RandomVectorScorerSupplier { + final int vectorByteSize; + final int maxOrd; + final MemorySegmentAccessInput input; + final RandomAccessVectorValues values; // to support ordToDoc/getAcceptOrds + byte[] scratch1, scratch2; + + /** + * Return an optional whose value, if present, is the scorer supplier. Otherwise, an empty + * optional is returned. + */ + static Optional create( + VectorSimilarityFunction type, IndexInput input, RandomAccessVectorValues values) { + input = FilterIndexInput.unwrapOnlyTest(input); + if (!(input instanceof MemorySegmentAccessInput msInput)) { + return Optional.empty(); + } + checkInvariants(values.size(), values.getVectorByteLength(), input); + return switch (type) { + case COSINE -> Optional.of(new CosineSupplier(msInput, values)); + case DOT_PRODUCT -> Optional.of(new DotProductSupplier(msInput, values)); + case EUCLIDEAN -> Optional.of(new EuclideanSupplier(msInput, values)); + case MAXIMUM_INNER_PRODUCT -> Optional.of(new MaxInnerProductSupplier(msInput, values)); + }; + } + + Lucene99MemorySegmentByteVectorScorerSupplier( + MemorySegmentAccessInput input, RandomAccessVectorValues values) { + this.input = input; + this.values = values; + this.vectorByteSize = values.getVectorByteLength(); + this.maxOrd = values.size(); + } + + static void checkInvariants(int maxOrd, int vectorByteLength, IndexInput input) { + if (input.length() < (long) vectorByteLength * maxOrd) { + throw new IllegalArgumentException("input length is less than expected vector data"); + } + } + + final void checkOrdinal(int ord) { + if (ord < 0 || ord >= maxOrd) { + throw new IllegalArgumentException("illegal ordinal: " + ord); + } + } + + final MemorySegment getFirstSegment(int ord) throws IOException { + long byteOffset = (long) ord * vectorByteSize; + MemorySegment seg = input.segmentSliceOrNull(byteOffset, vectorByteSize); + if (seg == null) { + if (scratch1 == null) { + scratch1 = new byte[vectorByteSize]; + } + input.readBytes(byteOffset, scratch1, 0, vectorByteSize); + seg = MemorySegment.ofArray(scratch1); + } + return seg; + } + + final MemorySegment getSecondSegment(int ord) throws IOException { + long byteOffset = (long) ord * vectorByteSize; + MemorySegment seg = input.segmentSliceOrNull(byteOffset, vectorByteSize); + if (seg == null) { + if (scratch2 == null) { + scratch2 = new byte[vectorByteSize]; + } + input.readBytes(byteOffset, scratch2, 0, vectorByteSize); + seg = MemorySegment.ofArray(scratch2); + } + return seg; + } + + static final class CosineSupplier extends Lucene99MemorySegmentByteVectorScorerSupplier { + + CosineSupplier(MemorySegmentAccessInput input, RandomAccessVectorValues values) { + super(input, values); + } + + @Override + public RandomVectorScorer scorer(int ord) { + checkOrdinal(ord); + return new RandomVectorScorer.AbstractRandomVectorScorer(values) { + @Override + public float score(int node) throws IOException { + checkOrdinal(node); + float raw = PanamaVectorUtilSupport.cosine(getFirstSegment(ord), getSecondSegment(node)); + return (1 + raw) / 2; + } + }; + } + + @Override + public CosineSupplier copy() throws IOException { + return new CosineSupplier(input.clone(), values); + } + } + + static final class DotProductSupplier extends Lucene99MemorySegmentByteVectorScorerSupplier { + + DotProductSupplier(MemorySegmentAccessInput input, RandomAccessVectorValues values) { + super(input, values); + } + + @Override + public RandomVectorScorer scorer(int ord) { + checkOrdinal(ord); + return new RandomVectorScorer.AbstractRandomVectorScorer(values) { + @Override + public float score(int node) throws IOException { + checkOrdinal(node); + // divide by 2 * 2^14 (maximum absolute value of product of 2 signed bytes) * len + float raw = + PanamaVectorUtilSupport.dotProduct(getFirstSegment(ord), getSecondSegment(node)); + return 0.5f + raw / (float) (values.dimension() * (1 << 15)); + } + }; + } + + @Override + public DotProductSupplier copy() throws IOException { + return new DotProductSupplier(input.clone(), values); + } + } + + static final class EuclideanSupplier extends Lucene99MemorySegmentByteVectorScorerSupplier { + + EuclideanSupplier(MemorySegmentAccessInput input, RandomAccessVectorValues values) { + super(input, values); + } + + @Override + public RandomVectorScorer scorer(int ord) { + checkOrdinal(ord); + return new RandomVectorScorer.AbstractRandomVectorScorer(values) { + @Override + public float score(int node) throws IOException { + checkOrdinal(node); + float raw = + PanamaVectorUtilSupport.squareDistance(getFirstSegment(ord), getSecondSegment(node)); + return 1 / (1f + raw); + } + }; + } + + @Override + public EuclideanSupplier copy() throws IOException { + return new EuclideanSupplier(input.clone(), values); + } + } + + static final class MaxInnerProductSupplier extends Lucene99MemorySegmentByteVectorScorerSupplier { + + MaxInnerProductSupplier(MemorySegmentAccessInput input, RandomAccessVectorValues values) { + super(input, values); + } + + @Override + public RandomVectorScorer scorer(int ord) { + checkOrdinal(ord); + return new RandomVectorScorer.AbstractRandomVectorScorer(values) { + @Override + public float score(int node) throws IOException { + checkOrdinal(node); + float raw = + PanamaVectorUtilSupport.dotProduct(getFirstSegment(ord), getSecondSegment(node)); + if (raw < 0) { + return 1 / (1 + -1 * raw); + } + return raw + 1; + } + }; + } + + @Override + public MaxInnerProductSupplier copy() throws IOException { + return new MaxInnerProductSupplier(input.clone(), values); + } + } +} diff --git a/lucene/core/src/java20/org/apache/lucene/internal/vectorization/Lucene99MemorySegmentFlatVectorsScorer.java b/lucene/core/src/java20/org/apache/lucene/internal/vectorization/Lucene99MemorySegmentFlatVectorsScorer.java new file mode 100644 index 00000000000..78dd70d4d83 --- /dev/null +++ b/lucene/core/src/java20/org/apache/lucene/internal/vectorization/Lucene99MemorySegmentFlatVectorsScorer.java @@ -0,0 +1,93 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.lucene.internal.vectorization; + +import java.io.IOException; +import org.apache.lucene.codecs.hnsw.DefaultFlatVectorScorer; +import org.apache.lucene.codecs.hnsw.FlatVectorsScorer; +import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.util.hnsw.RandomAccessVectorValues; +import org.apache.lucene.util.hnsw.RandomVectorScorer; +import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier; + +public class Lucene99MemorySegmentFlatVectorsScorer implements FlatVectorsScorer { + + public static final Lucene99MemorySegmentFlatVectorsScorer INSTANCE = + new Lucene99MemorySegmentFlatVectorsScorer(DefaultFlatVectorScorer.INSTANCE); + + private final FlatVectorsScorer delegate; + + private Lucene99MemorySegmentFlatVectorsScorer(FlatVectorsScorer delegate) { + this.delegate = delegate; + } + + @Override + public RandomVectorScorerSupplier getRandomVectorScorerSupplier( + VectorSimilarityFunction similarityType, RandomAccessVectorValues vectorValues) + throws IOException { + // currently only supports binary vectors + if (vectorValues instanceof RandomAccessVectorValues.Bytes && vectorValues.getSlice() != null) { + var scorer = + Lucene99MemorySegmentByteVectorScorerSupplier.create( + similarityType, vectorValues.getSlice(), vectorValues); + if (scorer.isPresent()) { + return scorer.get(); + } + } + return delegate.getRandomVectorScorerSupplier(similarityType, vectorValues); + } + + @Override + public RandomVectorScorer getRandomVectorScorer( + VectorSimilarityFunction similarityType, + RandomAccessVectorValues vectorValues, + float[] target) + throws IOException { + // currently only supports binary vectors, so always delegate + return delegate.getRandomVectorScorer(similarityType, vectorValues, target); + } + + @Override + public RandomVectorScorer getRandomVectorScorer( + VectorSimilarityFunction similarityType, + RandomAccessVectorValues vectorValues, + byte[] queryVector) + throws IOException { + checkDimensions(queryVector.length, vectorValues.dimension()); + if (vectorValues instanceof RandomAccessVectorValues.Bytes && vectorValues.getSlice() != null) { + var scorer = + Lucene99MemorySegmentByteVectorScorer.create( + similarityType, vectorValues.getSlice(), vectorValues, queryVector); + if (scorer.isPresent()) { + return scorer.get(); + } + } + return delegate.getRandomVectorScorer(similarityType, vectorValues, queryVector); + } + + static void checkDimensions(int queryLen, int fieldLen) { + if (queryLen != fieldLen) { + throw new IllegalArgumentException( + "vector query dimension: " + queryLen + " differs from field dimension: " + fieldLen); + } + } + + @Override + public String toString() { + return "Lucene99MemorySegmentFlatVectorsScorer()"; + } +} diff --git a/lucene/core/src/java20/org/apache/lucene/internal/vectorization/PanamaVectorUtilSupport.java b/lucene/core/src/java20/org/apache/lucene/internal/vectorization/PanamaVectorUtilSupport.java index 9e447612215..867d0c684cb 100644 --- a/lucene/core/src/java20/org/apache/lucene/internal/vectorization/PanamaVectorUtilSupport.java +++ b/lucene/core/src/java20/org/apache/lucene/internal/vectorization/PanamaVectorUtilSupport.java @@ -16,6 +16,8 @@ */ package org.apache.lucene.internal.vectorization; +import static java.lang.foreign.ValueLayout.JAVA_BYTE; +import static java.nio.ByteOrder.LITTLE_ENDIAN; import static jdk.incubator.vector.VectorOperators.ADD; import static jdk.incubator.vector.VectorOperators.B2I; import static jdk.incubator.vector.VectorOperators.B2S; @@ -23,6 +25,7 @@ import static jdk.incubator.vector.VectorOperators.S2I; import static jdk.incubator.vector.VectorOperators.ZERO_EXTEND_B2S; +import java.lang.foreign.MemorySegment; import jdk.incubator.vector.ByteVector; import jdk.incubator.vector.FloatVector; import jdk.incubator.vector.IntVector; @@ -307,39 +310,44 @@ private float squareDistanceBody(float[] a, float[] b, int limit) { @Override public int dotProduct(byte[] a, byte[] b) { + return dotProduct(MemorySegment.ofArray(a), MemorySegment.ofArray(b)); + } + + public static int dotProduct(MemorySegment a, MemorySegment b) { + assert a.byteSize() == b.byteSize(); int i = 0; int res = 0; // only vectorize if we'll at least enter the loop a single time, and we have at least 128-bit // vectors (256-bit on intel to dodge performance landmines) - if (a.length >= 16 && HAS_FAST_INTEGER_VECTORS) { + if (a.byteSize() >= 16 && HAS_FAST_INTEGER_VECTORS) { // compute vectorized dot product consistent with VPDPBUSD instruction if (VECTOR_BITSIZE >= 512) { - i += BYTE_SPECIES.loopBound(a.length); + i += BYTE_SPECIES.loopBound(a.byteSize()); res += dotProductBody512(a, b, i); } else if (VECTOR_BITSIZE == 256) { - i += BYTE_SPECIES.loopBound(a.length); + i += BYTE_SPECIES.loopBound(a.byteSize()); res += dotProductBody256(a, b, i); } else { // tricky: we don't have SPECIES_32, so we workaround with "overlapping read" - i += ByteVector.SPECIES_64.loopBound(a.length - ByteVector.SPECIES_64.length()); + i += ByteVector.SPECIES_64.loopBound(a.byteSize() - ByteVector.SPECIES_64.length()); res += dotProductBody128(a, b, i); } } // scalar tail - for (; i < a.length; i++) { - res += b[i] * a[i]; + for (; i < a.byteSize(); i++) { + res += b.get(JAVA_BYTE, i) * a.get(JAVA_BYTE, i); } return res; } /** vectorized dot product body (512 bit vectors) */ - private int dotProductBody512(byte[] a, byte[] b, int limit) { + private static int dotProductBody512(MemorySegment a, MemorySegment b, int limit) { IntVector acc = IntVector.zero(INT_SPECIES); for (int i = 0; i < limit; i += BYTE_SPECIES.length()) { - ByteVector va8 = ByteVector.fromArray(BYTE_SPECIES, a, i); - ByteVector vb8 = ByteVector.fromArray(BYTE_SPECIES, b, i); + ByteVector va8 = ByteVector.fromMemorySegment(BYTE_SPECIES, a, i, LITTLE_ENDIAN); + ByteVector vb8 = ByteVector.fromMemorySegment(BYTE_SPECIES, b, i, LITTLE_ENDIAN); // 16-bit multiply: avoid AVX-512 heavy multiply on zmm Vector va16 = va8.convertShape(B2S, SHORT_SPECIES, 0); @@ -355,11 +363,11 @@ private int dotProductBody512(byte[] a, byte[] b, int limit) { } /** vectorized dot product body (256 bit vectors) */ - private int dotProductBody256(byte[] a, byte[] b, int limit) { + private static int dotProductBody256(MemorySegment a, MemorySegment b, int limit) { IntVector acc = IntVector.zero(IntVector.SPECIES_256); for (int i = 0; i < limit; i += ByteVector.SPECIES_64.length()) { - ByteVector va8 = ByteVector.fromArray(ByteVector.SPECIES_64, a, i); - ByteVector vb8 = ByteVector.fromArray(ByteVector.SPECIES_64, b, i); + ByteVector va8 = ByteVector.fromMemorySegment(ByteVector.SPECIES_64, a, i, LITTLE_ENDIAN); + ByteVector vb8 = ByteVector.fromMemorySegment(ByteVector.SPECIES_64, b, i, LITTLE_ENDIAN); // 32-bit multiply and add into accumulator Vector va32 = va8.convertShape(B2I, IntVector.SPECIES_256, 0); @@ -371,13 +379,13 @@ private int dotProductBody256(byte[] a, byte[] b, int limit) { } /** vectorized dot product body (128 bit vectors) */ - private int dotProductBody128(byte[] a, byte[] b, int limit) { + private static int dotProductBody128(MemorySegment a, MemorySegment b, int limit) { IntVector acc = IntVector.zero(IntVector.SPECIES_128); // 4 bytes at a time (re-loading half the vector each time!) for (int i = 0; i < limit; i += ByteVector.SPECIES_64.length() >> 1) { // load 8 bytes - ByteVector va8 = ByteVector.fromArray(ByteVector.SPECIES_64, a, i); - ByteVector vb8 = ByteVector.fromArray(ByteVector.SPECIES_64, b, i); + ByteVector va8 = ByteVector.fromMemorySegment(ByteVector.SPECIES_64, a, i, LITTLE_ENDIAN); + ByteVector vb8 = ByteVector.fromMemorySegment(ByteVector.SPECIES_64, b, i, LITTLE_ENDIAN); // process first "half" only: 16-bit multiply Vector va16 = va8.convert(B2S, 0); @@ -569,6 +577,10 @@ private int int4DotProductBody128(byte[] a, byte[] b, int limit) { @Override public float cosine(byte[] a, byte[] b) { + return cosine(MemorySegment.ofArray(a), MemorySegment.ofArray(b)); + } + + public static float cosine(MemorySegment a, MemorySegment b) { int i = 0; int sum = 0; int norm1 = 0; @@ -576,17 +588,17 @@ public float cosine(byte[] a, byte[] b) { // only vectorize if we'll at least enter the loop a single time, and we have at least 128-bit // vectors (256-bit on intel to dodge performance landmines) - if (a.length >= 16 && HAS_FAST_INTEGER_VECTORS) { + if (a.byteSize() >= 16 && HAS_FAST_INTEGER_VECTORS) { final float[] ret; if (VECTOR_BITSIZE >= 512) { - i += BYTE_SPECIES.loopBound(a.length); + i += BYTE_SPECIES.loopBound((int) a.byteSize()); ret = cosineBody512(a, b, i); } else if (VECTOR_BITSIZE == 256) { - i += BYTE_SPECIES.loopBound(a.length); + i += BYTE_SPECIES.loopBound((int) a.byteSize()); ret = cosineBody256(a, b, i); } else { // tricky: we don't have SPECIES_32, so we workaround with "overlapping read" - i += ByteVector.SPECIES_64.loopBound(a.length - ByteVector.SPECIES_64.length()); + i += ByteVector.SPECIES_64.loopBound(a.byteSize() - ByteVector.SPECIES_64.length()); ret = cosineBody128(a, b, i); } sum += ret[0]; @@ -595,9 +607,9 @@ public float cosine(byte[] a, byte[] b) { } // scalar tail - for (; i < a.length; i++) { - byte elem1 = a[i]; - byte elem2 = b[i]; + for (; i < a.byteSize(); i++) { + byte elem1 = a.get(JAVA_BYTE, i); + byte elem2 = b.get(JAVA_BYTE, i); sum += elem1 * elem2; norm1 += elem1 * elem1; norm2 += elem2 * elem2; @@ -606,13 +618,13 @@ public float cosine(byte[] a, byte[] b) { } /** vectorized cosine body (512 bit vectors) */ - private float[] cosineBody512(byte[] a, byte[] b, int limit) { + private static float[] cosineBody512(MemorySegment a, MemorySegment b, int limit) { IntVector accSum = IntVector.zero(INT_SPECIES); IntVector accNorm1 = IntVector.zero(INT_SPECIES); IntVector accNorm2 = IntVector.zero(INT_SPECIES); for (int i = 0; i < limit; i += BYTE_SPECIES.length()) { - ByteVector va8 = ByteVector.fromArray(BYTE_SPECIES, a, i); - ByteVector vb8 = ByteVector.fromArray(BYTE_SPECIES, b, i); + ByteVector va8 = ByteVector.fromMemorySegment(BYTE_SPECIES, a, i, LITTLE_ENDIAN); + ByteVector vb8 = ByteVector.fromMemorySegment(BYTE_SPECIES, b, i, LITTLE_ENDIAN); // 16-bit multiply: avoid AVX-512 heavy multiply on zmm Vector va16 = va8.convertShape(B2S, SHORT_SPECIES, 0); @@ -636,13 +648,13 @@ private float[] cosineBody512(byte[] a, byte[] b, int limit) { } /** vectorized cosine body (256 bit vectors) */ - private float[] cosineBody256(byte[] a, byte[] b, int limit) { + private static float[] cosineBody256(MemorySegment a, MemorySegment b, int limit) { IntVector accSum = IntVector.zero(IntVector.SPECIES_256); IntVector accNorm1 = IntVector.zero(IntVector.SPECIES_256); IntVector accNorm2 = IntVector.zero(IntVector.SPECIES_256); for (int i = 0; i < limit; i += ByteVector.SPECIES_64.length()) { - ByteVector va8 = ByteVector.fromArray(ByteVector.SPECIES_64, a, i); - ByteVector vb8 = ByteVector.fromArray(ByteVector.SPECIES_64, b, i); + ByteVector va8 = ByteVector.fromMemorySegment(ByteVector.SPECIES_64, a, i, LITTLE_ENDIAN); + ByteVector vb8 = ByteVector.fromMemorySegment(ByteVector.SPECIES_64, b, i, LITTLE_ENDIAN); // 16-bit multiply, and add into accumulators Vector va32 = va8.convertShape(B2I, IntVector.SPECIES_256, 0); @@ -661,13 +673,13 @@ private float[] cosineBody256(byte[] a, byte[] b, int limit) { } /** vectorized cosine body (128 bit vectors) */ - private float[] cosineBody128(byte[] a, byte[] b, int limit) { + private static float[] cosineBody128(MemorySegment a, MemorySegment b, int limit) { IntVector accSum = IntVector.zero(IntVector.SPECIES_128); IntVector accNorm1 = IntVector.zero(IntVector.SPECIES_128); IntVector accNorm2 = IntVector.zero(IntVector.SPECIES_128); for (int i = 0; i < limit; i += ByteVector.SPECIES_64.length() >> 1) { - ByteVector va8 = ByteVector.fromArray(ByteVector.SPECIES_64, a, i); - ByteVector vb8 = ByteVector.fromArray(ByteVector.SPECIES_64, b, i); + ByteVector va8 = ByteVector.fromMemorySegment(ByteVector.SPECIES_64, a, i, LITTLE_ENDIAN); + ByteVector vb8 = ByteVector.fromMemorySegment(ByteVector.SPECIES_64, b, i, LITTLE_ENDIAN); // process first half only: 16-bit multiply Vector va16 = va8.convert(B2S, 0); @@ -689,35 +701,40 @@ private float[] cosineBody128(byte[] a, byte[] b, int limit) { @Override public int squareDistance(byte[] a, byte[] b) { + return squareDistance(MemorySegment.ofArray(a), MemorySegment.ofArray(b)); + } + + public static int squareDistance(MemorySegment a, MemorySegment b) { + assert a.byteSize() == b.byteSize(); int i = 0; int res = 0; // only vectorize if we'll at least enter the loop a single time, and we have at least 128-bit // vectors (256-bit on intel to dodge performance landmines) - if (a.length >= 16 && HAS_FAST_INTEGER_VECTORS) { + if (a.byteSize() >= 16 && HAS_FAST_INTEGER_VECTORS) { if (VECTOR_BITSIZE >= 256) { - i += BYTE_SPECIES.loopBound(a.length); + i += BYTE_SPECIES.loopBound((int) a.byteSize()); res += squareDistanceBody256(a, b, i); } else { - i += ByteVector.SPECIES_64.loopBound(a.length); + i += ByteVector.SPECIES_64.loopBound((int) a.byteSize()); res += squareDistanceBody128(a, b, i); } } // scalar tail - for (; i < a.length; i++) { - int diff = a[i] - b[i]; + for (; i < a.byteSize(); i++) { + int diff = a.get(JAVA_BYTE, i) - b.get(JAVA_BYTE, i); res += diff * diff; } return res; } /** vectorized square distance body (256+ bit vectors) */ - private int squareDistanceBody256(byte[] a, byte[] b, int limit) { + private static int squareDistanceBody256(MemorySegment a, MemorySegment b, int limit) { IntVector acc = IntVector.zero(INT_SPECIES); for (int i = 0; i < limit; i += BYTE_SPECIES.length()) { - ByteVector va8 = ByteVector.fromArray(BYTE_SPECIES, a, i); - ByteVector vb8 = ByteVector.fromArray(BYTE_SPECIES, b, i); + ByteVector va8 = ByteVector.fromMemorySegment(BYTE_SPECIES, a, i, LITTLE_ENDIAN); + ByteVector vb8 = ByteVector.fromMemorySegment(BYTE_SPECIES, b, i, LITTLE_ENDIAN); // 32-bit sub, multiply, and add into accumulators // TODO: uses AVX-512 heavy multiply on zmm, should we just use 256-bit vectors on AVX-512? @@ -731,14 +748,14 @@ private int squareDistanceBody256(byte[] a, byte[] b, int limit) { } /** vectorized square distance body (128 bit vectors) */ - private int squareDistanceBody128(byte[] a, byte[] b, int limit) { + private static int squareDistanceBody128(MemorySegment a, MemorySegment b, int limit) { // 128-bit implementation, which must "split up" vectors due to widening conversions // it doesn't help to do the overlapping read trick, due to 32-bit multiply in the formula IntVector acc1 = IntVector.zero(IntVector.SPECIES_128); IntVector acc2 = IntVector.zero(IntVector.SPECIES_128); for (int i = 0; i < limit; i += ByteVector.SPECIES_64.length()) { - ByteVector va8 = ByteVector.fromArray(ByteVector.SPECIES_64, a, i); - ByteVector vb8 = ByteVector.fromArray(ByteVector.SPECIES_64, b, i); + ByteVector va8 = ByteVector.fromMemorySegment(ByteVector.SPECIES_64, a, i, LITTLE_ENDIAN); + ByteVector vb8 = ByteVector.fromMemorySegment(ByteVector.SPECIES_64, b, i, LITTLE_ENDIAN); // 16-bit sub Vector va16 = va8.convertShape(B2S, ShortVector.SPECIES_128, 0); diff --git a/lucene/core/src/java20/org/apache/lucene/internal/vectorization/PanamaVectorizationProvider.java b/lucene/core/src/java20/org/apache/lucene/internal/vectorization/PanamaVectorizationProvider.java index 11901d74f42..87f7cf2baf7 100644 --- a/lucene/core/src/java20/org/apache/lucene/internal/vectorization/PanamaVectorizationProvider.java +++ b/lucene/core/src/java20/org/apache/lucene/internal/vectorization/PanamaVectorizationProvider.java @@ -21,6 +21,7 @@ import java.util.Locale; import java.util.logging.Logger; import jdk.incubator.vector.FloatVector; +import org.apache.lucene.codecs.hnsw.FlatVectorsScorer; import org.apache.lucene.util.Constants; import org.apache.lucene.util.SuppressForbidden; @@ -73,4 +74,9 @@ private static T doPrivileged(PrivilegedAction action) { public VectorUtilSupport getVectorUtilSupport() { return vectorUtilSupport; } + + @Override + public FlatVectorsScorer getLucene99FlatVectorsScorer() { + return Lucene99MemorySegmentFlatVectorsScorer.INSTANCE; + } } diff --git a/lucene/core/src/java21/org/apache/lucene/store/MemorySegmentAccessInput.java b/lucene/core/src/java21/org/apache/lucene/store/MemorySegmentAccessInput.java new file mode 100644 index 00000000000..7c22eccdcf1 --- /dev/null +++ b/lucene/core/src/java21/org/apache/lucene/store/MemorySegmentAccessInput.java @@ -0,0 +1,33 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.lucene.store; + +import java.io.IOException; +import java.lang.foreign.MemorySegment; + +/** + * Provides access to the backing memory segment. + * + *

Expert API, allows access to the backing memory. + */ +public interface MemorySegmentAccessInput extends RandomAccessInput, Cloneable { + + /** Returns the memory segment for a given position and length, or null. */ + MemorySegment segmentSliceOrNull(long pos, int len) throws IOException; + + MemorySegmentAccessInput clone(); +} diff --git a/lucene/core/src/java21/org/apache/lucene/store/MemorySegmentIndexInput.java b/lucene/core/src/java21/org/apache/lucene/store/MemorySegmentIndexInput.java index fb304629ab9..3eead55afce 100644 --- a/lucene/core/src/java21/org/apache/lucene/store/MemorySegmentIndexInput.java +++ b/lucene/core/src/java21/org/apache/lucene/store/MemorySegmentIndexInput.java @@ -34,7 +34,8 @@ * chunkSizePower). */ @SuppressWarnings("preview") -abstract class MemorySegmentIndexInput extends IndexInput implements RandomAccessInput { +abstract class MemorySegmentIndexInput extends IndexInput + implements RandomAccessInput, MemorySegmentAccessInput { static final ValueLayout.OfByte LAYOUT_BYTE = ValueLayout.JAVA_BYTE; static final ValueLayout.OfShort LAYOUT_LE_SHORT = ValueLayout.JAVA_SHORT_UNALIGNED.withOrder(ByteOrder.LITTLE_ENDIAN); @@ -478,6 +479,10 @@ MemorySegmentIndexInput buildSlice(String sliceDescription, long offset, long le } } + static boolean checkIndex(long index, long length) { + return index >= 0 && index < length; + } + @Override public final void close() throws IOException { if (curSegment == null) { @@ -578,6 +583,16 @@ public long readLong(long pos) throws IOException { throw alreadyClosed(e); } } + + @Override + public MemorySegment segmentSliceOrNull(long pos, int len) throws IOException { + try { + Objects.checkIndex(pos + len, this.length + 1); + return curSegment.asSlice(pos, len); + } catch (IndexOutOfBoundsException e) { + throw handlePositionalIOOBE(e, "segmentSliceOrNull", pos); + } + } } /** This class adds offset support to MemorySegmentIndexInput, which is needed for slices. */ @@ -638,6 +653,20 @@ public long readLong(long pos) throws IOException { return super.readLong(pos + offset); } + public MemorySegment segmentSliceOrNull(long pos, int len) throws IOException { + if (pos + len > length) { + throw handlePositionalIOOBE(null, "segmentSliceOrNull", pos); + } + pos = pos + offset; + final int si = (int) (pos >> chunkSizePower); + final MemorySegment seg = segments[si]; + final long segOffset = pos & chunkSizeMask; + if (checkIndex(segOffset + len, seg.byteSize() + 1)) { + return seg.asSlice(segOffset, len); + } + return null; + } + @Override MemorySegmentIndexInput buildSlice(String sliceDescription, long ofs, long length) { return super.buildSlice(sliceDescription, this.offset + ofs, length); diff --git a/lucene/core/src/test/org/apache/lucene/codecs/hnsw/TestFlatVectorScorer.java b/lucene/core/src/test/org/apache/lucene/codecs/hnsw/TestFlatVectorScorer.java index 5dbebf08df5..9bce1f10a43 100644 --- a/lucene/core/src/test/org/apache/lucene/codecs/hnsw/TestFlatVectorScorer.java +++ b/lucene/core/src/test/org/apache/lucene/codecs/hnsw/TestFlatVectorScorer.java @@ -21,6 +21,8 @@ import static org.apache.lucene.index.VectorSimilarityFunction.EUCLIDEAN; import static org.apache.lucene.index.VectorSimilarityFunction.MAXIMUM_INNER_PRODUCT; import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.oneOf; import com.carrotsearch.randomizedtesting.annotations.ParametersFactory; import java.io.ByteArrayOutputStream; @@ -60,8 +62,9 @@ public TestFlatVectorScorer( public static Iterable parametersFactory() { var scorers = List.of( - new DefaultFlatVectorScorer(), - new Lucene99ScalarQuantizedVectorScorer(new DefaultFlatVectorScorer())); + DefaultFlatVectorScorer.INSTANCE, + new Lucene99ScalarQuantizedVectorScorer(new DefaultFlatVectorScorer()), + FlatVectorScorerUtil.getLucene99FlatVectorsScorer()); var dirs = List.>of( TestFlatVectorScorer::newDirectory, @@ -76,7 +79,14 @@ public static Iterable parametersFactory() { return objs; } - // Tests that the creation of another scorer does not perturb previous scorers + public void testDefaultOrMemSegScorer() { + var scorer = FlatVectorScorerUtil.getLucene99FlatVectorsScorer(); + assertThat( + scorer.toString(), + is(oneOf("DefaultFlatVectorScorer()", "Lucene99MemorySegmentFlatVectorsScorer()"))); + } + + // Tests that the creation of another scorer does not disturb previous scorers public void testMultipleByteScorers() throws IOException { byte[] vec0 = new byte[] {0, 0, 0, 0}; byte[] vec1 = new byte[] {1, 1, 1, 1}; diff --git a/lucene/core/src/test/org/apache/lucene/codecs/lucene99/TestLucene99HnswQuantizedVectorsFormat.java b/lucene/core/src/test/org/apache/lucene/codecs/lucene99/TestLucene99HnswQuantizedVectorsFormat.java index dd9a55b36fb..82adedc26e5 100644 --- a/lucene/core/src/test/org/apache/lucene/codecs/lucene99/TestLucene99HnswQuantizedVectorsFormat.java +++ b/lucene/core/src/test/org/apache/lucene/codecs/lucene99/TestLucene99HnswQuantizedVectorsFormat.java @@ -16,11 +16,15 @@ */ package org.apache.lucene.codecs.lucene99; +import static java.lang.String.format; import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS; +import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.oneOf; import java.util.ArrayList; import java.util.Arrays; import java.util.List; +import java.util.Locale; import java.util.stream.Collectors; import org.apache.lucene.codecs.Codec; import org.apache.lucene.codecs.FilterCodec; @@ -222,9 +226,12 @@ public KnnVectorsFormat knnVectorsFormat() { 10, 20, 1, (byte) 4, false, 0.9f, null); } }; - String expectedString = - "Lucene99HnswScalarQuantizedVectorsFormat(name=Lucene99HnswScalarQuantizedVectorsFormat, maxConn=10, beamWidth=20, flatVectorFormat=Lucene99ScalarQuantizedVectorsFormat(name=Lucene99ScalarQuantizedVectorsFormat, confidenceInterval=0.9, bits=4, compress=false, flatVectorScorer=ScalarQuantizedVectorScorer(nonQuantizedDelegate=DefaultFlatVectorScorer()), rawVectorFormat=Lucene99FlatVectorsFormat(vectorsScorer=DefaultFlatVectorScorer())))"; - assertEquals(expectedString, customCodec.knnVectorsFormat().toString()); + String expectedPattern = + "Lucene99HnswScalarQuantizedVectorsFormat(name=Lucene99HnswScalarQuantizedVectorsFormat, maxConn=10, beamWidth=20, flatVectorFormat=Lucene99ScalarQuantizedVectorsFormat(name=Lucene99ScalarQuantizedVectorsFormat, confidenceInterval=0.9, bits=4, compress=false, flatVectorScorer=ScalarQuantizedVectorScorer(nonQuantizedDelegate=DefaultFlatVectorScorer()), rawVectorFormat=Lucene99FlatVectorsFormat(vectorsScorer=%s())))"; + var defaultScorer = format(Locale.ROOT, expectedPattern, "DefaultFlatVectorScorer"); + var memSegScorer = + format(Locale.ROOT, expectedPattern, "Lucene99MemorySegmentFlatVectorsScorer"); + assertThat(customCodec.knnVectorsFormat().toString(), is(oneOf(defaultScorer, memSegScorer))); } public void testLimits() { diff --git a/lucene/core/src/test/org/apache/lucene/codecs/lucene99/TestLucene99HnswVectorsFormat.java b/lucene/core/src/test/org/apache/lucene/codecs/lucene99/TestLucene99HnswVectorsFormat.java index 0f84f8ab4ae..aea32b0a13d 100644 --- a/lucene/core/src/test/org/apache/lucene/codecs/lucene99/TestLucene99HnswVectorsFormat.java +++ b/lucene/core/src/test/org/apache/lucene/codecs/lucene99/TestLucene99HnswVectorsFormat.java @@ -16,6 +16,11 @@ */ package org.apache.lucene.codecs.lucene99; +import static java.lang.String.format; +import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.oneOf; + +import java.util.Locale; import org.apache.lucene.codecs.Codec; import org.apache.lucene.codecs.FilterCodec; import org.apache.lucene.codecs.KnnVectorsFormat; @@ -37,9 +42,12 @@ public KnnVectorsFormat knnVectorsFormat() { return new Lucene99HnswVectorsFormat(10, 20); } }; - String expectedString = - "Lucene99HnswVectorsFormat(name=Lucene99HnswVectorsFormat, maxConn=10, beamWidth=20, flatVectorFormat=Lucene99FlatVectorsFormat(vectorsScorer=DefaultFlatVectorScorer()))"; - assertEquals(expectedString, customCodec.knnVectorsFormat().toString()); + String expectedPattern = + "Lucene99HnswVectorsFormat(name=Lucene99HnswVectorsFormat, maxConn=10, beamWidth=20, flatVectorFormat=Lucene99FlatVectorsFormat(vectorsScorer=%s()))"; + var defaultScorer = format(Locale.ROOT, expectedPattern, "DefaultFlatVectorScorer"); + var memSegScorer = + format(Locale.ROOT, expectedPattern, "Lucene99MemorySegmentFlatVectorsScorer"); + assertThat(customCodec.knnVectorsFormat().toString(), is(oneOf(defaultScorer, memSegScorer))); } public void testLimits() { diff --git a/lucene/core/src/test/org/apache/lucene/index/VectorScorerBenchmark.java b/lucene/core/src/test/org/apache/lucene/index/VectorScorerBenchmark.java new file mode 100644 index 00000000000..c4d3040f283 --- /dev/null +++ b/lucene/core/src/test/org/apache/lucene/index/VectorScorerBenchmark.java @@ -0,0 +1,128 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.lucene.benchmark.jmh; + +import static org.apache.lucene.index.VectorSimilarityFunction.DOT_PRODUCT; + +import java.io.IOException; +import java.nio.file.Files; +import java.util.concurrent.ThreadLocalRandom; +import java.util.concurrent.TimeUnit; +import org.apache.lucene.codecs.hnsw.FlatVectorScorerUtil; +import org.apache.lucene.codecs.hnsw.FlatVectorsScorer; +import org.apache.lucene.codecs.lucene95.OffHeapByteVectorValues; +import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.store.Directory; +import org.apache.lucene.store.IOContext; +import org.apache.lucene.store.IndexInput; +import org.apache.lucene.store.IndexOutput; +import org.apache.lucene.store.MMapDirectory; +import org.apache.lucene.util.IOUtils; +import org.apache.lucene.util.hnsw.RandomAccessVectorValues; +import org.apache.lucene.util.hnsw.RandomVectorScorer; +import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier; +import org.openjdk.jmh.annotations.*; + +@BenchmarkMode(Mode.Throughput) +@OutputTimeUnit(TimeUnit.MICROSECONDS) +@State(Scope.Benchmark) +// first iteration is complete garbage, so make sure we really warmup +@Warmup(iterations = 4, time = 1) +// real iterations. not useful to spend tons of time here, better to fork more +@Measurement(iterations = 5, time = 1) +// engage some noise reduction +@Fork( + value = 3, + jvmArgsAppend = {"-Xmx2g", "-Xms2g", "-XX:+AlwaysPreTouch"}) +public class VectorScorerBenchmark { + + @Param({"1", "128", "207", "256", "300", "512", "702", "1024"}) + int size; + + Directory dir; + IndexInput in; + RandomAccessVectorValues vectorValues; + byte[] vec1, vec2; + RandomVectorScorer scorer; + + @Setup(Level.Iteration) + public void init() throws IOException { + vec1 = new byte[size]; + vec2 = new byte[size]; + ThreadLocalRandom.current().nextBytes(vec1); + ThreadLocalRandom.current().nextBytes(vec2); + + dir = new MMapDirectory(Files.createTempDirectory("VectorScorerBenchmark")); + try (IndexOutput out = dir.createOutput("vector.data", IOContext.DEFAULT)) { + out.writeBytes(vec1, 0, vec1.length); + out.writeBytes(vec2, 0, vec2.length); + } + in = dir.openInput("vector.data", IOContext.DEFAULT); + vectorValues = vectorValues(size, 2, in, DOT_PRODUCT); + scorer = + FlatVectorScorerUtil.getLucene99FlatVectorsScorer() + .getRandomVectorScorerSupplier(DOT_PRODUCT, vectorValues) + .scorer(0); + } + + @TearDown + public void teardown() throws IOException { + IOUtils.close(dir, in); + } + + @Benchmark + public float binaryDotProductDefault() throws IOException { + return scorer.score(1); + } + + @Benchmark + @Fork(jvmArgsPrepend = {"--add-modules=jdk.incubator.vector"}) + public float binaryDotProductMemSeg() throws IOException { + return scorer.score(1); + } + + static RandomAccessVectorValues vectorValues( + int dims, int size, IndexInput in, VectorSimilarityFunction sim) throws IOException { + return new OffHeapByteVectorValues.DenseOffHeapVectorValues( + dims, size, in.slice("test", 0, in.length()), dims, new ThrowingFlatVectorScorer(), sim); + } + + static final class ThrowingFlatVectorScorer implements FlatVectorsScorer { + + @Override + public RandomVectorScorerSupplier getRandomVectorScorerSupplier( + VectorSimilarityFunction similarityFunction, RandomAccessVectorValues vectorValues) { + throw new UnsupportedOperationException(); + } + + @Override + public RandomVectorScorer getRandomVectorScorer( + VectorSimilarityFunction similarityFunction, + RandomAccessVectorValues vectorValues, + float[] target) { + throw new UnsupportedOperationException(); + } + + @Override + public RandomVectorScorer getRandomVectorScorer( + VectorSimilarityFunction similarityFunction, + RandomAccessVectorValues vectorValues, + byte[] target) { + throw new UnsupportedOperationException(); + } + } +} diff --git a/lucene/core/src/test/org/apache/lucene/internal/vectorization/TestVectorScorer.java b/lucene/core/src/test/org/apache/lucene/internal/vectorization/TestVectorScorer.java new file mode 100644 index 00000000000..ce2ad6854a2 --- /dev/null +++ b/lucene/core/src/test/org/apache/lucene/internal/vectorization/TestVectorScorer.java @@ -0,0 +1,398 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.lucene.internal.vectorization; + +import static org.apache.lucene.index.VectorSimilarityFunction.COSINE; +import static org.apache.lucene.index.VectorSimilarityFunction.DOT_PRODUCT; +import static org.apache.lucene.index.VectorSimilarityFunction.EUCLIDEAN; +import static org.apache.lucene.index.VectorSimilarityFunction.MAXIMUM_INNER_PRODUCT; + +import com.carrotsearch.randomizedtesting.generators.RandomNumbers; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.util.Arrays; +import java.util.List; +import java.util.Objects; +import java.util.Optional; +import java.util.Random; +import java.util.concurrent.Callable; +import java.util.concurrent.Executors; +import java.util.concurrent.Future; +import java.util.concurrent.TimeUnit; +import java.util.function.Function; +import java.util.function.Predicate; +import java.util.stream.IntStream; +import org.apache.lucene.codecs.hnsw.DefaultFlatVectorScorer; +import org.apache.lucene.codecs.hnsw.FlatVectorsScorer; +import org.apache.lucene.codecs.lucene95.OffHeapByteVectorValues; +import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.store.Directory; +import org.apache.lucene.store.IOContext; +import org.apache.lucene.store.IndexInput; +import org.apache.lucene.store.IndexOutput; +import org.apache.lucene.store.MMapDirectory; +import org.apache.lucene.tests.util.LuceneTestCase; +import org.apache.lucene.util.NamedThreadFactory; +import org.apache.lucene.util.hnsw.RandomAccessVectorValues; +import org.apache.lucene.util.hnsw.RandomVectorScorer; +import org.junit.BeforeClass; + +public class TestVectorScorer extends LuceneTestCase { + + private static final double DELTA = 1e-5; + + static final FlatVectorsScorer DEFAULT_SCORER = DefaultFlatVectorScorer.INSTANCE; + static final FlatVectorsScorer MEMSEG_SCORER = + VectorizationProvider.lookup(true).getLucene99FlatVectorsScorer(); + + @BeforeClass + public static void beforeClass() throws Exception { + assumeTrue( + "Test only works when the Memory segment scorer is present.", + MEMSEG_SCORER.getClass() != DEFAULT_SCORER.getClass()); + } + + public void testSimpleScorer() throws IOException { + testSimpleScorer(MMapDirectory.DEFAULT_MAX_CHUNK_SIZE); + } + + public void testSimpleScorerSmallChunkSize() throws IOException { + long maxChunkSize = random().nextLong(4, 16); + testSimpleScorer(maxChunkSize); + } + + public void testSimpleScorerMedChunkSize() throws IOException { + // a chunk size where in some vectors will be copied on-heap, while others remain off-heap + testSimpleScorer(64); + } + + void testSimpleScorer(long maxChunkSize) throws IOException { + try (Directory dir = new MMapDirectory(createTempDir("testSimpleScorer"), maxChunkSize)) { + for (int dims : List.of(31, 32, 33)) { + // dimensions that, in some scenarios, cross the mmap chunk sizes + byte[][] vectors = new byte[2][dims]; + String fileName = "bar-" + dims; + try (IndexOutput out = dir.createOutput(fileName, IOContext.DEFAULT)) { + for (int i = 0; i < dims; i++) { + vectors[0][i] = (byte) i; + vectors[1][i] = (byte) (dims - i); + } + byte[] bytes = concat(vectors[0], vectors[1]); + out.writeBytes(bytes, 0, bytes.length); + } + try (IndexInput in = dir.openInput(fileName, IOContext.DEFAULT)) { + for (var sim : List.of(COSINE, EUCLIDEAN, DOT_PRODUCT, MAXIMUM_INNER_PRODUCT)) { + var vectorValues = vectorValues(dims, 2, in, sim); + for (var ords : List.of(List.of(0, 1), List.of(1, 0))) { + int idx0 = ords.get(0); + int idx1 = ords.get(1); + + // getRandomVectorScorerSupplier + var scorer1 = DEFAULT_SCORER.getRandomVectorScorerSupplier(sim, vectorValues); + float expected = scorer1.scorer(idx0).score(idx1); + var scorer2 = MEMSEG_SCORER.getRandomVectorScorerSupplier(sim, vectorValues); + assertEquals(scorer2.scorer(idx0).score(idx1), expected, DELTA); + + // getRandomVectorScorer + var scorer3 = DEFAULT_SCORER.getRandomVectorScorer(sim, vectorValues, vectors[idx0]); + assertEquals(scorer3.score(idx1), expected, DELTA); + var scorer4 = MEMSEG_SCORER.getRandomVectorScorer(sim, vectorValues, vectors[idx0]); + assertEquals(scorer4.score(idx1), expected, DELTA); + } + } + } + } + } + } + + public void testRandomScorer() throws IOException { + testRandomScorer(MMapDirectory.DEFAULT_MAX_CHUNK_SIZE, BYTE_ARRAY_RANDOM_FUNC); + } + + public void testRandomScorerMax() throws IOException { + testRandomScorer(MMapDirectory.DEFAULT_MAX_CHUNK_SIZE, BYTE_ARRAY_MAX_FUNC); + } + + public void testRandomScorerMin() throws IOException { + testRandomScorer(MMapDirectory.DEFAULT_MAX_CHUNK_SIZE, BYTE_ARRAY_MIN_FUNC); + } + + public void testRandomSmallChunkSize() throws IOException { + long maxChunkSize = randomLongBetween(32, 128); + testRandomScorer(maxChunkSize, BYTE_ARRAY_RANDOM_FUNC); + } + + void testRandomScorer(long maxChunkSize, Function byteArraySupplier) + throws IOException { + try (Directory dir = new MMapDirectory(createTempDir("testRandomScorer"), maxChunkSize)) { + final int dims = randomIntBetween(1, 4096); + final int size = randomIntBetween(2, 100); + final byte[][] vectors = new byte[size][]; + String fileName = "foo-" + dims; + try (IndexOutput out = dir.createOutput(fileName, IOContext.DEFAULT)) { + for (int i = 0; i < size; i++) { + var vec = byteArraySupplier.apply(dims); + out.writeBytes(vec, 0, vec.length); + vectors[i] = vec; + } + } + + try (IndexInput in = dir.openInput(fileName, IOContext.DEFAULT)) { + for (int times = 0; times < TIMES; times++) { + for (var sim : List.of(COSINE, EUCLIDEAN, DOT_PRODUCT, MAXIMUM_INNER_PRODUCT)) { + var vectorValues = vectorValues(dims, size, in, sim); + int idx0 = randomIntBetween(0, size - 1); + int idx1 = randomIntBetween(0, size - 1); // may be the same as idx0 - which is ok. + + // getRandomVectorScorerSupplier + var scorer1 = DEFAULT_SCORER.getRandomVectorScorerSupplier(sim, vectorValues); + float expected = scorer1.scorer(idx0).score(idx1); + var scorer2 = MEMSEG_SCORER.getRandomVectorScorerSupplier(sim, vectorValues); + assertEquals(scorer2.scorer(idx0).score(idx1), expected, DELTA); + + // getRandomVectorScorer + var scorer3 = DEFAULT_SCORER.getRandomVectorScorer(sim, vectorValues, vectors[idx0]); + assertEquals(scorer3.score(idx1), expected, DELTA); + var scorer4 = MEMSEG_SCORER.getRandomVectorScorer(sim, vectorValues, vectors[idx0]); + assertEquals(scorer4.score(idx1), expected, DELTA); + } + } + } + } + } + + public void testRandomSliceSmall() throws IOException { + testRandomSliceImpl(30, 64, 1, BYTE_ARRAY_RANDOM_FUNC); + } + + public void testRandomSlice() throws IOException { + int dims = randomIntBetween(1, 4096); + long maxChunkSize = randomLongBetween(32, 128); + int initialOffset = randomIntBetween(1, 129); + testRandomSliceImpl(dims, maxChunkSize, initialOffset, BYTE_ARRAY_RANDOM_FUNC); + } + + // Tests with a slice that has a non-zero initial offset + void testRandomSliceImpl( + int dims, long maxChunkSize, int initialOffset, Function byteArraySupplier) + throws IOException { + try (Directory dir = new MMapDirectory(createTempDir("testRandomSliceImpl"), maxChunkSize)) { + final int size = randomIntBetween(2, 100); + final byte[][] vectors = new byte[size][]; + String fileName = "baz-" + dims; + try (IndexOutput out = dir.createOutput(fileName, IOContext.DEFAULT)) { + byte[] ba = new byte[initialOffset]; + out.writeBytes(ba, 0, ba.length); + for (int i = 0; i < size; i++) { + var vec = byteArraySupplier.apply(dims); + out.writeBytes(vec, 0, vec.length); + vectors[i] = vec; + } + } + + try (var outter = dir.openInput(fileName, IOContext.DEFAULT); + var in = outter.slice("slice", initialOffset, outter.length() - initialOffset)) { + for (int times = 0; times < TIMES; times++) { + for (var sim : List.of(COSINE, EUCLIDEAN, DOT_PRODUCT, MAXIMUM_INNER_PRODUCT)) { + var vectorValues = vectorValues(dims, size, in, sim); + int idx0 = randomIntBetween(0, size - 1); + int idx1 = randomIntBetween(0, size - 1); // may be the same as idx0 - which is ok. + + // getRandomVectorScorerSupplier + var scorer1 = DEFAULT_SCORER.getRandomVectorScorerSupplier(sim, vectorValues); + float expected = scorer1.scorer(idx0).score(idx1); + var scorer2 = MEMSEG_SCORER.getRandomVectorScorerSupplier(sim, vectorValues); + assertEquals(scorer2.scorer(idx0).score(idx1), expected, DELTA); + + // getRandomVectorScorer + var scorer3 = DEFAULT_SCORER.getRandomVectorScorer(sim, vectorValues, vectors[idx0]); + assertEquals(scorer3.score(idx1), expected, DELTA); + var scorer4 = MEMSEG_SCORER.getRandomVectorScorer(sim, vectorValues, vectors[idx0]); + assertEquals(scorer4.score(idx1), expected, DELTA); + } + } + } + } + } + + // Tests that copies in threads do not interfere with each other + public void testCopiesAcrossThreads() throws Exception { + final long maxChunkSize = 32; + final int dims = 34; // dimensions that are larger than the chunk size, to force fallback + byte[] vec1 = new byte[dims]; + byte[] vec2 = new byte[dims]; + IntStream.range(0, dims).forEach(i -> vec1[i] = 1); + IntStream.range(0, dims).forEach(i -> vec2[i] = 2); + try (Directory dir = new MMapDirectory(createTempDir("testRace"), maxChunkSize)) { + String fileName = "biz-" + dims; + try (IndexOutput out = dir.createOutput(fileName, IOContext.DEFAULT)) { + byte[] bytes = concat(vec1, vec1, vec2, vec2); + out.writeBytes(bytes, 0, bytes.length); + } + try (IndexInput in = dir.openInput(fileName, IOContext.DEFAULT)) { + for (var sim : List.of(COSINE, EUCLIDEAN, DOT_PRODUCT, MAXIMUM_INNER_PRODUCT)) { + var vectorValues = vectorValues(dims, 4, in, sim); + var scoreSupplier = DEFAULT_SCORER.getRandomVectorScorerSupplier(sim, vectorValues); + var expectedScore1 = scoreSupplier.scorer(0).score(1); + var expectedScore2 = scoreSupplier.scorer(2).score(3); + + var scorer = MEMSEG_SCORER.getRandomVectorScorerSupplier(sim, vectorValues); + var tasks = + List.>>of( + new AssertingScoreCallable(scorer.copy().scorer(0), 1, expectedScore1), + new AssertingScoreCallable(scorer.copy().scorer(2), 3, expectedScore2)); + var executor = Executors.newFixedThreadPool(2, new NamedThreadFactory("copiesThreads")); + var results = executor.invokeAll(tasks); + executor.shutdown(); + assertTrue(executor.awaitTermination(30, TimeUnit.SECONDS)); + assertEquals(results.stream().filter(Predicate.not(Future::isDone)).count(), 0L); + for (var res : results) { + assertTrue("Unexpected exception" + res.get(), res.get().isEmpty()); + } + } + } + } + } + + // A callable that scores the given ord and scorer and asserts the expected result. + static class AssertingScoreCallable implements Callable> { + final RandomVectorScorer scorer; + final int ord; + final float expectedScore; + + AssertingScoreCallable(RandomVectorScorer scorer, int ord, float expectedScore) { + this.scorer = scorer; + this.ord = ord; + this.expectedScore = expectedScore; + } + + @Override + public Optional call() throws Exception { + try { + for (int i = 0; i < 100; i++) { + assertEquals(scorer.score(ord), expectedScore, DELTA); + } + } catch (Throwable t) { + return Optional.of(t); + } + return Optional.empty(); + } + } + + // Tests with a large amount of data (> 2GB), which ensures that data offsets do not overflow + @Nightly + public void testLarge() throws IOException { + try (Directory dir = new MMapDirectory(createTempDir("testLarge"))) { + final int dims = 8192; + final int size = 262500; + final String fileName = "large-" + dims; + try (IndexOutput out = dir.createOutput(fileName, IOContext.DEFAULT)) { + for (int i = 0; i < size; i++) { + var vec = vector(i, dims); + out.writeBytes(vec, 0, vec.length); + } + } + + try (IndexInput in = dir.openInput(fileName, IOContext.DEFAULT)) { + assert in.length() > Integer.MAX_VALUE; + for (int times = 0; times < TIMES; times++) { + for (var sim : List.of(COSINE, EUCLIDEAN, DOT_PRODUCT, MAXIMUM_INNER_PRODUCT)) { + var vectorValues = vectorValues(dims, size, in, sim); + int ord1 = randomIntBetween(0, size - 1); + int ord2 = size - 1; + for (var ords : List.of(List.of(ord1, ord2), List.of(ord2, ord1))) { + int idx0 = ords.getFirst(); + int idx1 = ords.getLast(); + + // getRandomVectorScorerSupplier + var scorer1 = DEFAULT_SCORER.getRandomVectorScorerSupplier(sim, vectorValues); + float expected = scorer1.scorer(idx0).score(idx1); + var scorer2 = MEMSEG_SCORER.getRandomVectorScorerSupplier(sim, vectorValues); + assertEquals(scorer2.scorer(idx0).score(idx1), expected, DELTA); + + // getRandomVectorScorer + var query = vector(idx0, dims); + var scorer3 = DEFAULT_SCORER.getRandomVectorScorer(sim, vectorValues, query); + assertEquals(scorer3.score(idx1), expected, DELTA); + var scorer4 = MEMSEG_SCORER.getRandomVectorScorer(sim, vectorValues, query); + assertEquals(scorer4.score(idx1), expected, DELTA); + } + } + } + } + } + } + + RandomAccessVectorValues vectorValues( + int dims, int size, IndexInput in, VectorSimilarityFunction sim) throws IOException { + return new OffHeapByteVectorValues.DenseOffHeapVectorValues( + dims, size, in.slice("byteValues", 0, in.length()), dims, MEMSEG_SCORER, sim); + } + + // creates the vector based on the given ordinal, which is reproducible given the ord and dims + static byte[] vector(int ord, int dims) { + var random = new Random(Objects.hash(ord, dims)); + byte[] ba = new byte[dims]; + for (int i = 0; i < dims; i++) { + ba[i] = (byte) RandomNumbers.randomIntBetween(random, Byte.MIN_VALUE, Byte.MAX_VALUE); + } + return ba; + } + + /** Concatenates byte arrays. */ + static byte[] concat(byte[]... arrays) throws IOException { + try (ByteArrayOutputStream baos = new ByteArrayOutputStream()) { + for (var ba : arrays) { + baos.write(ba); + } + return baos.toByteArray(); + } + } + + static int randomIntBetween(int minInclusive, int maxInclusive) { + return RandomNumbers.randomIntBetween(random(), minInclusive, maxInclusive); + } + + static long randomLongBetween(long minInclusive, long maxInclusive) { + return RandomNumbers.randomLongBetween(random(), minInclusive, maxInclusive); + } + + static Function BYTE_ARRAY_RANDOM_FUNC = + size -> { + byte[] ba = new byte[size]; + for (int i = 0; i < size; i++) { + ba[i] = (byte) random().nextInt(); + } + return ba; + }; + + static Function BYTE_ARRAY_MAX_FUNC = + size -> { + byte[] ba = new byte[size]; + Arrays.fill(ba, Byte.MAX_VALUE); + return ba; + }; + + static Function BYTE_ARRAY_MIN_FUNC = + size -> { + byte[] ba = new byte[size]; + Arrays.fill(ba, Byte.MIN_VALUE); + return ba; + }; + + static final int TIMES = 100; // a loop iteration times +} diff --git a/lucene/core/src/test/org/apache/lucene/search/BaseKnnVectorQueryTestCase.java b/lucene/core/src/test/org/apache/lucene/search/BaseKnnVectorQueryTestCase.java index f38f2e0fa8d..df89be2c9de 100644 --- a/lucene/core/src/test/org/apache/lucene/search/BaseKnnVectorQueryTestCase.java +++ b/lucene/core/src/test/org/apache/lucene/search/BaseKnnVectorQueryTestCase.java @@ -46,6 +46,7 @@ import org.apache.lucene.search.knn.TopKnnCollectorManager; import org.apache.lucene.store.Directory; import org.apache.lucene.tests.index.RandomIndexWriter; +import org.apache.lucene.tests.store.BaseDirectoryWrapper; import org.apache.lucene.tests.util.LuceneTestCase; import org.apache.lucene.tests.util.TestUtil; import org.apache.lucene.util.BitSet; @@ -73,6 +74,13 @@ abstract Field getKnnVectorField( abstract Field getKnnVectorField(String name, float[] vector); + /** + * Creates a new directory. Subclasses can override to test different directory implementations. + */ + protected BaseDirectoryWrapper newDirectoryForTest() { + return LuceneTestCase.newDirectory(random()); + } + public void testEquals() { AbstractKnnVectorQuery q1 = getKnnVectorQuery("f1", new float[] {0, 1}, 10); Query filter1 = new TermQuery(new Term("id", "id1")); @@ -337,7 +345,7 @@ public void testScoreEuclidean() throws IOException { } public void testScoreCosine() throws IOException { - try (Directory d = newDirectory()) { + try (Directory d = newDirectoryForTest()) { try (IndexWriter w = new IndexWriter(d, new IndexWriterConfig())) { for (int j = 1; j <= 5; j++) { Document doc = new Document(); @@ -414,7 +422,7 @@ public void testScoreMIP() throws IOException { } public void testExplain() throws IOException { - try (Directory d = newDirectory()) { + try (Directory d = newDirectoryForTest()) { try (IndexWriter w = new IndexWriter(d, new IndexWriterConfig())) { for (int j = 0; j < 5; j++) { Document doc = new Document(); @@ -441,7 +449,7 @@ public void testExplain() throws IOException { } public void testExplainMultipleSegments() throws IOException { - try (Directory d = newDirectory()) { + try (Directory d = newDirectoryForTest()) { try (IndexWriter w = new IndexWriter(d, new IndexWriterConfig())) { for (int j = 0; j < 5; j++) { Document doc = new Document(); @@ -474,7 +482,7 @@ public void testSkewedIndex() throws IOException { * number of top K documents, but no more than K documents in total (otherwise we might occasionally * randomly fail to find one). */ - try (Directory d = newDirectory()) { + try (Directory d = newDirectoryForTest()) { try (IndexWriter w = new IndexWriter(d, new IndexWriterConfig())) { int r = 0; for (int i = 0; i < 5; i++) { @@ -510,7 +518,7 @@ public void testRandom() throws IOException { int dimension = atLeast(5); int numIters = atLeast(10); boolean everyDocHasAVector = random().nextBoolean(); - try (Directory d = newDirectory()) { + try (Directory d = newDirectoryForTest()) { RandomIndexWriter w = new RandomIndexWriter(random(), d); for (int i = 0; i < numDocs; i++) { Document doc = new Document(); @@ -549,7 +557,7 @@ public void testRandomWithFilter() throws IOException { int numDocs = 1000; int dimension = atLeast(5); int numIters = atLeast(10); - try (Directory d = newDirectory()) { + try (Directory d = newDirectoryForTest()) { // Always use the default kNN format to have predictable behavior around when it hits // visitedLimit. This is fine since the test targets AbstractKnnVectorQuery logic, not the kNN // format @@ -635,7 +643,7 @@ public void testRandomWithFilter() throws IOException { public void testFilterWithSameScore() throws IOException { int numDocs = 100; int dimension = atLeast(5); - try (Directory d = newDirectory()) { + try (Directory d = newDirectoryForTest()) { // Always use the default kNN format to have predictable behavior around when it hits // visitedLimit. This is fine since the test targets AbstractKnnVectorQuery logic, not the kNN // format @@ -675,7 +683,7 @@ public void testFilterWithSameScore() throws IOException { } public void testDeletes() throws IOException { - try (Directory dir = newDirectory(); + try (Directory dir = newDirectoryForTest(); IndexWriter w = new IndexWriter(dir, newIndexWriterConfig())) { final int numDocs = atLeast(100); final int dim = 30; @@ -719,7 +727,7 @@ public void testDeletes() throws IOException { } public void testAllDeletes() throws IOException { - try (Directory dir = newDirectory(); + try (Directory dir = newDirectoryForTest(); IndexWriter w = new IndexWriter(dir, newIndexWriterConfig())) { final int numDocs = atLeast(100); final int dim = 30; @@ -748,7 +756,7 @@ public void testAllDeletes() throws IOException { */ public void testNoLiveDocsReader() throws IOException { IndexWriterConfig iwc = newIndexWriterConfig(); - try (Directory dir = newDirectory(); + try (Directory dir = newDirectoryForTest(); IndexWriter w = new IndexWriter(dir, iwc)) { final int numDocs = 10; final int dim = 30; @@ -776,7 +784,7 @@ public void testNoLiveDocsReader() throws IOException { */ public void testBitSetQuery() throws IOException { IndexWriterConfig iwc = newIndexWriterConfig(); - try (Directory dir = newDirectory(); + try (Directory dir = newDirectoryForTest(); IndexWriter w = new IndexWriter(dir, iwc)) { final int numDocs = 100; final int dim = 30; @@ -884,7 +892,7 @@ Directory getIndexStore(String field, float[]... contents) throws IOException { Directory getIndexStore( String field, VectorSimilarityFunction vectorSimilarityFunction, float[]... contents) throws IOException { - Directory indexStore = newDirectory(); + Directory indexStore = newDirectoryForTest(); RandomIndexWriter writer = new RandomIndexWriter(random(), indexStore); for (int i = 0; i < contents.length; ++i) { Document doc = new Document(); @@ -917,7 +925,7 @@ Directory getIndexStore( * preserving the order of the added documents. */ private Directory getStableIndexStore(String field, float[]... contents) throws IOException { - Directory indexStore = newDirectory(); + Directory indexStore = newDirectoryForTest(); try (IndexWriter writer = new IndexWriter(indexStore, new IndexWriterConfig())) { for (int i = 0; i < contents.length; ++i) { Document doc = new Document(); diff --git a/lucene/core/src/test/org/apache/lucene/search/TestKnnByteVectorQueryMMap.java b/lucene/core/src/test/org/apache/lucene/search/TestKnnByteVectorQueryMMap.java new file mode 100644 index 00000000000..3010749da72 --- /dev/null +++ b/lucene/core/src/test/org/apache/lucene/search/TestKnnByteVectorQueryMMap.java @@ -0,0 +1,36 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.lucene.search; + +import java.io.IOException; +import java.io.UncheckedIOException; +import org.apache.lucene.store.MMapDirectory; +import org.apache.lucene.tests.store.BaseDirectoryWrapper; +import org.apache.lucene.tests.store.MockDirectoryWrapper; + +public class TestKnnByteVectorQueryMMap extends TestKnnByteVectorQuery { + + @Override + protected BaseDirectoryWrapper newDirectoryForTest() { + try { + return new MockDirectoryWrapper( + random(), new MMapDirectory(createTempDir("TestKnnByteVectorQueryMMap"))); + } catch (IOException e) { + throw new UncheckedIOException(e); + } + } +} diff --git a/lucene/test-framework/src/java/org/apache/lucene/tests/store/MockIndexInputWrapper.java b/lucene/test-framework/src/java/org/apache/lucene/tests/store/MockIndexInputWrapper.java index 39c41d46825..09015ed3fcf 100644 --- a/lucene/test-framework/src/java/org/apache/lucene/tests/store/MockIndexInputWrapper.java +++ b/lucene/test-framework/src/java/org/apache/lucene/tests/store/MockIndexInputWrapper.java @@ -20,6 +20,7 @@ import java.io.IOException; import java.util.Map; import java.util.Set; +import org.apache.lucene.internal.tests.TestSecrets; import org.apache.lucene.store.FilterIndexInput; import org.apache.lucene.store.IndexInput; @@ -27,6 +28,11 @@ * Used by MockDirectoryWrapper to create an input stream that keeps track of when it's been closed. */ public class MockIndexInputWrapper extends FilterIndexInput { + + static { + TestSecrets.getFilterInputIndexAccess().addTestFilterType(MockIndexInputWrapper.class); + } + private MockDirectoryWrapper dir; final String name; private volatile boolean closed; diff --git a/lucene/test-framework/src/java/org/apache/lucene/tests/store/SlowClosingMockIndexInputWrapper.java b/lucene/test-framework/src/java/org/apache/lucene/tests/store/SlowClosingMockIndexInputWrapper.java index 436f456cd92..c28415e11d6 100644 --- a/lucene/test-framework/src/java/org/apache/lucene/tests/store/SlowClosingMockIndexInputWrapper.java +++ b/lucene/test-framework/src/java/org/apache/lucene/tests/store/SlowClosingMockIndexInputWrapper.java @@ -17,6 +17,7 @@ package org.apache.lucene.tests.store; import java.io.IOException; +import org.apache.lucene.internal.tests.TestSecrets; import org.apache.lucene.store.IndexInput; import org.apache.lucene.util.ThreadInterruptedException; @@ -27,6 +28,11 @@ */ class SlowClosingMockIndexInputWrapper extends MockIndexInputWrapper { + static { + TestSecrets.getFilterInputIndexAccess() + .addTestFilterType(SlowClosingMockIndexInputWrapper.class); + } + public SlowClosingMockIndexInputWrapper( MockDirectoryWrapper dir, String name, IndexInput delegate) { super(dir, name, delegate, null); diff --git a/lucene/test-framework/src/java/org/apache/lucene/tests/store/SlowOpeningMockIndexInputWrapper.java b/lucene/test-framework/src/java/org/apache/lucene/tests/store/SlowOpeningMockIndexInputWrapper.java index 68e38543ce5..6e8b0042041 100644 --- a/lucene/test-framework/src/java/org/apache/lucene/tests/store/SlowOpeningMockIndexInputWrapper.java +++ b/lucene/test-framework/src/java/org/apache/lucene/tests/store/SlowOpeningMockIndexInputWrapper.java @@ -17,6 +17,7 @@ package org.apache.lucene.tests.store; import java.io.IOException; +import org.apache.lucene.internal.tests.TestSecrets; import org.apache.lucene.store.IndexInput; import org.apache.lucene.util.ThreadInterruptedException; @@ -26,6 +27,11 @@ */ class SlowOpeningMockIndexInputWrapper extends MockIndexInputWrapper { + static { + TestSecrets.getFilterInputIndexAccess() + .addTestFilterType(SlowOpeningMockIndexInputWrapper.class); + } + public SlowOpeningMockIndexInputWrapper( MockDirectoryWrapper dir, String name, IndexInput delegate) throws IOException { super(dir, name, delegate, null);