Skip to content

Commit

Permalink
Moving quantization logic to make future quantizer work simpler (#13091)
Browse files Browse the repository at this point in the history
  • Loading branch information
benwtrent committed Feb 8, 2024
1 parent 2d713d9 commit c70c946
Show file tree
Hide file tree
Showing 20 changed files with 127 additions and 94 deletions.
2 changes: 2 additions & 0 deletions lucene/core/src/java/module-info.java
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,8 @@
opens org.apache.lucene.document to
org.apache.lucene.test_framework;

exports org.apache.lucene.util.quantization;

provides org.apache.lucene.analysis.TokenizerFactory with
org.apache.lucene.analysis.standard.StandardTokenizerFactory;
provides org.apache.lucene.codecs.Codec with
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@

package org.apache.lucene.codecs.lucene99;

import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsReader.readSimilarityFunction;
import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsReader.readVectorEncoding;

import java.io.IOException;
import java.util.HashMap;
import java.util.Map;
Expand All @@ -35,7 +38,6 @@
import org.apache.lucene.index.VectorEncoding;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.store.ChecksumIndexInput;
import org.apache.lucene.store.DataInput;
import org.apache.lucene.store.IndexInput;
import org.apache.lucene.util.Accountable;
import org.apache.lucene.util.IOUtils;
Expand All @@ -55,7 +57,7 @@ public final class Lucene99FlatVectorsReader extends FlatVectorsReader {
private final Map<String, FieldEntry> fields = new HashMap<>();
private final IndexInput vectorData;

Lucene99FlatVectorsReader(SegmentReadState state) throws IOException {
public Lucene99FlatVectorsReader(SegmentReadState state) throws IOException {
int versionMeta = readMetadata(state);
boolean success = false;
try {
Expand Down Expand Up @@ -188,24 +190,6 @@ private void validateFieldEntry(FieldInfo info, FieldEntry fieldEntry) {
}
}

private VectorSimilarityFunction readSimilarityFunction(DataInput input) throws IOException {
int similarityFunctionId = input.readInt();
if (similarityFunctionId < 0
|| similarityFunctionId >= VectorSimilarityFunction.values().length) {
throw new CorruptIndexException(
"Invalid similarity function id: " + similarityFunctionId, input);
}
return VectorSimilarityFunction.values()[similarityFunctionId];
}

private VectorEncoding readVectorEncoding(DataInput input) throws IOException {
int encodingId = input.readInt();
if (encodingId < 0 || encodingId >= VectorEncoding.values().length) {
throw new CorruptIndexException("Invalid vector encoding id: " + encodingId, input);
}
return VectorEncoding.values()[encodingId];
}

private FieldEntry readField(IndexInput input) throws IOException {
VectorEncoding vectorEncoding = readVectorEncoding(input);
VectorSimilarityFunction similarityFunction = readSimilarityFunction(input);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ public final class Lucene99FlatVectorsWriter extends FlatVectorsWriter {
private final List<FieldWriter<?>> fields = new ArrayList<>();
private boolean finished;

Lucene99FlatVectorsWriter(SegmentWriteState state) throws IOException {
public Lucene99FlatVectorsWriter(SegmentWriteState state) throws IOException {
segmentWriteState = state;
String metaFileName =
IndexFileNames.segmentFileName(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,12 +45,14 @@
import org.apache.lucene.util.Bits;
import org.apache.lucene.util.IOUtils;
import org.apache.lucene.util.RamUsageEstimator;
import org.apache.lucene.util.ScalarQuantizer;
import org.apache.lucene.util.hnsw.HnswGraph;
import org.apache.lucene.util.hnsw.HnswGraphSearcher;
import org.apache.lucene.util.hnsw.OrdinalTranslatedKnnCollector;
import org.apache.lucene.util.hnsw.RandomVectorScorer;
import org.apache.lucene.util.packed.DirectMonotonicReader;
import org.apache.lucene.util.quantization.QuantizedByteVectorValues;
import org.apache.lucene.util.quantization.QuantizedVectorsReader;
import org.apache.lucene.util.quantization.ScalarQuantizer;

/**
* Reads vectors from the index segments along with index data structures supporting KNN search.
Expand All @@ -68,7 +70,7 @@ public final class Lucene99HnswVectorsReader extends KnnVectorsReader
private final IndexInput vectorIndex;
private final FlatVectorsReader flatVectorsReader;

Lucene99HnswVectorsReader(SegmentReadState state, FlatVectorsReader flatVectorsReader)
public Lucene99HnswVectorsReader(SegmentReadState state, FlatVectorsReader flatVectorsReader)
throws IOException {
this.flatVectorsReader = flatVectorsReader;
boolean success = false;
Expand Down Expand Up @@ -169,7 +171,8 @@ private void validateFieldEntry(FieldInfo info, FieldEntry fieldEntry) {
}
}

private VectorSimilarityFunction readSimilarityFunction(DataInput input) throws IOException {
public static VectorSimilarityFunction readSimilarityFunction(DataInput input)
throws IOException {
int similarityFunctionId = input.readInt();
if (similarityFunctionId < 0
|| similarityFunctionId >= VectorSimilarityFunction.values().length) {
Expand All @@ -179,7 +182,7 @@ private VectorSimilarityFunction readSimilarityFunction(DataInput input) throws
return VectorSimilarityFunction.values()[similarityFunctionId];
}

private VectorEncoding readVectorEncoding(DataInput input) throws IOException {
public static VectorEncoding readVectorEncoding(DataInput input) throws IOException {
int encodingId = input.readInt();
if (encodingId < 0 || encodingId >= VectorEncoding.values().length) {
throw new CorruptIndexException("Invalid vector encoding id: " + encodingId, input);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ public final class Lucene99HnswVectorsWriter extends KnnVectorsWriter {
private final List<FieldWriter<?>> fields = new ArrayList<>();
private boolean finished;

Lucene99HnswVectorsWriter(
public Lucene99HnswVectorsWriter(
SegmentWriteState state,
int M,
int beamWidth,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ public Lucene99ScalarQuantizedVectorsFormat(Float confidenceInterval) {
this.confidenceInterval = confidenceInterval;
}

static float calculateDefaultConfidenceInterval(int vectorDimension) {
public static float calculateDefaultConfidenceInterval(int vectorDimension) {
return Math.max(MINIMUM_CONFIDENCE_INTERVAL, 1f - (1f / (vectorDimension + 1)));
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@

package org.apache.lucene.codecs.lucene99;

import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsReader.readSimilarityFunction;
import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsReader.readVectorEncoding;

import java.io.IOException;
import java.util.HashMap;
import java.util.Map;
Expand All @@ -33,13 +36,15 @@
import org.apache.lucene.index.VectorEncoding;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.store.ChecksumIndexInput;
import org.apache.lucene.store.DataInput;
import org.apache.lucene.store.IndexInput;
import org.apache.lucene.util.Accountable;
import org.apache.lucene.util.IOUtils;
import org.apache.lucene.util.RamUsageEstimator;
import org.apache.lucene.util.ScalarQuantizer;
import org.apache.lucene.util.hnsw.RandomVectorScorer;
import org.apache.lucene.util.quantization.QuantizedByteVectorValues;
import org.apache.lucene.util.quantization.QuantizedVectorsReader;
import org.apache.lucene.util.quantization.ScalarQuantizedRandomVectorScorer;
import org.apache.lucene.util.quantization.ScalarQuantizer;

/**
* Reads Scalar Quantized vectors from the index segments along with index data structures.
Expand All @@ -56,8 +61,8 @@ public final class Lucene99ScalarQuantizedVectorsReader extends FlatVectorsReade
private final IndexInput quantizedVectorData;
private final FlatVectorsReader rawVectorsReader;

Lucene99ScalarQuantizedVectorsReader(SegmentReadState state, FlatVectorsReader rawVectorsReader)
throws IOException {
public Lucene99ScalarQuantizedVectorsReader(
SegmentReadState state, FlatVectorsReader rawVectorsReader) throws IOException {
this.rawVectorsReader = rawVectorsReader;
int versionMeta = -1;
String metaFileName =
Expand Down Expand Up @@ -237,24 +242,6 @@ private FieldEntry readField(IndexInput input) throws IOException {
return new FieldEntry(input, vectorEncoding, similarityFunction);
}

private VectorSimilarityFunction readSimilarityFunction(DataInput input) throws IOException {
int similarityFunctionId = input.readInt();
if (similarityFunctionId < 0
|| similarityFunctionId >= VectorSimilarityFunction.values().length) {
throw new CorruptIndexException(
"Invalid similarity function id: " + similarityFunctionId, input);
}
return VectorSimilarityFunction.values()[similarityFunctionId];
}

private VectorEncoding readVectorEncoding(DataInput input) throws IOException {
int encodingId = input.readInt();
if (encodingId < 0 || encodingId >= VectorEncoding.values().length) {
throw new CorruptIndexException("Invalid vector encoding id: " + encodingId, input);
}
return VectorEncoding.values()[encodingId];
}

@Override
public QuantizedByteVectorValues getQuantizedVectorValues(String fieldName) throws IOException {
FieldEntry fieldEntry = fields.get(fieldName);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,11 +53,14 @@
import org.apache.lucene.util.IOUtils;
import org.apache.lucene.util.InfoStream;
import org.apache.lucene.util.RamUsageEstimator;
import org.apache.lucene.util.ScalarQuantizer;
import org.apache.lucene.util.VectorUtil;
import org.apache.lucene.util.hnsw.CloseableRandomVectorScorerSupplier;
import org.apache.lucene.util.hnsw.RandomVectorScorer;
import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier;
import org.apache.lucene.util.quantization.QuantizedByteVectorValues;
import org.apache.lucene.util.quantization.QuantizedVectorsReader;
import org.apache.lucene.util.quantization.ScalarQuantizedRandomVectorScorerSupplier;
import org.apache.lucene.util.quantization.ScalarQuantizer;

/**
* Writes quantized vector values and metadata to index segments.
Expand Down Expand Up @@ -95,7 +98,7 @@ public final class Lucene99ScalarQuantizedVectorsWriter extends FlatVectorsWrite
private final FlatVectorsWriter rawVectorDelegate;
private boolean finished;

Lucene99ScalarQuantizedVectorsWriter(
public Lucene99ScalarQuantizedVectorsWriter(
SegmentWriteState state, Float confidenceInterval, FlatVectorsWriter rawVectorDelegate)
throws IOException {
this.confidenceInterval = confidenceInterval;
Expand Down Expand Up @@ -523,7 +526,16 @@ private static ScalarQuantizer getQuantizedState(
return null;
}

static ScalarQuantizer mergeAndRecalculateQuantiles(
/**
* Merges the quantiles of the segments and recalculates the quantiles if necessary.
*
* @param mergeState The merge state
* @param fieldInfo The field info
* @param confidenceInterval The confidence interval
* @return The merged quantiles
* @throws IOException If there is a low-level I/O error
*/
public static ScalarQuantizer mergeAndRecalculateQuantiles(
MergeState mergeState, FieldInfo fieldInfo, float confidenceInterval) throws IOException {
List<ScalarQuantizer> quantizationStates = new ArrayList<>(mergeState.liveDocs.length);
List<Integer> segmentSizes = new ArrayList<>(mergeState.liveDocs.length);
Expand Down Expand Up @@ -588,7 +600,7 @@ static boolean shouldRequantize(ScalarQuantizer existingQuantiles, ScalarQuantiz
/**
* Writes the vector values to the output and returns a set of documents that contains vectors.
*/
private static DocsWithFieldSet writeQuantizedVectorData(
public static DocsWithFieldSet writeQuantizedVectorData(
IndexOutput output, QuantizedByteVectorValues quantizedByteVectorValues) throws IOException {
DocsWithFieldSet docsWithField = new DocsWithFieldSet();
for (int docV = quantizedByteVectorValues.nextDoc();
Expand Down Expand Up @@ -868,7 +880,7 @@ public int dimension() {
}

@Override
float getScoreCorrectionConstant() throws IOException {
public float getScoreCorrectionConstant() throws IOException {
return current.values.getScoreCorrectionConstant();
}
}
Expand Down Expand Up @@ -898,7 +910,7 @@ public QuantizedFloatVectorValues(
}

@Override
float getScoreCorrectionConstant() {
public float getScoreCorrectionConstant() {
return offsetValue;
}

Expand Down Expand Up @@ -1007,7 +1019,7 @@ private OffsetCorrectedQuantizedByteVectorValues(
}

@Override
float getScoreCorrectionConstant() throws IOException {
public float getScoreCorrectionConstant() throws IOException {
return scalarQuantizer.recalculateCorrectiveOffset(
in.vectorValue(), oldScalarQuantizer, vectorSimilarityFunction);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,14 @@
import org.apache.lucene.store.IndexInput;
import org.apache.lucene.util.Bits;
import org.apache.lucene.util.packed.DirectMonotonicReader;
import org.apache.lucene.util.quantization.QuantizedByteVectorValues;
import org.apache.lucene.util.quantization.RandomAccessQuantizedByteVectorValues;

/**
* Read the quantized vector values and their score correction values from the index input. This
* supports both iterated and random access.
*/
abstract class OffHeapQuantizedByteVectorValues extends QuantizedByteVectorValues
public abstract class OffHeapQuantizedByteVectorValues extends QuantizedByteVectorValues
implements RandomAccessQuantizedByteVectorValues {

protected final int dimension;
Expand Down Expand Up @@ -77,7 +79,7 @@ public float getScoreCorrectionConstant() {
return scoreCorrectionConstant[0];
}

static OffHeapQuantizedByteVectorValues load(
public static OffHeapQuantizedByteVectorValues load(
OrdToDocDISIReaderConfiguration configuration,
int dimension,
int size,
Expand All @@ -98,7 +100,11 @@ static OffHeapQuantizedByteVectorValues load(
}
}

static class DenseOffHeapVectorValues extends OffHeapQuantizedByteVectorValues {
/**
* Dense vector values that are stored off-heap. This is the most common case when every doc has a
* vector.
*/
public static class DenseOffHeapVectorValues extends OffHeapQuantizedByteVectorValues {

private int doc = -1;

Expand Down Expand Up @@ -231,7 +237,7 @@ public int size() {
}

@Override
public byte[] vectorValue() throws IOException {
public byte[] vectorValue() {
throw new UnsupportedOperationException();
}

Expand All @@ -246,17 +252,17 @@ public int nextDoc() throws IOException {
}

@Override
public int advance(int target) throws IOException {
public int advance(int target) {
return doc = NO_MORE_DOCS;
}

@Override
public EmptyOffHeapVectorValues copy() throws IOException {
public EmptyOffHeapVectorValues copy() {
throw new UnsupportedOperationException();
}

@Override
public byte[] vectorValue(int targetOrd) throws IOException {
public byte[] vectorValue(int targetOrd) {
throw new UnsupportedOperationException();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,17 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.lucene.codecs.lucene99;
package org.apache.lucene.util.quantization;

import java.io.IOException;
import org.apache.lucene.index.ByteVectorValues;

/**
* A version of {@link ByteVectorValues}, but additionally retrieving score correction offset for
* Scalar quantization scores.
*
* @lucene.experimental
*/
abstract class QuantizedByteVectorValues extends ByteVectorValues {
abstract float getScoreCorrectionConstant() throws IOException;
public abstract class QuantizedByteVectorValues extends ByteVectorValues {
public abstract float getScoreCorrectionConstant() throws IOException;
}
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,18 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.lucene.codecs.lucene99;
package org.apache.lucene.util.quantization;

import java.io.Closeable;
import java.io.IOException;
import org.apache.lucene.util.Accountable;
import org.apache.lucene.util.ScalarQuantizer;

/** Quantized vector reader */
interface QuantizedVectorsReader extends Closeable, Accountable {
/**
* Quantized vector reader
*
* @lucene.experimental
*/
public interface QuantizedVectorsReader extends Closeable, Accountable {

QuantizedByteVectorValues getQuantizedVectorValues(String fieldName) throws IOException;

Expand Down
Loading

0 comments on commit c70c946

Please sign in to comment.