Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,10 @@ protected TopDocs exactSearch(LeafReaderContext context, DocIdSetIterator accept

private Query createRewrittenQuery(IndexReader reader, TopDocs topK) {
int len = topK.scoreDocs.length;

assert len > 0;
float maxScore = topK.scoreDocs[0].score;

Arrays.sort(topK.scoreDocs, Comparator.comparingInt(a -> a.doc));
int[] docs = new int[len];
float[] scores = new float[len];
Expand All @@ -197,7 +201,7 @@ private Query createRewrittenQuery(IndexReader reader, TopDocs topK) {
scores[i] = topK.scoreDocs[i].score;
}
int[] segmentStarts = findSegmentStarts(reader, docs);
return new DocAndScoreQuery(docs, scores, segmentStarts, reader.getContext().id());
return new DocAndScoreQuery(docs, scores, maxScore, segmentStarts, reader.getContext().id());
}

private int[] findSegmentStarts(IndexReader reader, int[] docs) {
Expand Down Expand Up @@ -265,6 +269,7 @@ static class DocAndScoreQuery extends Query {

private final int[] docs;
private final float[] scores;
private final float maxScore;
private final int[] segmentStarts;
private final Object contextIdentity;

Expand All @@ -280,9 +285,11 @@ static class DocAndScoreQuery extends Query {
* @param contextIdentity an object identifying the reader context that was used to build this
* query
*/
DocAndScoreQuery(int[] docs, float[] scores, int[] segmentStarts, Object contextIdentity) {
DocAndScoreQuery(
int[] docs, float[] scores, float maxScore, int[] segmentStarts, Object contextIdentity) {
this.docs = docs;
this.scores = scores;
this.maxScore = maxScore;
this.segmentStarts = segmentStarts;
this.contextIdentity = contextIdentity;
}
Expand Down Expand Up @@ -343,11 +350,6 @@ public long cost() {

@Override
public float getMaxScore(int docId) {
docId += context.docBase;
float maxScore = 0;
for (int idx = Math.max(0, upTo); idx < upper && docs[idx] <= docId; idx++) {
maxScore = Math.max(maxScore, scores[idx]);
}
return maxScore * boost;
}

Expand All @@ -356,19 +358,6 @@ public float score() {
return scores[upTo] * boost;
}

@Override
public int advanceShallow(int docid) {
int start = Math.max(upTo, lower);
int docidIndex = Arrays.binarySearch(docs, start, upper, docid + context.docBase);
if (docidIndex < 0) {
docidIndex = -1 - docidIndex;
}
if (docidIndex >= upper) {
return NO_MORE_DOCS;
}
return docs[docidIndex];
}

/**
* move the implementation of docID() into a differently-named method so we can call it
* from DocIDSetIterator.docID() even though this class is anonymous
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -244,36 +244,6 @@ public void testDifferentReader() throws IOException {
}
}

public void testAdvanceShallow() throws IOException {
try (Directory d = newDirectory()) {
try (IndexWriter w = new IndexWriter(d, new IndexWriterConfig())) {
for (int j = 0; j < 5; j++) {
Document doc = new Document();
doc.add(getKnnVectorField("field", new float[] {j, j}));
w.addDocument(doc);
}
}
try (IndexReader reader = DirectoryReader.open(d)) {
IndexSearcher searcher = new IndexSearcher(reader);
AbstractKnnVectorQuery query = getKnnVectorQuery("field", new float[] {2, 3}, 3);
Query dasq = query.rewrite(searcher);
Scorer scorer =
dasq.createWeight(searcher, ScoreMode.COMPLETE, 1).scorer(reader.leaves().get(0));
// before advancing the iterator
assertEquals(1, scorer.advanceShallow(0));
assertEquals(1, scorer.advanceShallow(1));
assertEquals(NO_MORE_DOCS, scorer.advanceShallow(10));

// after advancing the iterator
scorer.iterator().advance(2);
assertEquals(2, scorer.advanceShallow(0));
assertEquals(2, scorer.advanceShallow(2));
assertEquals(3, scorer.advanceShallow(3));
assertEquals(NO_MORE_DOCS, scorer.advanceShallow(10));
}
}
}

public void testScoreEuclidean() throws IOException {
float[][] vectors = new float[5][];
for (int j = 0; j < 5; j++) {
Expand All @@ -291,9 +261,6 @@ public void testScoreEuclidean() throws IOException {
assertEquals(-1, scorer.docID());
expectThrows(ArrayIndexOutOfBoundsException.class, scorer::score);

// test getMaxScore
assertEquals(0, scorer.getMaxScore(-1), 0);
assertEquals(0, scorer.getMaxScore(0), 0);
// This is 1 / ((l2distance((2,3), (2, 2)) = 1) + 1) = 0.5
assertEquals(1 / 2f, scorer.getMaxScore(2), 0);
assertEquals(1 / 2f, scorer.getMaxScore(Integer.MAX_VALUE), 0);
Expand All @@ -304,6 +271,7 @@ public void testScoreEuclidean() throws IOException {
assertEquals(1 / 6f, scorer.score(), 0);
assertEquals(3, it.advance(3));
assertEquals(1 / 2f, scorer.score(), 0);

assertEquals(NO_MORE_DOCS, it.advance(4));
expectThrows(ArrayIndexOutOfBoundsException.class, scorer::score);
}
Expand All @@ -330,32 +298,30 @@ public void testScoreCosine() throws IOException {
assertEquals(-1, scorer.docID());
expectThrows(ArrayIndexOutOfBoundsException.class, scorer::score);

// test getMaxScore
assertEquals(0, scorer.getMaxScore(-1), 0);
/* maxAtZero = ((2,3) * (1, 1) = 5) / (||2, 3|| * ||1, 1|| = sqrt(26)), then
/* score0 = ((2,3) * (1, 1) = 5) / (||2, 3|| * ||1, 1|| = sqrt(26)), then
* normalized by (1 + x) /2.
*/
float maxAtZero =
float score0 =
(float) ((1 + (2 * 1 + 3 * 1) / Math.sqrt((2 * 2 + 3 * 3) * (1 * 1 + 1 * 1))) / 2);
assertEquals(maxAtZero, scorer.getMaxScore(0), 0.001);

/* max at 2 is actually the score for doc 1 which is the highest (since doc 1 vector (2, 4)
* is the closest to (2, 3)). This is ((2,3) * (2, 4) = 16) / (||2, 3|| * ||2, 4|| = sqrt(260)), then
/* score1 = ((2,3) * (2, 4) = 16) / (||2, 3|| * ||2, 4|| = sqrt(260)), then
* normalized by (1 + x) /2
*/
float expected =
float score1 =
(float) ((1 + (2 * 2 + 3 * 4) / Math.sqrt((2 * 2 + 3 * 3) * (2 * 2 + 4 * 4))) / 2);
assertEquals(expected, scorer.getMaxScore(2), 0);
assertEquals(expected, scorer.getMaxScore(Integer.MAX_VALUE), 0);

// doc 1 happens to have the maximum score
assertEquals(score1, scorer.getMaxScore(2), 0.0001);
assertEquals(score1, scorer.getMaxScore(Integer.MAX_VALUE), 0.0001);

DocIdSetIterator it = scorer.iterator();
assertEquals(3, it.cost());
assertEquals(0, it.nextDoc());
// doc 0 has (1, 1)
assertEquals(maxAtZero, scorer.score(), 0.0001);
assertEquals(score0, scorer.score(), 0.0001);
assertEquals(1, it.advance(1));
assertEquals(expected, scorer.score(), 0);
assertEquals(2, it.nextDoc());
assertEquals(score1, scorer.score(), 0.0001);

// since topK was 3
assertEquals(NO_MORE_DOCS, it.advance(4));
expectThrows(ArrayIndexOutOfBoundsException.class, scorer::score);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -133,32 +133,30 @@ public void testScoreDotProduct() throws IOException {
assertEquals(-1, scorer.docID());
expectThrows(ArrayIndexOutOfBoundsException.class, scorer::score);

// test getMaxScore
assertEquals(0, scorer.getMaxScore(-1), 0);
/* maxAtZero = ((2,3) * (1, 1) = 5) / (||2, 3|| * ||1, 1|| = sqrt(26)), then
/* score0 = ((2,3) * (1, 1) = 5) / (||2, 3|| * ||1, 1|| = sqrt(26)), then
* normalized by (1 + x) /2.
*/
float maxAtZero =
float score0 =
(float) ((1 + (2 * 1 + 3 * 1) / Math.sqrt((2 * 2 + 3 * 3) * (1 * 1 + 1 * 1))) / 2);
assertEquals(maxAtZero, scorer.getMaxScore(0), 0.001);

/* max at 2 is actually the score for doc 1 which is the highest (since doc 1 vector (2, 4)
* is the closest to (2, 3)). This is ((2,3) * (2, 4) = 16) / (||2, 3|| * ||2, 4|| = sqrt(260)), then
/* score1 = ((2,3) * (2, 4) = 16) / (||2, 3|| * ||2, 4|| = sqrt(260)), then
* normalized by (1 + x) /2
*/
float expected =
float score1 =
(float) ((1 + (2 * 2 + 3 * 4) / Math.sqrt((2 * 2 + 3 * 3) * (2 * 2 + 4 * 4))) / 2);
assertEquals(expected, scorer.getMaxScore(2), 0);
assertEquals(expected, scorer.getMaxScore(Integer.MAX_VALUE), 0);

// doc 1 happens to have the max score
assertEquals(score1, scorer.getMaxScore(2), 0.0001);
assertEquals(score1, scorer.getMaxScore(Integer.MAX_VALUE), 0.0001);

DocIdSetIterator it = scorer.iterator();
assertEquals(3, it.cost());
assertEquals(0, it.nextDoc());
// doc 0 has (1, 1)
assertEquals(maxAtZero, scorer.score(), 0.0001);
assertEquals(score0, scorer.score(), 0.0001);
assertEquals(1, it.advance(1));
assertEquals(expected, scorer.score(), 0);
assertEquals(2, it.nextDoc());
assertEquals(score1, scorer.score(), 0.0001);

// since topK was 3
assertEquals(NO_MORE_DOCS, it.advance(4));
expectThrows(ArrayIndexOutOfBoundsException.class, scorer::score);
Expand Down