Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

LUCENE-8952: Use a sort key instead of true distance in NearestNeighbor. #832

Merged
merged 3 commits into from
Aug 23, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions lucene/CHANGES.txt
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,8 @@ Improvements

* SOLR-13663: Introduce <SpanPositionRange> into XML Query Parser (Alessandro Benedetti via Mikhail Khludnev)

* LUCENE-8952: Use a sort key instead of true distance in NearestNeighbor (Julie Tibshirani).

Optimizations

* LUCENE-8922: DisjunctionMaxQuery more efficiently leverages impacts to skip
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.PointValues;
import org.apache.lucene.util.Bits;
import org.apache.lucene.util.SloppyMath;
import org.apache.lucene.util.bkd.BKDReader;

/**
Expand Down Expand Up @@ -104,7 +105,8 @@ public static TopFieldDocs nearest(IndexSearcher searcher, String field, double
ScoreDoc[] scoreDocs = new ScoreDoc[hits.length];
for(int i=0;i<hits.length;i++) {
NearestNeighbor.NearestHit hit = hits[i];
scoreDocs[i] = new FieldDoc(hit.docID, 0.0f, new Object[] {Double.valueOf(hit.distanceMeters)});
double hitDistance = SloppyMath.haversinMeters(hit.distanceSortKey);
scoreDocs[i] = new FieldDoc(hit.docID, 0.0f, new Object[] {Double.valueOf(hitDistance)});
}
return new TopFieldDocs(new TotalHits(totalHits, TotalHits.Relation.EQUAL_TO), scoreDocs, null);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,9 @@
import org.apache.lucene.util.Bits;
import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.SloppyMath;
import org.apache.lucene.util.bkd.BKDReader;
import org.apache.lucene.util.bkd.BKDReader.IndexTree;
import org.apache.lucene.util.bkd.BKDReader.IntersectState;
import org.apache.lucene.util.bkd.BKDReader;

import static org.apache.lucene.geo.GeoEncodingUtils.decodeLatitude;
import static org.apache.lucene.geo.GeoEncodingUtils.decodeLongitude;
Expand All @@ -48,19 +48,23 @@ static class Cell implements Comparable<Cell> {
final byte[] maxPacked;
final IndexTree index;

/** The closest possible distance of all points in this cell */
final double distanceMeters;
/**
* The closest distance from a point in this cell to the query point, computed as a sort key through
* {@link SloppyMath#haversinSortKey}. Note that this is an approximation to the closest distance,
* and there could be a point in the cell that is closer.
*/
final double distanceSortKey;

public Cell(IndexTree index, int readerIndex, byte[] minPacked, byte[] maxPacked, double distanceMeters) {
public Cell(IndexTree index, int readerIndex, byte[] minPacked, byte[] maxPacked, double distanceSortKey) {
this.index = index;
this.readerIndex = readerIndex;
this.minPacked = minPacked.clone();
this.maxPacked = maxPacked.clone();
this.distanceMeters = distanceMeters;
this.distanceSortKey = distanceSortKey;
}

public int compareTo(Cell other) {
return Double.compare(distanceMeters, other.distanceMeters);
return Double.compare(distanceSortKey, other.distanceSortKey);
}

@Override
Expand All @@ -69,7 +73,7 @@ public String toString() {
double minLon = decodeLongitude(minPacked, Integer.BYTES);
double maxLat = decodeLatitude(maxPacked, 0);
double maxLon = decodeLongitude(maxPacked, Integer.BYTES);
return "Cell(readerIndex=" + readerIndex + " nodeID=" + index.getNodeID() + " isLeaf=" + index.isLeafNode() + " lat=" + minLat + " TO " + maxLat + ", lon=" + minLon + " TO " + maxLon + "; distanceMeters=" + distanceMeters + ")";
return "Cell(readerIndex=" + readerIndex + " nodeID=" + index.getNodeID() + " isLeaf=" + index.isLeafNode() + " lat=" + minLat + " TO " + maxLat + ", lon=" + minLon + " TO " + maxLon + "; distanceSortKey=" + distanceSortKey + ")";
}
}

Expand Down Expand Up @@ -106,7 +110,8 @@ public void visit(int docID) {
private void maybeUpdateBBox() {
if (setBottomCounter < 1024 || (setBottomCounter & 0x3F) == 0x3F) {
NearestHit hit = hitQueue.peek();
Rectangle box = Rectangle.fromPointDistance(pointLat, pointLon, hit.distanceMeters);
Rectangle box = Rectangle.fromPointDistance(pointLat, pointLon,
SloppyMath.haversinMeters(hit.distanceSortKey));
//System.out.println(" update bbox to " + box);
minLat = box.minLat;
maxLat = box.maxLat;
Expand Down Expand Up @@ -134,8 +139,6 @@ public void visit(int docID, byte[] packedValue) {
return;
}

// TODO: work in int space, use haversinSortKey

double docLatitude = decodeLatitude(packedValue, 0);
double docLongitude = decodeLongitude(packedValue, Integer.BYTES);

Expand All @@ -147,21 +150,22 @@ public void visit(int docID, byte[] packedValue) {
return;
}

double distanceMeters = SloppyMath.haversinMeters(pointLat, pointLon, docLatitude, docLongitude);
// Use the haversin sort key when comparing hits, as it is faster to compute than the true distance.
double distanceSortKey = SloppyMath.haversinSortKey(pointLat, pointLon, docLatitude, docLongitude);

//System.out.println(" visit docID=" + docID + " distanceMeters=" + distanceMeters + " docLat=" + docLatitude + " docLon=" + docLongitude);
//System.out.println(" visit docID=" + docID + " distanceSortKey=" + distanceSortKey + " docLat=" + docLatitude + " docLon=" + docLongitude);

int fullDocID = curDocBase + docID;

if (hitQueue.size() == topN) {
// queue already full
NearestHit hit = hitQueue.peek();
//System.out.println(" bottom distanceMeters=" + hit.distanceMeters);
//System.out.println(" bottom distanceSortKey=" + hit.distanceSortKey);
// we don't collect docs in order here, so we must also test the tie-break case ourselves:
if (distanceMeters < hit.distanceMeters || (distanceMeters == hit.distanceMeters && fullDocID < hit.docID)) {
if (distanceSortKey < hit.distanceSortKey || (distanceSortKey == hit.distanceSortKey && fullDocID < hit.docID)) {
hitQueue.poll();
hit.docID = fullDocID;
hit.distanceMeters = distanceMeters;
hit.distanceSortKey = distanceSortKey;
hitQueue.offer(hit);
//System.out.println(" ** keep2, now bottom=" + hit);
maybeUpdateBBox();
Expand All @@ -170,7 +174,7 @@ public void visit(int docID, byte[] packedValue) {
} else {
NearestHit hit = new NearestHit();
hit.docID = fullDocID;
hit.distanceMeters = distanceMeters;
hit.distanceSortKey = distanceSortKey;
hitQueue.offer(hit);
//System.out.println(" ** keep1, now bottom=" + hit);
}
Expand All @@ -182,14 +186,18 @@ public Relation compare(byte[] minPackedValue, byte[] maxPackedValue) {
}
}

/** Holds one hit from {@link LatLonPointPrototypeQueries#nearest} */
/** Holds one hit from {@link NearestNeighbor#nearest} */
static class NearestHit {
public int docID;
public double distanceMeters;

/**
* The distance from the hit to the query point, computed as a sort key through {@link SloppyMath#haversinSortKey}.
*/
public double distanceSortKey;

@Override
public String toString() {
return "NearestHit(docID=" + docID + " distanceMeters=" + distanceMeters + ")";
return "NearestHit(docID=" + docID + " distanceSortKey=" + distanceSortKey + ")";
}
}

Expand All @@ -204,8 +212,8 @@ public static NearestHit[] nearest(double pointLat, double pointLon, List<BKDRea
final PriorityQueue<NearestHit> hitQueue = new PriorityQueue<>(n, new Comparator<NearestHit>() {
@Override
public int compare(NearestHit a, NearestHit b) {
// sort by opposite distanceMeters natural order
int cmp = Double.compare(a.distanceMeters, b.distanceMeters);
// sort by opposite distanceSortKey natural order
int cmp = Double.compare(a.distanceSortKey, b.distanceSortKey);
if (cmp != 0) {
return -cmp;
}
Expand Down Expand Up @@ -319,10 +327,10 @@ private static double approxBestDistance(double minLat, double maxLat, double mi
return 0.0;
}

double d1 = SloppyMath.haversinMeters(pointLat, pointLon, minLat, minLon);
double d2 = SloppyMath.haversinMeters(pointLat, pointLon, minLat, maxLon);
double d3 = SloppyMath.haversinMeters(pointLat, pointLon, maxLat, maxLon);
double d4 = SloppyMath.haversinMeters(pointLat, pointLon, maxLat, minLon);
double d1 = SloppyMath.haversinSortKey(pointLat, pointLon, minLat, minLon);
double d2 = SloppyMath.haversinSortKey(pointLat, pointLon, minLat, maxLon);
double d3 = SloppyMath.haversinSortKey(pointLat, pointLon, maxLat, maxLon);
double d4 = SloppyMath.haversinSortKey(pointLat, pointLon, maxLat, minLon);
return Math.min(Math.min(d1, d2), Math.min(d3, d4));
}

Expand Down
29 changes: 14 additions & 15 deletions lucene/sandbox/src/test/org/apache/lucene/search/TestNearest.java
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
import org.apache.lucene.document.LatLonPoint;
import org.apache.lucene.document.StoredField;
import org.apache.lucene.document.StringField;
import org.apache.lucene.search.NearestNeighbor.NearestHit;
import org.apache.lucene.geo.GeoEncodingUtils;
import org.apache.lucene.geo.GeoTestUtil;
import org.apache.lucene.index.DirectoryReader;
Expand Down Expand Up @@ -190,23 +189,22 @@ public void testNearestNeighborRandom() throws Exception {
double pointLon = GeoTestUtil.nextLongitude();

// dumb brute force search to get the expected result:
NearestHit[] expectedHits = new NearestHit[lats.length];
FieldDoc[] expectedHits = new FieldDoc[lats.length];
for(int id=0;id<lats.length;id++) {
NearestHit hit = new NearestHit();
hit.distanceMeters = SloppyMath.haversinMeters(pointLat, pointLon, lats[id], lons[id]);
hit.docID = id;
double distance = SloppyMath.haversinMeters(pointLat, pointLon, lats[id], lons[id]);
FieldDoc hit = new FieldDoc(id, 0.0f, new Object[] {Double.valueOf(distance)});
expectedHits[id] = hit;
}

Arrays.sort(expectedHits, new Comparator<NearestHit>() {
Arrays.sort(expectedHits, new Comparator<FieldDoc>() {
@Override
public int compare(NearestHit a, NearestHit b) {
int cmp = Double.compare(a.distanceMeters, b.distanceMeters);
public int compare(FieldDoc a, FieldDoc b) {
int cmp = Double.compare(((Double) a.fields[0]).doubleValue(), ((Double) b.fields[0]).doubleValue());
if (cmp != 0) {
return cmp;
}
// tie break by smaller docID:
return a.docID - b.docID;
return a.doc - b.doc;
}
});

Expand All @@ -221,22 +219,23 @@ public int compare(NearestHit a, NearestHit b) {

ScoreDoc[] hits = LatLonPointPrototypeQueries.nearest(s, "point", pointLat, pointLon, topN).scoreDocs;
for(int i=0;i<topN;i++) {
NearestHit expected = expectedHits[i];
FieldDoc expected = expectedHits[i];
FieldDoc expected2 = (FieldDoc) fieldDocs.scoreDocs[i];
FieldDoc actual = (FieldDoc) hits[i];
Document actualDoc = r.document(actual.doc);

if (VERBOSE) {
System.out.println("hit " + i);
System.out.println(" expected id=" + expected.docID + " lat=" + lats[expected.docID] + " lon=" + lons[expected.docID] + " distance=" + expected.distanceMeters + " meters");
System.out.println(" expected id=" + expected.doc+ " lat=" + lats[expected.doc] + " lon=" + lons[expected.doc]
+ " distance=" + ((Double) expected.fields[0]).doubleValue() + " meters");
System.out.println(" actual id=" + actualDoc.getField("id") + " distance=" + actual.fields[0] + " meters");
}

assertEquals(expected.docID, actual.doc);
assertEquals(expected.distanceMeters, ((Double) actual.fields[0]).doubleValue(), 0.0);
assertEquals(expected.doc, actual.doc);
assertEquals(((Double) expected.fields[0]).doubleValue(), ((Double) actual.fields[0]).doubleValue(), 0.0);

assertEquals(expected.docID, expected.docID);
assertEquals(((Double) expected2.fields[0]).doubleValue(), expected.distanceMeters, 0.0);
assertEquals(expected2.doc, actual.doc);
assertEquals(((Double) expected2.fields[0]).doubleValue(), ((Double) actual.fields[0]).doubleValue(), 0.0);
}
}

Expand Down