Skip to content

Commit

Permalink
After backporting LUCENE-10577, adapt to JDK11 and add CHANGES entry
Browse files Browse the repository at this point in the history
  • Loading branch information
Michael Sokolov committed Aug 24, 2022
1 parent a174412 commit 7d824b1
Show file tree
Hide file tree
Showing 10 changed files with 150 additions and 93 deletions.
4 changes: 2 additions & 2 deletions lucene/CHANGES.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ http://s.apache.org/luceneversions

API Changes
---------------------
(No changes)
* LUCENE-10577: Add VectorEncoding to enable byte-encoded HNSW vectors (Michael Sokolov, Julie Tibshirani)

New Features
---------------------
Expand Down Expand Up @@ -18692,4 +18692,4 @@ First open source release.

The code has been re-organized into a new package and directory
structure for this release. It builds OK, but has not been tested
beyond that since the re-organization.
beyond that since the re-organization.
Original file line number Diff line number Diff line change
Expand Up @@ -173,11 +173,16 @@ private void validateFieldEntry(FieldInfo info, FieldEntry fieldEntry) {
+ fieldEntry.dimension);
}

int byteSize =
switch (info.getVectorEncoding()) {
case BYTE -> Byte.BYTES;
case FLOAT32 -> Float.BYTES;
};
int byteSize;
switch (info.getVectorEncoding()) {
case BYTE:
byteSize = Byte.BYTES;
break;
default:
case FLOAT32:
byteSize = Float.BYTES;
break;
}
int numBytes = fieldEntry.size * dimension * byteSize;
if (numBytes != fieldEntry.vectorDataLength) {
throw new IllegalStateException(
Expand Down Expand Up @@ -296,11 +301,14 @@ public TopDocs searchExhaustively(
VectorSimilarityFunction similarityFunction = fieldEntry.similarityFunction;
VectorValues vectorValues = getVectorValues(field);

return switch (fieldEntry.vectorEncoding) {
case BYTE -> exhaustiveSearch(
vectorValues, acceptDocs, similarityFunction, toBytesRef(target), k);
case FLOAT32 -> exhaustiveSearch(vectorValues, acceptDocs, similarityFunction, target, k);
};
switch (fieldEntry.vectorEncoding) {
case BYTE:
return exhaustiveSearch(
vectorValues, acceptDocs, similarityFunction, toBytesRef(target), k);
default:
case FLOAT32:
return exhaustiveSearch(vectorValues, acceptDocs, similarityFunction, target, k);
}
}

/** Get knn graph values; used for testing */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -170,10 +170,13 @@ private void writeField(FieldWriter<?> fieldData, int maxDoc) throws IOException
// write vector values
long vectorDataOffset = vectorData.alignFilePointer(Float.BYTES);
switch (fieldData.fieldInfo.getVectorEncoding()) {
case BYTE -> writeByteVectors(fieldData);
case FLOAT32 -> writeFloat32Vectors(fieldData);
case BYTE:
writeByteVectors(fieldData);
break;
default:
case FLOAT32:
writeFloat32Vectors(fieldData);
}
;
long vectorDataLength = vectorData.getFilePointer() - vectorDataOffset;

// write graph
Expand Down Expand Up @@ -237,11 +240,17 @@ private void writeSortingField(FieldWriter<?> fieldData, int maxDoc, Sorter.DocM
}

// write vector values
long vectorDataOffset =
switch (fieldData.fieldInfo.getVectorEncoding()) {
case BYTE -> writeSortedByteVectors(fieldData, ordMap);
case FLOAT32 -> writeSortedFloat32Vectors(fieldData, ordMap);
};
long vectorDataOffset;
switch (fieldData.fieldInfo.getVectorEncoding()) {
case BYTE:
vectorDataOffset = writeSortedByteVectors(fieldData, ordMap);
break;
default:
case FLOAT32:
vectorDataOffset = writeSortedFloat32Vectors(fieldData, ordMap);
break;
}
;
long vectorDataLength = vectorData.getFilePointer() - vectorDataOffset;

// write graph
Expand Down Expand Up @@ -590,21 +599,24 @@ private abstract static class FieldWriter<T> extends KnnFieldVectorsWriter<T> {
static FieldWriter<?> create(FieldInfo fieldInfo, int M, int beamWidth, InfoStream infoStream)
throws IOException {
int dim = fieldInfo.getVectorDimension();
return switch (fieldInfo.getVectorEncoding()) {
case BYTE -> new FieldWriter<BytesRef>(fieldInfo, M, beamWidth, infoStream) {
@Override
public BytesRef copyValue(BytesRef value) {
return new BytesRef(
ArrayUtil.copyOfSubArray(value.bytes, value.offset, value.offset + dim));
}
};
case FLOAT32 -> new FieldWriter<float[]>(fieldInfo, M, beamWidth, infoStream) {
@Override
public float[] copyValue(float[] value) {
return ArrayUtil.copyOfSubArray(value, 0, dim);
}
};
};
switch (fieldInfo.getVectorEncoding()) {
case BYTE:
return new FieldWriter<BytesRef>(fieldInfo, M, beamWidth, infoStream) {
@Override
public BytesRef copyValue(BytesRef value) {
return new BytesRef(
ArrayUtil.copyOfSubArray(value.bytes, value.offset, value.offset + dim));
}
};
default:
case FLOAT32:
return new FieldWriter<float[]>(fieldInfo, M, beamWidth, infoStream) {
@Override
public float[] copyValue(float[] value) {
return ArrayUtil.copyOfSubArray(value, 0, dim);
}
};
}
}

@SuppressWarnings("unchecked")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -91,11 +91,7 @@ static OffHeapVectorValues load(
}
IndexInput bytesSlice =
vectorData.slice("vector-data", fieldEntry.vectorDataOffset, fieldEntry.vectorDataLength);
int byteSize =
switch (fieldEntry.vectorEncoding) {
case BYTE -> fieldEntry.dimension;
case FLOAT32 -> fieldEntry.dimension * Float.BYTES;
};
int byteSize = fieldEntry.dimension * fieldEntry.vectorEncoding.byteSize;
if (fieldEntry.docsWithFieldOffset == -1) {
return new DenseOffHeapVectorValues(
fieldEntry.dimension, fieldEntry.size, bytesSlice, byteSize);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -173,11 +173,17 @@ public void addValue(int docID, Object value) {
+ "\" appears more than once in this document (only one value is allowed per field)");
}
assert docID > lastDocID;
float[] vectorValue =
switch (fieldInfo.getVectorEncoding()) {
case FLOAT32 -> (float[]) value;
case BYTE -> bytesToFloats((BytesRef) value);
};
float[] vectorValue;
switch (fieldInfo.getVectorEncoding()) {
case BYTE:
vectorValue = bytesToFloats((BytesRef) value);
break;
default:
case FLOAT32:
vectorValue = (float[]) value;
break;
}
;
docsWithField.add(docID);
vectors.add(copyValue(vectorValue));
lastDocID = docID;
Expand Down
10 changes: 7 additions & 3 deletions lucene/core/src/java/org/apache/lucene/index/IndexingChain.java
Original file line number Diff line number Diff line change
Expand Up @@ -714,9 +714,13 @@ private boolean processField(int docID, IndexableField field, PerField pf) throw
}
if (fieldType.vectorDimension() != 0) {
switch (fieldType.vectorEncoding()) {
case BYTE -> pf.knnFieldVectorsWriter.addValue(docID, field.binaryValue());
case FLOAT32 -> pf.knnFieldVectorsWriter.addValue(
docID, ((KnnVectorField) field).vectorValue());
case BYTE:
pf.knnFieldVectorsWriter.addValue(docID, field.binaryValue());
break;
default:
case FLOAT32:
pf.knnFieldVectorsWriter.addValue(docID, ((KnnVectorField) field).vectorValue());
break;
}
}
return indexedField;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -195,10 +195,13 @@ public void addGraphNode(int node, RandomAccessVectorValues values) throws IOExc

@SuppressWarnings("unchecked")
private T getValue(int node, RandomAccessVectorValues values) throws IOException {
return switch (vectorEncoding) {
case BYTE -> (T) values.binaryValue(node);
case FLOAT32 -> (T) values.vectorValue(node);
};
switch (vectorEncoding) {
case BYTE:
return (T) values.binaryValue(node);
default:
case FLOAT32:
return (T) values.vectorValue(node);
}
}

private long printGraphBuildStatus(int node, long start, long t) {
Expand Down Expand Up @@ -280,10 +283,13 @@ private boolean diversityCheck(int candidate, float score, NeighborArray neighbo

private boolean isDiverse(int candidate, NeighborArray neighbors, float score)
throws IOException {
return switch (vectorEncoding) {
case BYTE -> isDiverse(vectors.binaryValue(candidate), neighbors, score);
case FLOAT32 -> isDiverse(vectors.vectorValue(candidate), neighbors, score);
};
switch (vectorEncoding) {
case BYTE:
return isDiverse(vectors.binaryValue(candidate), neighbors, score);
default:
case FLOAT32:
return isDiverse(vectors.vectorValue(candidate), neighbors, score);
}
}

private boolean isDiverse(float[] candidate, NeighborArray neighbors, float score)
Expand Down Expand Up @@ -325,12 +331,15 @@ private int findWorstNonDiverse(NeighborArray neighbors) throws IOException {

private boolean isWorstNonDiverse(
int candidate, NeighborArray neighbors, float minAcceptedSimilarity) throws IOException {
return switch (vectorEncoding) {
case BYTE -> isWorstNonDiverse(
candidate, vectors.binaryValue(candidate), neighbors, minAcceptedSimilarity);
case FLOAT32 -> isWorstNonDiverse(
candidate, vectors.vectorValue(candidate), neighbors, minAcceptedSimilarity);
};
switch (vectorEncoding) {
case BYTE:
return isWorstNonDiverse(
candidate, vectors.binaryValue(candidate), neighbors, minAcceptedSimilarity);
default:
case FLOAT32:
return isWorstNonDiverse(
candidate, vectors.vectorValue(candidate), neighbors, minAcceptedSimilarity);
}
}

private boolean isWorstNonDiverse(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -251,7 +251,7 @@ public void testOrthogonalCosineBytes() {
float[] v = new float[2];
v[0] = random().nextInt(100);
// ensure the vector is non-zero so that cosine is defined
v[1] = random().nextInt(1, 100);
v[1] = random().nextInt(99) + 1;
float[] u = new float[2];
u[0] = v[1];
u[1] = -v[0];
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -473,10 +473,13 @@ private abstract static class VectorReader {
static VectorReader create(FileChannel input, int dim, VectorEncoding vectorEncoding, int n)
throws IOException {
int bufferSize = n * dim * vectorEncoding.byteSize;
return switch (vectorEncoding) {
case BYTE -> new VectorReaderByte(input, dim, bufferSize);
case FLOAT32 -> new VectorReaderFloat32(input, dim, bufferSize);
};
switch (vectorEncoding) {
case BYTE:
return new VectorReaderByte(input, dim, bufferSize);
default:
case FLOAT32:
return new VectorReaderFloat32(input, dim, bufferSize);
}
}

VectorReader(FileChannel input, int dim, int bufferSize) throws IOException {
Expand Down Expand Up @@ -715,10 +718,15 @@ public KnnVectorsFormat getKnnVectorsFormatForField(String field) {
for (int i = 0; i < numDocs; i++) {
Document doc = new Document();
switch (vectorEncoding) {
case BYTE -> doc.add(
new KnnVectorField(
KNN_FIELD, ((VectorReaderByte) vectorReader).nextBytes(), fieldType));
case FLOAT32 -> doc.add(new KnnVectorField(KNN_FIELD, vectorReader.next(), fieldType));
case BYTE:
doc.add(
new KnnVectorField(
KNN_FIELD, ((VectorReaderByte) vectorReader).nextBytes(), fieldType));
break;
default:
case FLOAT32:
doc.add(new KnnVectorField(KNN_FIELD, vectorReader.next(), fieldType));
break;
}
doc.add(new StoredField(ID_FIELD, i));
iw.addDocument(doc);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,8 +78,13 @@ public void init() {
@Override
protected void addRandomFields(Document doc) {
switch (vectorEncoding) {
case BYTE -> doc.add(new KnnVectorField("v2", randomVector8(30), similarityFunction));
case FLOAT32 -> doc.add(new KnnVectorField("v2", randomVector(30), similarityFunction));
case BYTE:
doc.add(new KnnVectorField("v2", randomVector8(30), similarityFunction));
break;
default:
case FLOAT32:
doc.add(new KnnVectorField("v2", randomVector(30), similarityFunction));
break;
}
}

Expand Down Expand Up @@ -626,16 +631,21 @@ public void testSparseVectors() throws Exception {
String fieldName = "int" + field;
if (random().nextInt(100) == 17) {
switch (fieldVectorEncodings[field]) {
case BYTE -> {
BytesRef b = randomVector8(fieldDims[field]);
doc.add(new KnnVectorField(fieldName, b, fieldSimilarityFunctions[field]));
fieldTotals[field] += b.bytes[b.offset];
}
case FLOAT32 -> {
float[] v = randomVector(fieldDims[field]);
doc.add(new KnnVectorField(fieldName, v, fieldSimilarityFunctions[field]));
fieldTotals[field] += v[0];
}
case BYTE:
{
BytesRef b = randomVector8(fieldDims[field]);
doc.add(new KnnVectorField(fieldName, b, fieldSimilarityFunctions[field]));
fieldTotals[field] += b.bytes[b.offset];
break;
}
default:
case FLOAT32:
{
float[] v = randomVector(fieldDims[field]);
doc.add(new KnnVectorField(fieldName, v, fieldSimilarityFunctions[field]));
fieldTotals[field] += v[0];
break;
}
}
fieldDocCounts[field]++;
}
Expand Down Expand Up @@ -1290,16 +1300,20 @@ public void testVectorValuesReportCorrectDocs() throws Exception {
doc.add(new StoredField("id", docID));
if (random().nextInt(4) == 3) {
switch (vectorEncoding) {
case BYTE -> {
BytesRef b = randomVector8(dim);
fieldValuesCheckSum += b.bytes[b.offset];
doc.add(new KnnVectorField("knn_vector", b, similarityFunction));
}
case FLOAT32 -> {
float[] v = randomVector(dim);
fieldValuesCheckSum += v[0];
doc.add(new KnnVectorField("knn_vector", v, similarityFunction));
}
case BYTE:
{
BytesRef b = randomVector8(dim);
fieldValuesCheckSum += b.bytes[b.offset];
doc.add(new KnnVectorField("knn_vector", b, similarityFunction));
break;
}
case FLOAT32:
{
float[] v = randomVector(dim);
fieldValuesCheckSum += v[0];
doc.add(new KnnVectorField("knn_vector", v, similarityFunction));
break;
}
}
fieldDocCount++;
fieldSumDocIDs += docID;
Expand Down

0 comments on commit 7d824b1

Please sign in to comment.