Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding binary Hamming distance as similarity option for byte vectors #13076

Closed
wants to merge 11 commits into from
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,20 @@ public int binarySquareVector() {
return VectorUtil.squareDistance(bytesA, bytesB);
}

@Benchmark
public int binaryHammingDistanceVarHandle() {
return VectorUtil.binaryHammingDistance(bytesA, bytesB);
}

@Benchmark
public int binaryHammingDistanceScalar() {
int distance = 0;
for (int i = 0; i < bytesA.length; i++) {
distance += Integer.bitCount((bytesA[i] ^ bytesB[i]) & 0xFF);
}
return distance;
}

@Benchmark
public float floatCosineScalar() {
return VectorUtil.cosine(floatsA, floatsB);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,10 @@ private static float quantizeQuery(
VectorUtil.l2normalize(queryCopy);
yield queryCopy;
}
case BINARY_HAMMING_DISTANCE -> throw new IllegalArgumentException(
"Query quantization is not supported for '"
+ VectorSimilarityFunction.BINARY_HAMMING_DISTANCE.name()
+ "'.");
};
return scalarQuantizer.quantize(processedQuery, quantizedQuery, similarityFunction);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
*/
package org.apache.lucene.index;

import static org.apache.lucene.util.VectorUtil.binaryHammingDistance;
import static org.apache.lucene.util.VectorUtil.cosine;
import static org.apache.lucene.util.VectorUtil.dotProduct;
import static org.apache.lucene.util.VectorUtil.dotProductScore;
Expand Down Expand Up @@ -94,6 +95,29 @@ public float compare(float[] v1, float[] v2) {
public float compare(byte[] v1, byte[] v2) {
return scaleMaxInnerProductScore(dotProduct(v1, v2));
}
},
/**
* Binary Hamming distance; Computes how many bits are different in two bytes.
*
* <p>Only supported for bytes. To convert the distance to a similarity score we normalize using 1
* / (1 + hammingDistance)
*/
BINARY_HAMMING_DISTANCE {
@Override
public float compare(float[] v1, float[] v2) {
throw new UnsupportedOperationException(
BINARY_HAMMING_DISTANCE.name() + " is only supported for byte vectors");
}

@Override
public float compare(byte[] v1, byte[] v2) {
return (1f / (1 + binaryHammingDistance(v1, v2)));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This depends on vector length, is this intended? I would have expected to have something like dimensions * 8 / (1 + distance). I know, it is not relevant for scoring purposes as it is a constant factor, but we have some normalization on other functions, too.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see your point. The initial idea was to have the score bounded in (0, 1] so to have more a "natural" way of interpreting it, i.e. 1 will always mean identical, and ~0 will mean that the two vectors are complements of each other (1/(1+dim)). If we are to scale the score based on the number of dimensions, we move this to (0, dimensions*8] which will effectively be the reverse of the distance. So for example if two vectors are identical, they would have a score of dimensions * 8, whereas if one is complement of the other, their score would be ~1 (dim/(1+dim) ).

Don't have a strong opinion on this, happy to proceed with updating the normalization constant if you prefer.

}

@Override
public boolean supportsVectorEncoding(VectorEncoding encoding) {
return encoding == VectorEncoding.BYTE;
}
};

/**
Expand All @@ -116,4 +140,12 @@ public float compare(byte[] v1, byte[] v2) {
* @return the value of the similarity function applied to the two vectors
*/
public abstract float compare(byte[] v1, byte[] v2);

/**
* Specify whether the encoding provided is supported by the similarity function. Defaults to
* true.
*/
public boolean supportsVectorEncoding(VectorEncoding encoding) {
return true;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,10 @@ static ScalarQuantizedVectorSimilarity fromVectorSimilarity(
case EUCLIDEAN -> new Euclidean(constMultiplier);
case COSINE, DOT_PRODUCT -> new DotProduct(constMultiplier);
case MAXIMUM_INNER_PRODUCT -> new MaximumInnerProduct(constMultiplier);
case BINARY_HAMMING_DISTANCE -> throw new IllegalArgumentException(
"Cannot use '"
+ VectorSimilarityFunction.BINARY_HAMMING_DISTANCE.name()
+ "'with scalar quantization");
};
}

Expand Down
14 changes: 14 additions & 0 deletions lucene/core/src/java/org/apache/lucene/util/VectorUtil.java
Original file line number Diff line number Diff line change
Expand Up @@ -214,4 +214,18 @@ public static float[] checkFinite(float[] v) {
}
return v;
}

public static int binaryHammingDistance(byte[] a, byte[] b) {
uschindler marked this conversation as resolved.
Show resolved Hide resolved
int distance = 0, i = 0;
for (final int upperBound = a.length & ~(Long.BYTES - 1); i < upperBound; i += Long.BYTES) {
distance +=
Long.bitCount(
(long) BitUtil.VH_NATIVE_LONG.get(a, i) ^ (long) BitUtil.VH_NATIVE_LONG.get(b, i));
}
// tail:
for (; i < a.length; i++) {
distance += Integer.bitCount((a[i] ^ b[i]) & 0xFF);
}
return distance;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
import org.apache.lucene.index.IndexWriterConfig;
import org.apache.lucene.index.LeafReader;
import org.apache.lucene.index.NoMergePolicy;
import org.apache.lucene.index.VectorEncoding;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.store.Directory;
import org.apache.lucene.tests.index.BaseKnnVectorsFormatTestCase;
Expand All @@ -43,6 +44,11 @@

public class TestLucene99HnswQuantizedVectorsFormat extends BaseKnnVectorsFormatTestCase {

@Override
protected VectorEncoding randomVectorEncoding() {
return VectorEncoding.FLOAT32;
}

@Override
protected Codec getCodec() {
return new Lucene99Codec() {
Expand All @@ -58,7 +64,8 @@ public KnnVectorsFormat getKnnVectorsFormatForField(String field) {
public void testQuantizedVectorsWriteAndRead() throws Exception {
// create lucene directory with codec
int numVectors = 1 + random().nextInt(50);
VectorSimilarityFunction similarityFunction = randomSimilarity();
VectorSimilarityFunction similarityFunction =
randomSimilarityForEncoding(VectorEncoding.FLOAT32);
boolean normalize = similarityFunction == VectorSimilarityFunction.COSINE;
int dim = random().nextInt(64) + 1;
List<float[]> vectors = new ArrayList<>(numVectors);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,15 @@

import com.carrotsearch.randomizedtesting.generators.RandomPicks;
import java.io.IOException;
import java.util.Arrays;
import java.util.Random;
import org.apache.lucene.index.DirectoryReader;
import org.apache.lucene.index.IndexOptions;
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.IndexWriter;
import org.apache.lucene.index.IndexWriterConfig;
import org.apache.lucene.index.NoMergePolicy;
import org.apache.lucene.index.VectorEncoding;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.store.Directory;
import org.apache.lucene.tests.util.LuceneTestCase;
Expand Down Expand Up @@ -89,7 +91,11 @@ private static Field randomDocValuesField(Random random, String fieldName) {

private static Field randomKnnVectorField(Random random, String fieldName) {
VectorSimilarityFunction similarityFunction =
RandomPicks.randomFrom(random, VectorSimilarityFunction.values());
RandomPicks.randomFrom(
random,
Arrays.stream(VectorSimilarityFunction.values())
.filter(x -> x.supportsVectorEncoding(VectorEncoding.FLOAT32))
.toList());
float[] values = new float[randomIntBetween(1, 10)];
for (int i = 0; i < values.length; i++) {
values[i] = randomFloat();
Expand Down
16 changes: 14 additions & 2 deletions lucene/core/src/test/org/apache/lucene/index/TestKnnGraph.java
Original file line number Diff line number Diff line change
Expand Up @@ -80,9 +80,13 @@ public void setup() {
M = random().nextInt(256) + 3;
}

int similarity = random().nextInt(VectorSimilarityFunction.values().length - 1) + 1;
similarityFunction = VectorSimilarityFunction.values()[similarity];
// TODO: refactor this to better handle BYTE vs FLOAT encodings
// even though vector encoding could be BYTE, we make explicit use of KnnFloatField in multiple
// tests
// so similarity not working with floats (e.g. BINARY_HAMMING_DISTANCE) will fail
similarityFunction = randomSimilarityForFloatEncoding();
vectorEncoding = randomVectorEncoding();

boolean quantized = randomBoolean();
codec =
new FilterCodec(TestUtil.getDefaultCodec().getName(), TestUtil.getDefaultCodec()) {
Expand Down Expand Up @@ -118,6 +122,14 @@ private VectorEncoding randomVectorEncoding() {
return VectorEncoding.values()[random().nextInt(VectorEncoding.values().length)];
}

private VectorSimilarityFunction randomSimilarityForFloatEncoding() {
List<VectorSimilarityFunction> supportedVectorSimilarities =
Arrays.stream(VectorSimilarityFunction.values())
.filter(x -> x.supportsVectorEncoding(VectorEncoding.FLOAT32))
.toList();
return supportedVectorSimilarities.get(random().nextInt(supportedVectorSimilarities.size()));
}

@After
public void cleanup() {
M = HnswGraphBuilder.DEFAULT_MAX_CONN;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -825,8 +825,7 @@ private Directory getStableIndexStore(String field, float[]... contents) throws
return indexStore;
}

private void assertMatches(IndexSearcher searcher, Query q, int expectedMatches)
throws IOException {
void assertMatches(IndexSearcher searcher, Query q, int expectedMatches) throws IOException {
ScoreDoc[] result = searcher.search(q, 1000).scoreDocs;
assertEquals(expectedMatches, result.length);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,31 @@ Field getKnnVectorField(String name, float[] vector) {
return new KnnByteVectorField(name, floatToBytes(vector), VectorSimilarityFunction.EUCLIDEAN);
}

public void testScoreHammingDistance() throws IOException {
// this seems a bit weird as Hamming distance is only supported for byte vectors,
// but the following floats will be converted at a later step through `floatToBytes`
try (Directory indexStore =
getIndexStore(
"field",
VectorSimilarityFunction.BINARY_HAMMING_DISTANCE,
new float[] {0, 1, 8},
new float[] {-1, -127, 9},
new float[] {1, 2, 8});
IndexReader reader = DirectoryReader.open(indexStore)) {
IndexSearcher searcher = newSearcher(reader);
AbstractKnnVectorQuery kvq = getKnnVectorQuery("field", new float[] {0, 1, 8}, 10);
assertMatches(searcher, kvq, 3);
ScoreDoc[] scoreDocs = searcher.search(kvq, 3).scoreDocs;
assertIdMatches(reader, "id0", scoreDocs[0]);
assertIdMatches(reader, "id2", scoreDocs[1]);
assertIdMatches(reader, "id1", scoreDocs[2]);

assertEquals(1.0, scoreDocs[0].score, 1e-7);
assertEquals(1 / 4f, scoreDocs[1].score, 1e-7);
assertEquals(1 / 11f, scoreDocs[2].score, 1e-7);
}
}

private static byte[] floatToBytes(float[] query) {
byte[] bytes = new byte[query.length];
for (int i = 0; i < query.length; i++) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,11 @@ public static void beforeClass() throws Exception {
"knnByteField4",
new byte[] {-127, 127, 127},
VectorSimilarityFunction.MAXIMUM_INNER_PRODUCT));
document.add(
new KnnByteVectorField(
"knnByteField6",
new byte[] {-127, 127, 127},
VectorSimilarityFunction.BINARY_HAMMING_DISTANCE));
iw.addDocument(document);

Document document2 = new Document();
Expand Down Expand Up @@ -107,6 +112,11 @@ public static void beforeClass() throws Exception {
"knnByteField4",
new byte[] {14, 29, 31},
VectorSimilarityFunction.MAXIMUM_INNER_PRODUCT));
document2.add(
new KnnByteVectorField(
"knnByteField6",
new byte[] {1, -1, 13},
VectorSimilarityFunction.BINARY_HAMMING_DISTANCE));
iw.addDocument(document2);

Document document3 = new Document();
Expand Down Expand Up @@ -349,6 +359,26 @@ public void testMaximumProductSimilarityValuesSource() throws Exception {
0.0001);
}

public void testHammingDistanceValuesSource() throws Exception {
byte[] byteQueryVector = new byte[] {1, 2, 9};
DoubleValues dv =
DoubleValuesSource.similarityToQueryVector(
searcher.reader.leaves().get(0), byteQueryVector, "knnByteField6");
assertTrue(dv.advanceExact(0));
assertEquals(
VectorSimilarityFunction.BINARY_HAMMING_DISTANCE.compare(
new byte[] {-127, 127, 127}, byteQueryVector),
dv.doubleValue(),
0.0001);
assertTrue(dv.advanceExact(1));
assertEquals(
VectorSimilarityFunction.BINARY_HAMMING_DISTANCE.compare(
new byte[] {1, -1, 13}, byteQueryVector),
dv.doubleValue(),
0.0001);
assertFalse(dv.advanceExact(2));
}

public void testFailuresWithSimilarityValuesSource() throws Exception {
float[] floatQueryVector = new float[] {1.1f, 2.2f, 3.3f};
byte[] byteQueryVector = new byte[] {-10, 20, 30};
Expand Down