Skip to content

Commit

Permalink
LUCENE-8952: Use a sort key instead of true distance in NearestNeighb…
Browse files Browse the repository at this point in the history
…or. (#832)
  • Loading branch information
jtibshirani authored and iverase committed Aug 23, 2019
1 parent e94a7b0 commit f3c7bbf
Show file tree
Hide file tree
Showing 4 changed files with 52 additions and 41 deletions.
2 changes: 2 additions & 0 deletions lucene/CHANGES.txt
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@ Improvements

* SOLR-13663: Introduce <SpanPositionRange> into XML Query Parser (Alessandro Benedetti via Mikhail Khludnev)

* LUCENE-8952: Use a sort key instead of true distance in NearestNeighbor (Julie Tibshirani).

Optimizations

* LUCENE-8922: DisjunctionMaxQuery more efficiently leverages impacts to skip
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.PointValues;
import org.apache.lucene.util.Bits;
import org.apache.lucene.util.SloppyMath;
import org.apache.lucene.util.bkd.BKDReader;

/**
Expand Down Expand Up @@ -104,7 +105,8 @@ public static TopFieldDocs nearest(IndexSearcher searcher, String field, double
ScoreDoc[] scoreDocs = new ScoreDoc[hits.length];
for(int i=0;i<hits.length;i++) {
NearestNeighbor.NearestHit hit = hits[i];
scoreDocs[i] = new FieldDoc(hit.docID, 0.0f, new Object[] {Double.valueOf(hit.distanceMeters)});
double hitDistance = SloppyMath.haversinMeters(hit.distanceSortKey);
scoreDocs[i] = new FieldDoc(hit.docID, 0.0f, new Object[] {Double.valueOf(hitDistance)});
}
return new TopFieldDocs(new TotalHits(totalHits, TotalHits.Relation.EQUAL_TO), scoreDocs, null);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,9 @@
import org.apache.lucene.util.Bits;
import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.SloppyMath;
import org.apache.lucene.util.bkd.BKDReader;
import org.apache.lucene.util.bkd.BKDReader.IndexTree;
import org.apache.lucene.util.bkd.BKDReader.IntersectState;
import org.apache.lucene.util.bkd.BKDReader;

import static org.apache.lucene.geo.GeoEncodingUtils.decodeLatitude;
import static org.apache.lucene.geo.GeoEncodingUtils.decodeLongitude;
Expand All @@ -48,19 +48,23 @@ static class Cell implements Comparable<Cell> {
final byte[] maxPacked;
final IndexTree index;

/** The closest possible distance of all points in this cell */
final double distanceMeters;
/**
* The closest distance from a point in this cell to the query point, computed as a sort key through
* {@link SloppyMath#haversinSortKey}. Note that this is an approximation to the closest distance,
* and there could be a point in the cell that is closer.
*/
final double distanceSortKey;

public Cell(IndexTree index, int readerIndex, byte[] minPacked, byte[] maxPacked, double distanceMeters) {
public Cell(IndexTree index, int readerIndex, byte[] minPacked, byte[] maxPacked, double distanceSortKey) {
this.index = index;
this.readerIndex = readerIndex;
this.minPacked = minPacked.clone();
this.maxPacked = maxPacked.clone();
this.distanceMeters = distanceMeters;
this.distanceSortKey = distanceSortKey;
}

public int compareTo(Cell other) {
return Double.compare(distanceMeters, other.distanceMeters);
return Double.compare(distanceSortKey, other.distanceSortKey);
}

@Override
Expand All @@ -69,7 +73,7 @@ public String toString() {
double minLon = decodeLongitude(minPacked, Integer.BYTES);
double maxLat = decodeLatitude(maxPacked, 0);
double maxLon = decodeLongitude(maxPacked, Integer.BYTES);
return "Cell(readerIndex=" + readerIndex + " nodeID=" + index.getNodeID() + " isLeaf=" + index.isLeafNode() + " lat=" + minLat + " TO " + maxLat + ", lon=" + minLon + " TO " + maxLon + "; distanceMeters=" + distanceMeters + ")";
return "Cell(readerIndex=" + readerIndex + " nodeID=" + index.getNodeID() + " isLeaf=" + index.isLeafNode() + " lat=" + minLat + " TO " + maxLat + ", lon=" + minLon + " TO " + maxLon + "; distanceSortKey=" + distanceSortKey + ")";
}
}

Expand Down Expand Up @@ -106,7 +110,8 @@ public void visit(int docID) {
private void maybeUpdateBBox() {
if (setBottomCounter < 1024 || (setBottomCounter & 0x3F) == 0x3F) {
NearestHit hit = hitQueue.peek();
Rectangle box = Rectangle.fromPointDistance(pointLat, pointLon, hit.distanceMeters);
Rectangle box = Rectangle.fromPointDistance(pointLat, pointLon,
SloppyMath.haversinMeters(hit.distanceSortKey));
//System.out.println(" update bbox to " + box);
minLat = box.minLat;
maxLat = box.maxLat;
Expand Down Expand Up @@ -134,8 +139,6 @@ public void visit(int docID, byte[] packedValue) {
return;
}

// TODO: work in int space, use haversinSortKey

double docLatitude = decodeLatitude(packedValue, 0);
double docLongitude = decodeLongitude(packedValue, Integer.BYTES);

Expand All @@ -147,21 +150,22 @@ public void visit(int docID, byte[] packedValue) {
return;
}

double distanceMeters = SloppyMath.haversinMeters(pointLat, pointLon, docLatitude, docLongitude);
// Use the haversin sort key when comparing hits, as it is faster to compute than the true distance.
double distanceSortKey = SloppyMath.haversinSortKey(pointLat, pointLon, docLatitude, docLongitude);

//System.out.println(" visit docID=" + docID + " distanceMeters=" + distanceMeters + " docLat=" + docLatitude + " docLon=" + docLongitude);
//System.out.println(" visit docID=" + docID + " distanceSortKey=" + distanceSortKey + " docLat=" + docLatitude + " docLon=" + docLongitude);

int fullDocID = curDocBase + docID;

if (hitQueue.size() == topN) {
// queue already full
NearestHit hit = hitQueue.peek();
//System.out.println(" bottom distanceMeters=" + hit.distanceMeters);
//System.out.println(" bottom distanceSortKey=" + hit.distanceSortKey);
// we don't collect docs in order here, so we must also test the tie-break case ourselves:
if (distanceMeters < hit.distanceMeters || (distanceMeters == hit.distanceMeters && fullDocID < hit.docID)) {
if (distanceSortKey < hit.distanceSortKey || (distanceSortKey == hit.distanceSortKey && fullDocID < hit.docID)) {
hitQueue.poll();
hit.docID = fullDocID;
hit.distanceMeters = distanceMeters;
hit.distanceSortKey = distanceSortKey;
hitQueue.offer(hit);
//System.out.println(" ** keep2, now bottom=" + hit);
maybeUpdateBBox();
Expand All @@ -170,7 +174,7 @@ public void visit(int docID, byte[] packedValue) {
} else {
NearestHit hit = new NearestHit();
hit.docID = fullDocID;
hit.distanceMeters = distanceMeters;
hit.distanceSortKey = distanceSortKey;
hitQueue.offer(hit);
//System.out.println(" ** keep1, now bottom=" + hit);
}
Expand All @@ -182,14 +186,18 @@ public Relation compare(byte[] minPackedValue, byte[] maxPackedValue) {
}
}

/** Holds one hit from {@link LatLonPointPrototypeQueries#nearest} */
/** Holds one hit from {@link NearestNeighbor#nearest} */
static class NearestHit {
public int docID;
public double distanceMeters;

/**
* The distance from the hit to the query point, computed as a sort key through {@link SloppyMath#haversinSortKey}.
*/
public double distanceSortKey;

@Override
public String toString() {
return "NearestHit(docID=" + docID + " distanceMeters=" + distanceMeters + ")";
return "NearestHit(docID=" + docID + " distanceSortKey=" + distanceSortKey + ")";
}
}

Expand All @@ -204,8 +212,8 @@ public static NearestHit[] nearest(double pointLat, double pointLon, List<BKDRea
final PriorityQueue<NearestHit> hitQueue = new PriorityQueue<>(n, new Comparator<NearestHit>() {
@Override
public int compare(NearestHit a, NearestHit b) {
// sort by opposite distanceMeters natural order
int cmp = Double.compare(a.distanceMeters, b.distanceMeters);
// sort by opposite distanceSortKey natural order
int cmp = Double.compare(a.distanceSortKey, b.distanceSortKey);
if (cmp != 0) {
return -cmp;
}
Expand Down Expand Up @@ -319,10 +327,10 @@ private static double approxBestDistance(double minLat, double maxLat, double mi
return 0.0;
}

double d1 = SloppyMath.haversinMeters(pointLat, pointLon, minLat, minLon);
double d2 = SloppyMath.haversinMeters(pointLat, pointLon, minLat, maxLon);
double d3 = SloppyMath.haversinMeters(pointLat, pointLon, maxLat, maxLon);
double d4 = SloppyMath.haversinMeters(pointLat, pointLon, maxLat, minLon);
double d1 = SloppyMath.haversinSortKey(pointLat, pointLon, minLat, minLon);
double d2 = SloppyMath.haversinSortKey(pointLat, pointLon, minLat, maxLon);
double d3 = SloppyMath.haversinSortKey(pointLat, pointLon, maxLat, maxLon);
double d4 = SloppyMath.haversinSortKey(pointLat, pointLon, maxLat, minLon);
return Math.min(Math.min(d1, d2), Math.min(d3, d4));
}

Expand Down
29 changes: 14 additions & 15 deletions lucene/sandbox/src/test/org/apache/lucene/search/TestNearest.java
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
import org.apache.lucene.document.LatLonPoint;
import org.apache.lucene.document.StoredField;
import org.apache.lucene.document.StringField;
import org.apache.lucene.search.NearestNeighbor.NearestHit;
import org.apache.lucene.geo.GeoEncodingUtils;
import org.apache.lucene.geo.GeoTestUtil;
import org.apache.lucene.index.DirectoryReader;
Expand Down Expand Up @@ -190,23 +189,22 @@ public void testNearestNeighborRandom() throws Exception {
double pointLon = GeoTestUtil.nextLongitude();

// dumb brute force search to get the expected result:
NearestHit[] expectedHits = new NearestHit[lats.length];
FieldDoc[] expectedHits = new FieldDoc[lats.length];
for(int id=0;id<lats.length;id++) {
NearestHit hit = new NearestHit();
hit.distanceMeters = SloppyMath.haversinMeters(pointLat, pointLon, lats[id], lons[id]);
hit.docID = id;
double distance = SloppyMath.haversinMeters(pointLat, pointLon, lats[id], lons[id]);
FieldDoc hit = new FieldDoc(id, 0.0f, new Object[] {Double.valueOf(distance)});
expectedHits[id] = hit;
}

Arrays.sort(expectedHits, new Comparator<NearestHit>() {
Arrays.sort(expectedHits, new Comparator<FieldDoc>() {
@Override
public int compare(NearestHit a, NearestHit b) {
int cmp = Double.compare(a.distanceMeters, b.distanceMeters);
public int compare(FieldDoc a, FieldDoc b) {
int cmp = Double.compare(((Double) a.fields[0]).doubleValue(), ((Double) b.fields[0]).doubleValue());
if (cmp != 0) {
return cmp;
}
// tie break by smaller docID:
return a.docID - b.docID;
return a.doc - b.doc;
}
});

Expand All @@ -221,22 +219,23 @@ public int compare(NearestHit a, NearestHit b) {

ScoreDoc[] hits = LatLonPointPrototypeQueries.nearest(s, "point", pointLat, pointLon, topN).scoreDocs;
for(int i=0;i<topN;i++) {
NearestHit expected = expectedHits[i];
FieldDoc expected = expectedHits[i];
FieldDoc expected2 = (FieldDoc) fieldDocs.scoreDocs[i];
FieldDoc actual = (FieldDoc) hits[i];
Document actualDoc = r.document(actual.doc);

if (VERBOSE) {
System.out.println("hit " + i);
System.out.println(" expected id=" + expected.docID + " lat=" + lats[expected.docID] + " lon=" + lons[expected.docID] + " distance=" + expected.distanceMeters + " meters");
System.out.println(" expected id=" + expected.doc+ " lat=" + lats[expected.doc] + " lon=" + lons[expected.doc]
+ " distance=" + ((Double) expected.fields[0]).doubleValue() + " meters");
System.out.println(" actual id=" + actualDoc.getField("id") + " distance=" + actual.fields[0] + " meters");
}

assertEquals(expected.docID, actual.doc);
assertEquals(expected.distanceMeters, ((Double) actual.fields[0]).doubleValue(), 0.0);
assertEquals(expected.doc, actual.doc);
assertEquals(((Double) expected.fields[0]).doubleValue(), ((Double) actual.fields[0]).doubleValue(), 0.0);

assertEquals(expected.docID, expected.docID);
assertEquals(((Double) expected2.fields[0]).doubleValue(), expected.distanceMeters, 0.0);
assertEquals(expected2.doc, actual.doc);
assertEquals(((Double) expected2.fields[0]).doubleValue(), ((Double) actual.fields[0]).doubleValue(), 0.0);
}
}

Expand Down

0 comments on commit f3c7bbf

Please sign in to comment.