-
Notifications
You must be signed in to change notification settings - Fork 985
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add support for similarity-based vector searches #12679
Merged
Merged
Changes from 9 commits
Commits
Show all changes
10 commits
Select commit
Hold shift + click to select a range
a384967
Add support for radius-based vector searches
5fc970e
Address comments
5a6f555
Change naming from radius-based search to similarity-based search
cad5654
Refactor into a more appropriate query
5096790
Simplify the query
a40c388
Add tests
8aadd92
Merge branch 'main' into radius-based-vector-search
0c24d79
Minor fixes
f62fb74
Add a CHANGES.txt entry
07f4223
Update CHANGES.txt entry to include query and collector names
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
288 changes: 288 additions & 0 deletions
288
lucene/core/src/java/org/apache/lucene/search/AbstractVectorSimilarityQuery.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,288 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one or more | ||
* contributor license agreements. See the NOTICE file distributed with | ||
* this work for additional information regarding copyright ownership. | ||
* The ASF licenses this file to You under the Apache License, Version 2.0 | ||
* (the "License"); you may not use this file except in compliance with | ||
* the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
package org.apache.lucene.search; | ||
|
||
import java.io.IOException; | ||
import java.util.Arrays; | ||
import java.util.Comparator; | ||
import java.util.Objects; | ||
import org.apache.lucene.index.LeafReader; | ||
import org.apache.lucene.index.LeafReaderContext; | ||
import org.apache.lucene.util.BitSet; | ||
import org.apache.lucene.util.BitSetIterator; | ||
import org.apache.lucene.util.Bits; | ||
|
||
/** | ||
* Search for all (approximate) vectors above a similarity threshold. | ||
* | ||
* @lucene.experimental | ||
*/ | ||
abstract class AbstractVectorSimilarityQuery extends Query { | ||
protected final String field; | ||
protected final float traversalSimilarity, resultSimilarity; | ||
protected final Query filter; | ||
|
||
/** | ||
* Search for all (approximate) vectors above a similarity threshold using {@link | ||
* VectorSimilarityCollector}. If a filter is applied, it traverses as many nodes as the cost of | ||
* the filter, and then falls back to exact search if results are incomplete. | ||
* | ||
* @param field a field that has been indexed as a vector field. | ||
* @param traversalSimilarity (lower) similarity score for graph traversal. | ||
* @param resultSimilarity (higher) similarity score for result collection. | ||
* @param filter a filter applied before the vector search. | ||
*/ | ||
AbstractVectorSimilarityQuery( | ||
String field, float traversalSimilarity, float resultSimilarity, Query filter) { | ||
if (traversalSimilarity > resultSimilarity) { | ||
throw new IllegalArgumentException("traversalSimilarity should be <= resultSimilarity"); | ||
} | ||
this.field = Objects.requireNonNull(field, "field"); | ||
this.traversalSimilarity = traversalSimilarity; | ||
this.resultSimilarity = resultSimilarity; | ||
this.filter = filter; | ||
} | ||
|
||
abstract VectorScorer createVectorScorer(LeafReaderContext context) throws IOException; | ||
|
||
protected abstract TopDocs approximateSearch( | ||
LeafReaderContext context, Bits acceptDocs, int visitLimit) throws IOException; | ||
|
||
@Override | ||
public Weight createWeight(IndexSearcher searcher, ScoreMode scoreMode, float boost) | ||
throws IOException { | ||
return new Weight(this) { | ||
final Weight filterWeight = | ||
filter == null | ||
? null | ||
: searcher.createWeight(searcher.rewrite(filter), ScoreMode.COMPLETE_NO_SCORES, 1); | ||
|
||
@Override | ||
public Explanation explain(LeafReaderContext context, int doc) throws IOException { | ||
if (filterWeight != null) { | ||
Scorer filterScorer = filterWeight.scorer(context); | ||
if (filterScorer == null || filterScorer.iterator().advance(doc) > doc) { | ||
return Explanation.noMatch("Doc does not match the filter"); | ||
} | ||
} | ||
|
||
VectorScorer scorer = createVectorScorer(context); | ||
if (scorer == null) { | ||
return Explanation.noMatch("Not indexed as the correct vector field"); | ||
} else if (scorer.advanceExact(doc)) { | ||
float score = scorer.score(); | ||
if (score >= resultSimilarity) { | ||
return Explanation.match(boost * score, "Score above threshold"); | ||
} else { | ||
return Explanation.noMatch("Score below threshold"); | ||
} | ||
} else { | ||
return Explanation.noMatch("No vector found for doc"); | ||
} | ||
} | ||
|
||
@Override | ||
public Scorer scorer(LeafReaderContext context) throws IOException { | ||
@SuppressWarnings("resource") | ||
LeafReader leafReader = context.reader(); | ||
Bits liveDocs = leafReader.getLiveDocs(); | ||
|
||
// If there is no filter | ||
if (filterWeight == null) { | ||
// Return exhaustive results | ||
TopDocs results = approximateSearch(context, liveDocs, Integer.MAX_VALUE); | ||
return VectorSimilarityScorer.fromScoreDocs(this, boost, results.scoreDocs); | ||
} | ||
|
||
Scorer scorer = filterWeight.scorer(context); | ||
if (scorer == null) { | ||
// If the filter does not match any documents | ||
return null; | ||
} | ||
|
||
BitSet acceptDocs; | ||
if (liveDocs == null && scorer.iterator() instanceof BitSetIterator bitSetIterator) { | ||
// If there are no deletions, and matching docs are already cached | ||
acceptDocs = bitSetIterator.getBitSet(); | ||
} else { | ||
// Else collect all matching docs | ||
FilteredDocIdSetIterator filtered = | ||
new FilteredDocIdSetIterator(scorer.iterator()) { | ||
@Override | ||
protected boolean match(int doc) { | ||
return liveDocs == null || liveDocs.get(doc); | ||
} | ||
}; | ||
acceptDocs = BitSet.of(filtered, leafReader.maxDoc()); | ||
} | ||
|
||
int cardinality = acceptDocs.cardinality(); | ||
if (cardinality == 0) { | ||
// If there are no live matching docs | ||
return null; | ||
} | ||
|
||
// Perform an approximate search | ||
TopDocs results = approximateSearch(context, acceptDocs, cardinality); | ||
|
||
// If the limit was exhausted | ||
if (results.totalHits.relation == TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO) { | ||
// Return a lazy-loading iterator | ||
return VectorSimilarityScorer.fromAcceptDocs( | ||
this, | ||
boost, | ||
createVectorScorer(context), | ||
new BitSetIterator(acceptDocs, cardinality), | ||
resultSimilarity); | ||
} else { | ||
// Return an iterator over the collected results | ||
return VectorSimilarityScorer.fromScoreDocs(this, boost, results.scoreDocs); | ||
} | ||
} | ||
|
||
@Override | ||
public boolean isCacheable(LeafReaderContext ctx) { | ||
return true; | ||
} | ||
}; | ||
} | ||
|
||
@Override | ||
public void visit(QueryVisitor visitor) { | ||
if (visitor.acceptField(field)) { | ||
visitor.visitLeaf(this); | ||
} | ||
} | ||
|
||
@Override | ||
public boolean equals(Object o) { | ||
return sameClassAs(o) | ||
&& Objects.equals(field, ((AbstractVectorSimilarityQuery) o).field) | ||
&& Float.compare( | ||
((AbstractVectorSimilarityQuery) o).traversalSimilarity, traversalSimilarity) | ||
== 0 | ||
&& Float.compare(((AbstractVectorSimilarityQuery) o).resultSimilarity, resultSimilarity) | ||
== 0 | ||
&& Objects.equals(filter, ((AbstractVectorSimilarityQuery) o).filter); | ||
} | ||
|
||
@Override | ||
public int hashCode() { | ||
return Objects.hash(field, traversalSimilarity, resultSimilarity, filter); | ||
} | ||
|
||
private static class VectorSimilarityScorer extends Scorer { | ||
final DocIdSetIterator iterator; | ||
final float[] cachedScore; | ||
|
||
VectorSimilarityScorer(Weight weight, DocIdSetIterator iterator, float[] cachedScore) { | ||
super(weight); | ||
this.iterator = iterator; | ||
this.cachedScore = cachedScore; | ||
} | ||
|
||
static VectorSimilarityScorer fromScoreDocs(Weight weight, float boost, ScoreDoc[] scoreDocs) { | ||
// Sort in ascending order of docid | ||
Arrays.sort(scoreDocs, Comparator.comparingInt(scoreDoc -> scoreDoc.doc)); | ||
|
||
float[] cachedScore = new float[1]; | ||
DocIdSetIterator iterator = | ||
new DocIdSetIterator() { | ||
int index = -1; | ||
|
||
@Override | ||
public int docID() { | ||
if (index < 0) { | ||
return -1; | ||
} else if (index >= scoreDocs.length) { | ||
return NO_MORE_DOCS; | ||
} else { | ||
cachedScore[0] = boost * scoreDocs[index].score; | ||
return scoreDocs[index].doc; | ||
} | ||
} | ||
|
||
@Override | ||
public int nextDoc() { | ||
index++; | ||
return docID(); | ||
} | ||
|
||
@Override | ||
public int advance(int target) { | ||
index = | ||
Arrays.binarySearch( | ||
scoreDocs, | ||
new ScoreDoc(target, 0), | ||
Comparator.comparingInt(scoreDoc -> scoreDoc.doc)); | ||
if (index < 0) { | ||
index = -1 - index; | ||
} | ||
return docID(); | ||
} | ||
|
||
@Override | ||
public long cost() { | ||
return scoreDocs.length; | ||
} | ||
}; | ||
|
||
return new VectorSimilarityScorer(weight, iterator, cachedScore); | ||
} | ||
|
||
static VectorSimilarityScorer fromAcceptDocs( | ||
Weight weight, | ||
float boost, | ||
VectorScorer scorer, | ||
DocIdSetIterator acceptDocs, | ||
float threshold) { | ||
float[] cachedScore = new float[1]; | ||
DocIdSetIterator iterator = | ||
new FilteredDocIdSetIterator(acceptDocs) { | ||
@Override | ||
protected boolean match(int doc) throws IOException { | ||
// Compute the dot product | ||
float score = scorer.score(); | ||
cachedScore[0] = score * boost; | ||
return score >= threshold; | ||
} | ||
}; | ||
|
||
return new VectorSimilarityScorer(weight, iterator, cachedScore); | ||
} | ||
|
||
@Override | ||
public int docID() { | ||
return iterator.docID(); | ||
} | ||
|
||
@Override | ||
public DocIdSetIterator iterator() { | ||
return iterator; | ||
} | ||
|
||
@Override | ||
public float getMaxScore(int upTo) { | ||
return Float.POSITIVE_INFINITY; | ||
} | ||
|
||
@Override | ||
public float score() { | ||
return cachedScore[0]; | ||
} | ||
} | ||
} |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe add the vector query names?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry, didn't get what you mean here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Well, what are the query names? :D
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ahh got it.. Updated now :)