Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for similarity-based vector searches #12679

Merged
merged 10 commits into from
Dec 11, 2023
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion lucene/CHANGES.txt
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,10 @@ API Changes

New Features
---------------------
(No changes)

* GITHUB#12679: Add support for similarity-based vector searches. Finds all vectors scoring above a `resultSimilarity`
while traversing the HNSW graph till better-scoring nodes are available, or the best candidate is below a score of
`traversalSimilarity` in the lowest level. (Aditya Prakash, Kaival Parikh)
Copy link
Member

@benwtrent benwtrent Dec 7, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe add the vector query names?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, didn't get what you mean here?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add support for similarity-based vector searches

Well, what are the query names? :D

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ahh got it.. Updated now :)


Improvements
---------------------
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,288 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.lucene.search;

import java.io.IOException;
import java.util.Arrays;
import java.util.Comparator;
import java.util.Objects;
import org.apache.lucene.index.LeafReader;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.util.BitSet;
import org.apache.lucene.util.BitSetIterator;
import org.apache.lucene.util.Bits;

/**
* Search for all (approximate) vectors above a similarity threshold.
*
* @lucene.experimental
*/
abstract class AbstractVectorSimilarityQuery extends Query {
protected final String field;
protected final float traversalSimilarity, resultSimilarity;
protected final Query filter;

/**
* Search for all (approximate) vectors above a similarity threshold using {@link
* VectorSimilarityCollector}. If a filter is applied, it traverses as many nodes as the cost of
* the filter, and then falls back to exact search if results are incomplete.
*
* @param field a field that has been indexed as a vector field.
* @param traversalSimilarity (lower) similarity score for graph traversal.
* @param resultSimilarity (higher) similarity score for result collection.
* @param filter a filter applied before the vector search.
*/
AbstractVectorSimilarityQuery(
String field, float traversalSimilarity, float resultSimilarity, Query filter) {
if (traversalSimilarity > resultSimilarity) {
throw new IllegalArgumentException("traversalSimilarity should be <= resultSimilarity");
}
this.field = Objects.requireNonNull(field, "field");
this.traversalSimilarity = traversalSimilarity;
this.resultSimilarity = resultSimilarity;
this.filter = filter;
}

abstract VectorScorer createVectorScorer(LeafReaderContext context) throws IOException;

protected abstract TopDocs approximateSearch(
LeafReaderContext context, Bits acceptDocs, int visitLimit) throws IOException;

@Override
public Weight createWeight(IndexSearcher searcher, ScoreMode scoreMode, float boost)
throws IOException {
return new Weight(this) {
final Weight filterWeight =
filter == null
? null
: searcher.createWeight(searcher.rewrite(filter), ScoreMode.COMPLETE_NO_SCORES, 1);

@Override
public Explanation explain(LeafReaderContext context, int doc) throws IOException {
if (filterWeight != null) {
Scorer filterScorer = filterWeight.scorer(context);
if (filterScorer == null || filterScorer.iterator().advance(doc) > doc) {
return Explanation.noMatch("Doc does not match the filter");
}
}

VectorScorer scorer = createVectorScorer(context);
if (scorer == null) {
return Explanation.noMatch("Not indexed as the correct vector field");
} else if (scorer.advanceExact(doc)) {
float score = scorer.score();
if (score >= resultSimilarity) {
return Explanation.match(boost * score, "Score above threshold");
} else {
return Explanation.noMatch("Score below threshold");
}
} else {
return Explanation.noMatch("No vector found for doc");
}
}

@Override
public Scorer scorer(LeafReaderContext context) throws IOException {
@SuppressWarnings("resource")
LeafReader leafReader = context.reader();
Bits liveDocs = leafReader.getLiveDocs();

// If there is no filter
if (filterWeight == null) {
// Return exhaustive results
TopDocs results = approximateSearch(context, liveDocs, Integer.MAX_VALUE);
return VectorSimilarityScorer.fromScoreDocs(this, boost, results.scoreDocs);
}

Scorer scorer = filterWeight.scorer(context);
if (scorer == null) {
// If the filter does not match any documents
return null;
}

BitSet acceptDocs;
if (liveDocs == null && scorer.iterator() instanceof BitSetIterator bitSetIterator) {
// If there are no deletions, and matching docs are already cached
acceptDocs = bitSetIterator.getBitSet();
} else {
// Else collect all matching docs
FilteredDocIdSetIterator filtered =
new FilteredDocIdSetIterator(scorer.iterator()) {
@Override
protected boolean match(int doc) {
return liveDocs == null || liveDocs.get(doc);
}
};
acceptDocs = BitSet.of(filtered, leafReader.maxDoc());
}

int cardinality = acceptDocs.cardinality();
if (cardinality == 0) {
// If there are no live matching docs
return null;
}

// Perform an approximate search
TopDocs results = approximateSearch(context, acceptDocs, cardinality);

// If the limit was exhausted
if (results.totalHits.relation == TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO) {
// Return a lazy-loading iterator
return VectorSimilarityScorer.fromAcceptDocs(
this,
boost,
createVectorScorer(context),
new BitSetIterator(acceptDocs, cardinality),
resultSimilarity);
} else {
// Return an iterator over the collected results
return VectorSimilarityScorer.fromScoreDocs(this, boost, results.scoreDocs);
}
}

@Override
public boolean isCacheable(LeafReaderContext ctx) {
return true;
}
};
}

@Override
public void visit(QueryVisitor visitor) {
if (visitor.acceptField(field)) {
visitor.visitLeaf(this);
}
}

@Override
public boolean equals(Object o) {
return sameClassAs(o)
&& Objects.equals(field, ((AbstractVectorSimilarityQuery) o).field)
&& Float.compare(
((AbstractVectorSimilarityQuery) o).traversalSimilarity, traversalSimilarity)
== 0
&& Float.compare(((AbstractVectorSimilarityQuery) o).resultSimilarity, resultSimilarity)
== 0
&& Objects.equals(filter, ((AbstractVectorSimilarityQuery) o).filter);
}

@Override
public int hashCode() {
return Objects.hash(field, traversalSimilarity, resultSimilarity, filter);
}

private static class VectorSimilarityScorer extends Scorer {
final DocIdSetIterator iterator;
final float[] cachedScore;

VectorSimilarityScorer(Weight weight, DocIdSetIterator iterator, float[] cachedScore) {
super(weight);
this.iterator = iterator;
this.cachedScore = cachedScore;
}

static VectorSimilarityScorer fromScoreDocs(Weight weight, float boost, ScoreDoc[] scoreDocs) {
// Sort in ascending order of docid
Arrays.sort(scoreDocs, Comparator.comparingInt(scoreDoc -> scoreDoc.doc));

float[] cachedScore = new float[1];
DocIdSetIterator iterator =
new DocIdSetIterator() {
int index = -1;

@Override
public int docID() {
if (index < 0) {
return -1;
} else if (index >= scoreDocs.length) {
return NO_MORE_DOCS;
} else {
cachedScore[0] = boost * scoreDocs[index].score;
return scoreDocs[index].doc;
}
}

@Override
public int nextDoc() {
index++;
return docID();
}

@Override
public int advance(int target) {
index =
Arrays.binarySearch(
scoreDocs,
new ScoreDoc(target, 0),
Comparator.comparingInt(scoreDoc -> scoreDoc.doc));
if (index < 0) {
index = -1 - index;
}
return docID();
}

@Override
public long cost() {
return scoreDocs.length;
}
};

return new VectorSimilarityScorer(weight, iterator, cachedScore);
}

static VectorSimilarityScorer fromAcceptDocs(
Weight weight,
float boost,
VectorScorer scorer,
DocIdSetIterator acceptDocs,
float threshold) {
float[] cachedScore = new float[1];
DocIdSetIterator iterator =
new FilteredDocIdSetIterator(acceptDocs) {
@Override
protected boolean match(int doc) throws IOException {
// Compute the dot product
float score = scorer.score();
cachedScore[0] = score * boost;
return score >= threshold;
}
};

return new VectorSimilarityScorer(weight, iterator, cachedScore);
}

@Override
public int docID() {
return iterator.docID();
}

@Override
public DocIdSetIterator iterator() {
return iterator;
}

@Override
public float getMaxScore(int upTo) {
return Float.POSITIVE_INFINITY;
}

@Override
public float score() {
return cachedScore[0];
}
}
}
Loading
Loading