Skip to content

Commit

Permalink
Expand scalar quantization with adding half-byte (int4) quantization (#…
Browse files Browse the repository at this point in the history
…13197)

This PR is a culmination of some various streams of work:

 - Confidence interval optimizations, unlocked even smaller quantization bytes.
 - The ability to quantize down smaller than just int8 or int7
 - Adding an optimized int4 (halfbyte) vector API comparison for dot-product.

The idea of further scalar quantization gives users the choice between:

 - Further quantizing to gain space through compressing the bits into single byte values
 - Or allowing quantization to give guarantees around maximal values that afford faster vector operations.

I didn't add more panama vector APIs as I think trying to micro-optimize int4 for anything other than dot-product was a fools errand. Additionally, I only focused on ARM. I experimented with trying to get better performance on other architectures, but didn't get very far, so I fall back to dotProduct.
  • Loading branch information
benwtrent committed Apr 2, 2024
1 parent a633334 commit 6aeb198
Show file tree
Hide file tree
Showing 24 changed files with 1,067 additions and 211 deletions.
3 changes: 3 additions & 0 deletions lucene/CHANGES.txt
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@ New Features
This may improve paging logic especially when large segments are merged under memory pressure.
(Uwe Schindler, Chris Hegarty, Robert Muir, Adrien Grand)

* GITHUB#13197: Expand support for new scalar bit levels for HNSW vectors. This includes 4-bit vectors and an option
to compress them to gain a 50% reduction in memory usage. (Ben Trent)

Improvements
---------------------

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -511,7 +511,7 @@ private static void doTestHits(ScoreDoc[] hits, int expectedCount, IndexReader r
}
}

private static ScoreDoc[] assertKNNSearch(
static ScoreDoc[] assertKNNSearch(
IndexSearcher searcher,
float[] queryVector,
int k,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,16 @@ public void testCreateSortedIndex() throws IOException {
sortedTest.createBWCIndex();
}

public void testCreateInt8HNSWIndices() throws IOException {
TestInt8HnswBackwardsCompatibility int8HnswBackwardsCompatibility =
new TestInt8HnswBackwardsCompatibility(
Version.LATEST,
createPattern(
TestInt8HnswBackwardsCompatibility.INDEX_NAME,
TestInt8HnswBackwardsCompatibility.SUFFIX));
int8HnswBackwardsCompatibility.createBWCIndex();
}

private boolean isInitialMajorVersionRelease() {
return Version.LATEST.equals(Version.fromBits(Version.LATEST.major, 0, 0));
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.lucene.backward_index;

import static org.apache.lucene.backward_index.TestBasicBackwardsCompatibility.assertKNNSearch;

import com.carrotsearch.randomizedtesting.annotations.ParametersFactory;
import java.io.IOException;
import org.apache.lucene.codecs.Codec;
import org.apache.lucene.codecs.KnnVectorsFormat;
import org.apache.lucene.codecs.lucene99.Lucene99Codec;
import org.apache.lucene.codecs.lucene99.Lucene99HnswScalarQuantizedVectorsFormat;
import org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat;
import org.apache.lucene.document.Document;
import org.apache.lucene.document.Field;
import org.apache.lucene.document.FieldType;
import org.apache.lucene.document.KnnFloatVectorField;
import org.apache.lucene.document.StringField;
import org.apache.lucene.index.DirectoryReader;
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.IndexWriter;
import org.apache.lucene.index.IndexWriterConfig;
import org.apache.lucene.index.NoMergePolicy;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.store.Directory;
import org.apache.lucene.tests.analysis.MockAnalyzer;
import org.apache.lucene.tests.util.TestUtil;
import org.apache.lucene.util.Version;

public class TestInt8HnswBackwardsCompatibility extends BackwardsCompatibilityTestBase {

static final String INDEX_NAME = "int8_hnsw";
static final String SUFFIX = "";
private static final Version FIRST_INT8_HNSW_VERSION = Version.LUCENE_9_10_1;
private static final String KNN_VECTOR_FIELD = "knn_field";
private static final int DOC_COUNT = 30;
private static final FieldType KNN_VECTOR_FIELD_TYPE =
KnnFloatVectorField.createFieldType(3, VectorSimilarityFunction.COSINE);
private static final float[] KNN_VECTOR = {0.2f, -0.1f, 0.1f};

public TestInt8HnswBackwardsCompatibility(Version version, String pattern) {
super(version, pattern);
}

/** Provides all sorted versions to the test-framework */
@ParametersFactory(argumentFormatting = "Lucene-Version:%1$s; Pattern: %2$s")
public static Iterable<Object[]> testVersionsFactory() throws IllegalAccessException {
return allVersion(INDEX_NAME, SUFFIX);
}

protected Codec getCodec() {
return new Lucene99Codec() {
@Override
public KnnVectorsFormat getKnnVectorsFormatForField(String field) {
return new Lucene99HnswScalarQuantizedVectorsFormat(
Lucene99HnswVectorsFormat.DEFAULT_MAX_CONN,
Lucene99HnswVectorsFormat.DEFAULT_BEAM_WIDTH);
}
};
}

@Override
protected boolean supportsVersion(Version version) {
return version.onOrAfter(FIRST_INT8_HNSW_VERSION);
}

@Override
void verifyUsesDefaultCodec(Directory dir, String name) throws IOException {
// We don't use the default codec
}

public void testInt8HnswIndexAndSearch() throws Exception {
IndexWriterConfig indexWriterConfig =
newIndexWriterConfig(new MockAnalyzer(random()))
.setOpenMode(IndexWriterConfig.OpenMode.APPEND)
.setCodec(getCodec())
.setMergePolicy(newLogMergePolicy());
try (IndexWriter writer = new IndexWriter(directory, indexWriterConfig)) {
// add 10 docs
for (int i = 0; i < 10; i++) {
writer.addDocument(knnDocument(i + DOC_COUNT));
if (random().nextBoolean()) {
writer.flush();
}
}
if (random().nextBoolean()) {
writer.forceMerge(1);
}
writer.commit();
try (IndexReader reader = DirectoryReader.open(directory)) {
IndexSearcher searcher = new IndexSearcher(reader);
assertKNNSearch(searcher, KNN_VECTOR, 1000, DOC_COUNT + 10, "0");
assertKNNSearch(searcher, KNN_VECTOR, 10, 10, "0");
}
}
// This will confirm the docs are really sorted
TestUtil.checkIndex(directory);
}

@Override
protected void createIndex(Directory dir) throws IOException {
IndexWriterConfig conf =
new IndexWriterConfig(new MockAnalyzer(random()))
.setMaxBufferedDocs(10)
.setCodec(TestUtil.getDefaultCodec())
.setMergePolicy(NoMergePolicy.INSTANCE);
try (IndexWriter writer = new IndexWriter(dir, conf)) {
for (int i = 0; i < DOC_COUNT; i++) {
writer.addDocument(knnDocument(i));
}
writer.forceMerge(1);
}
try (DirectoryReader reader = DirectoryReader.open(dir)) {
IndexSearcher searcher = new IndexSearcher(reader);
assertKNNSearch(searcher, KNN_VECTOR, 1000, DOC_COUNT, "0");
assertKNNSearch(searcher, KNN_VECTOR, 10, 10, "0");
}
}

private static Document knnDocument(int id) {
Document doc = new Document();
float[] vector = {KNN_VECTOR[0], KNN_VECTOR[1], KNN_VECTOR[2] + 0.1f * id};
doc.add(new KnnFloatVectorField(KNN_VECTOR_FIELD, vector, KNN_VECTOR_FIELD_TYPE));
doc.add(new StringField("id", Integer.toString(id), Field.Store.YES));
return doc;
}

public void testReadOldIndices() throws Exception {
try (DirectoryReader reader = DirectoryReader.open(directory)) {
IndexSearcher searcher = new IndexSearcher(reader);
assertKNNSearch(searcher, KNN_VECTOR, 1000, DOC_COUNT, "0");
assertKNNSearch(searcher, KNN_VECTOR, 10, 10, "0");
}
}
}
Binary file not shown.
Original file line number Diff line number Diff line change
Expand Up @@ -39,3 +39,4 @@
9.9.1
9.9.2
9.10.0
9.10.1
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ public final class Lucene99HnswScalarQuantizedVectorsFormat extends KnnVectorsFo

/** Constructs a format using default graph construction parameters */
public Lucene99HnswScalarQuantizedVectorsFormat() {
this(DEFAULT_MAX_CONN, DEFAULT_BEAM_WIDTH, DEFAULT_NUM_MERGE_WORKER, null, null);
this(DEFAULT_MAX_CONN, DEFAULT_BEAM_WIDTH, DEFAULT_NUM_MERGE_WORKER, 7, true, null, null);
}

/**
Expand All @@ -75,7 +75,7 @@ public Lucene99HnswScalarQuantizedVectorsFormat() {
* @param beamWidth the size of the queue maintained during graph construction.
*/
public Lucene99HnswScalarQuantizedVectorsFormat(int maxConn, int beamWidth) {
this(maxConn, beamWidth, DEFAULT_NUM_MERGE_WORKER, null, null);
this(maxConn, beamWidth, DEFAULT_NUM_MERGE_WORKER, 7, true, null, null);
}

/**
Expand All @@ -85,6 +85,11 @@ public Lucene99HnswScalarQuantizedVectorsFormat(int maxConn, int beamWidth) {
* @param beamWidth the size of the queue maintained during graph construction.
* @param numMergeWorkers number of workers (threads) that will be used when doing merge. If
* larger than 1, a non-null {@link ExecutorService} must be passed as mergeExec
* @param bits the number of bits to use for scalar quantization (must be between 1 and 8,
* inclusive)
* @param compress whether to compress the vectors, if true, the vectors that are quantized with
* lte 4 bits will be compressed into a single byte. If false, the vectors will be stored as
* is. This provides a trade-off of memory usage and speed.
* @param confidenceInterval the confidenceInterval for scalar quantizing the vectors, when `null`
* it is calculated based on the vector field dimensions.
* @param mergeExec the {@link ExecutorService} that will be used by ALL vector writers that are
Expand All @@ -94,6 +99,8 @@ public Lucene99HnswScalarQuantizedVectorsFormat(
int maxConn,
int beamWidth,
int numMergeWorkers,
int bits,
boolean compress,
Float confidenceInterval,
ExecutorService mergeExec) {
super("Lucene99HnswScalarQuantizedVectorsFormat");
Expand Down Expand Up @@ -127,7 +134,8 @@ public Lucene99HnswScalarQuantizedVectorsFormat(
} else {
this.mergeExec = null;
}
this.flatVectorsFormat = new Lucene99ScalarQuantizedVectorsFormat(confidenceInterval);
this.flatVectorsFormat =
new Lucene99ScalarQuantizedVectorsFormat(confidenceInterval, bits, compress);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,17 @@
* @lucene.experimental
*/
public final class Lucene99ScalarQuantizedVectorsFormat extends FlatVectorsFormat {

// The bits that are allowed for scalar quantization
// We only allow unsigned byte (8), signed byte (7), and half-byte (4)
private static final int ALLOWED_BITS = (1 << 8) | (1 << 7) | (1 << 4);
public static final String QUANTIZED_VECTOR_COMPONENT = "QVEC";

static final String NAME = "Lucene99ScalarQuantizedVectorsFormat";

static final int VERSION_START = 0;
static final int VERSION_CURRENT = VERSION_START;
static final int VERSION_ADD_BITS = 1;
static final int VERSION_CURRENT = VERSION_ADD_BITS;
static final String META_CODEC_NAME = "Lucene99ScalarQuantizedVectorsFormatMeta";
static final String VECTOR_DATA_CODEC_NAME = "Lucene99ScalarQuantizedVectorsFormatData";
static final String META_EXTENSION = "vemq";
Expand All @@ -55,18 +60,27 @@ public final class Lucene99ScalarQuantizedVectorsFormat extends FlatVectorsForma
*/
final Float confidenceInterval;

final byte bits;
final boolean compress;

/** Constructs a format using default graph construction parameters */
public Lucene99ScalarQuantizedVectorsFormat() {
this(null);
this(null, 7, true);
}

/**
* Constructs a format using the given graph construction parameters.
*
* @param confidenceInterval the confidenceInterval for scalar quantizing the vectors, when `null`
* it is calculated based on the vector field dimensions.
* it is calculated dynamically.
* @param bits the number of bits to use for scalar quantization (must be between 1 and 8,
* inclusive)
* @param compress whether to compress the vectors, if true, the vectors that are quantized with
* lte 4 bits will be compressed into a single byte. If false, the vectors will be stored as
* is. This provides a trade-off of memory usage and speed.
*/
public Lucene99ScalarQuantizedVectorsFormat(Float confidenceInterval) {
public Lucene99ScalarQuantizedVectorsFormat(
Float confidenceInterval, int bits, boolean compress) {
if (confidenceInterval != null
&& (confidenceInterval < MINIMUM_CONFIDENCE_INTERVAL
|| confidenceInterval > MAXIMUM_CONFIDENCE_INTERVAL)) {
Expand All @@ -78,7 +92,12 @@ public Lucene99ScalarQuantizedVectorsFormat(Float confidenceInterval) {
+ "; confidenceInterval="
+ confidenceInterval);
}
if (bits < 1 || bits > 8 || (ALLOWED_BITS & (1 << bits)) == 0) {
throw new IllegalArgumentException("bits must be one of: 4, 7, 8; bits=" + bits);
}
this.bits = (byte) bits;
this.confidenceInterval = confidenceInterval;
this.compress = compress;
}

public static float calculateDefaultConfidenceInterval(int vectorDimension) {
Expand All @@ -92,6 +111,10 @@ public String toString() {
+ NAME
+ ", confidenceInterval="
+ confidenceInterval
+ ", bits="
+ bits
+ ", compress="
+ compress
+ ", rawVectorFormat="
+ rawVectorFormat
+ ")";
Expand All @@ -100,7 +123,7 @@ public String toString() {
@Override
public FlatVectorsWriter fieldsWriter(SegmentWriteState state) throws IOException {
return new Lucene99ScalarQuantizedVectorsWriter(
state, confidenceInterval, rawVectorFormat.fieldsWriter(state));
state, confidenceInterval, bits, compress, rawVectorFormat.fieldsWriter(state));
}

@Override
Expand Down

0 comments on commit 6aeb198

Please sign in to comment.