diff --git a/lucene/CHANGES.txt b/lucene/CHANGES.txt index befc14a5dc73..33896818f7df 100644 --- a/lucene/CHANGES.txt +++ b/lucene/CHANGES.txt @@ -95,6 +95,8 @@ Improvements * SOLR-13663: Introduce into XML Query Parser (Alessandro Benedetti via Mikhail Khludnev) +* LUCENE-8952: Use a sort key instead of true distance in NearestNeighbor (Julie Tibshirani). + Optimizations * LUCENE-8922: DisjunctionMaxQuery more efficiently leverages impacts to skip diff --git a/lucene/sandbox/src/java/org/apache/lucene/search/LatLonPointPrototypeQueries.java b/lucene/sandbox/src/java/org/apache/lucene/search/LatLonPointPrototypeQueries.java index 73cbbb87fe47..3c0d7ff3164f 100644 --- a/lucene/sandbox/src/java/org/apache/lucene/search/LatLonPointPrototypeQueries.java +++ b/lucene/sandbox/src/java/org/apache/lucene/search/LatLonPointPrototypeQueries.java @@ -27,6 +27,7 @@ import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.index.PointValues; import org.apache.lucene.util.Bits; +import org.apache.lucene.util.SloppyMath; import org.apache.lucene.util.bkd.BKDReader; /** @@ -104,7 +105,8 @@ public static TopFieldDocs nearest(IndexSearcher searcher, String field, double ScoreDoc[] scoreDocs = new ScoreDoc[hits.length]; for(int i=0;i { final byte[] maxPacked; final IndexTree index; - /** The closest possible distance of all points in this cell */ - final double distanceMeters; + /** + * The closest distance from a point in this cell to the query point, computed as a sort key through + * {@link SloppyMath#haversinSortKey}. Note that this is an approximation to the closest distance, + * and there could be a point in the cell that is closer. + */ + final double distanceSortKey; - public Cell(IndexTree index, int readerIndex, byte[] minPacked, byte[] maxPacked, double distanceMeters) { + public Cell(IndexTree index, int readerIndex, byte[] minPacked, byte[] maxPacked, double distanceSortKey) { this.index = index; this.readerIndex = readerIndex; this.minPacked = minPacked.clone(); this.maxPacked = maxPacked.clone(); - this.distanceMeters = distanceMeters; + this.distanceSortKey = distanceSortKey; } public int compareTo(Cell other) { - return Double.compare(distanceMeters, other.distanceMeters); + return Double.compare(distanceSortKey, other.distanceSortKey); } @Override @@ -69,7 +73,7 @@ public String toString() { double minLon = decodeLongitude(minPacked, Integer.BYTES); double maxLat = decodeLatitude(maxPacked, 0); double maxLon = decodeLongitude(maxPacked, Integer.BYTES); - return "Cell(readerIndex=" + readerIndex + " nodeID=" + index.getNodeID() + " isLeaf=" + index.isLeafNode() + " lat=" + minLat + " TO " + maxLat + ", lon=" + minLon + " TO " + maxLon + "; distanceMeters=" + distanceMeters + ")"; + return "Cell(readerIndex=" + readerIndex + " nodeID=" + index.getNodeID() + " isLeaf=" + index.isLeafNode() + " lat=" + minLat + " TO " + maxLat + ", lon=" + minLon + " TO " + maxLon + "; distanceSortKey=" + distanceSortKey + ")"; } } @@ -106,7 +110,8 @@ public void visit(int docID) { private void maybeUpdateBBox() { if (setBottomCounter < 1024 || (setBottomCounter & 0x3F) == 0x3F) { NearestHit hit = hitQueue.peek(); - Rectangle box = Rectangle.fromPointDistance(pointLat, pointLon, hit.distanceMeters); + Rectangle box = Rectangle.fromPointDistance(pointLat, pointLon, + SloppyMath.haversinMeters(hit.distanceSortKey)); //System.out.println(" update bbox to " + box); minLat = box.minLat; maxLat = box.maxLat; @@ -134,8 +139,6 @@ public void visit(int docID, byte[] packedValue) { return; } - // TODO: work in int space, use haversinSortKey - double docLatitude = decodeLatitude(packedValue, 0); double docLongitude = decodeLongitude(packedValue, Integer.BYTES); @@ -147,21 +150,22 @@ public void visit(int docID, byte[] packedValue) { return; } - double distanceMeters = SloppyMath.haversinMeters(pointLat, pointLon, docLatitude, docLongitude); + // Use the haversin sort key when comparing hits, as it is faster to compute than the true distance. + double distanceSortKey = SloppyMath.haversinSortKey(pointLat, pointLon, docLatitude, docLongitude); - //System.out.println(" visit docID=" + docID + " distanceMeters=" + distanceMeters + " docLat=" + docLatitude + " docLon=" + docLongitude); + //System.out.println(" visit docID=" + docID + " distanceSortKey=" + distanceSortKey + " docLat=" + docLatitude + " docLon=" + docLongitude); int fullDocID = curDocBase + docID; if (hitQueue.size() == topN) { // queue already full NearestHit hit = hitQueue.peek(); - //System.out.println(" bottom distanceMeters=" + hit.distanceMeters); + //System.out.println(" bottom distanceSortKey=" + hit.distanceSortKey); // we don't collect docs in order here, so we must also test the tie-break case ourselves: - if (distanceMeters < hit.distanceMeters || (distanceMeters == hit.distanceMeters && fullDocID < hit.docID)) { + if (distanceSortKey < hit.distanceSortKey || (distanceSortKey == hit.distanceSortKey && fullDocID < hit.docID)) { hitQueue.poll(); hit.docID = fullDocID; - hit.distanceMeters = distanceMeters; + hit.distanceSortKey = distanceSortKey; hitQueue.offer(hit); //System.out.println(" ** keep2, now bottom=" + hit); maybeUpdateBBox(); @@ -170,7 +174,7 @@ public void visit(int docID, byte[] packedValue) { } else { NearestHit hit = new NearestHit(); hit.docID = fullDocID; - hit.distanceMeters = distanceMeters; + hit.distanceSortKey = distanceSortKey; hitQueue.offer(hit); //System.out.println(" ** keep1, now bottom=" + hit); } @@ -182,14 +186,18 @@ public Relation compare(byte[] minPackedValue, byte[] maxPackedValue) { } } - /** Holds one hit from {@link LatLonPointPrototypeQueries#nearest} */ + /** Holds one hit from {@link NearestNeighbor#nearest} */ static class NearestHit { public int docID; - public double distanceMeters; + + /** + * The distance from the hit to the query point, computed as a sort key through {@link SloppyMath#haversinSortKey}. + */ + public double distanceSortKey; @Override public String toString() { - return "NearestHit(docID=" + docID + " distanceMeters=" + distanceMeters + ")"; + return "NearestHit(docID=" + docID + " distanceSortKey=" + distanceSortKey + ")"; } } @@ -204,8 +212,8 @@ public static NearestHit[] nearest(double pointLat, double pointLon, List hitQueue = new PriorityQueue<>(n, new Comparator() { @Override public int compare(NearestHit a, NearestHit b) { - // sort by opposite distanceMeters natural order - int cmp = Double.compare(a.distanceMeters, b.distanceMeters); + // sort by opposite distanceSortKey natural order + int cmp = Double.compare(a.distanceSortKey, b.distanceSortKey); if (cmp != 0) { return -cmp; } @@ -319,10 +327,10 @@ private static double approxBestDistance(double minLat, double maxLat, double mi return 0.0; } - double d1 = SloppyMath.haversinMeters(pointLat, pointLon, minLat, minLon); - double d2 = SloppyMath.haversinMeters(pointLat, pointLon, minLat, maxLon); - double d3 = SloppyMath.haversinMeters(pointLat, pointLon, maxLat, maxLon); - double d4 = SloppyMath.haversinMeters(pointLat, pointLon, maxLat, minLon); + double d1 = SloppyMath.haversinSortKey(pointLat, pointLon, minLat, minLon); + double d2 = SloppyMath.haversinSortKey(pointLat, pointLon, minLat, maxLon); + double d3 = SloppyMath.haversinSortKey(pointLat, pointLon, maxLat, maxLon); + double d4 = SloppyMath.haversinSortKey(pointLat, pointLon, maxLat, minLon); return Math.min(Math.min(d1, d2), Math.min(d3, d4)); } diff --git a/lucene/sandbox/src/test/org/apache/lucene/search/TestNearest.java b/lucene/sandbox/src/test/org/apache/lucene/search/TestNearest.java index 40c521e43c19..627b72458813 100644 --- a/lucene/sandbox/src/test/org/apache/lucene/search/TestNearest.java +++ b/lucene/sandbox/src/test/org/apache/lucene/search/TestNearest.java @@ -26,7 +26,6 @@ import org.apache.lucene.document.LatLonPoint; import org.apache.lucene.document.StoredField; import org.apache.lucene.document.StringField; -import org.apache.lucene.search.NearestNeighbor.NearestHit; import org.apache.lucene.geo.GeoEncodingUtils; import org.apache.lucene.geo.GeoTestUtil; import org.apache.lucene.index.DirectoryReader; @@ -190,23 +189,22 @@ public void testNearestNeighborRandom() throws Exception { double pointLon = GeoTestUtil.nextLongitude(); // dumb brute force search to get the expected result: - NearestHit[] expectedHits = new NearestHit[lats.length]; + FieldDoc[] expectedHits = new FieldDoc[lats.length]; for(int id=0;id() { + Arrays.sort(expectedHits, new Comparator() { @Override - public int compare(NearestHit a, NearestHit b) { - int cmp = Double.compare(a.distanceMeters, b.distanceMeters); + public int compare(FieldDoc a, FieldDoc b) { + int cmp = Double.compare(((Double) a.fields[0]).doubleValue(), ((Double) b.fields[0]).doubleValue()); if (cmp != 0) { return cmp; } // tie break by smaller docID: - return a.docID - b.docID; + return a.doc - b.doc; } }); @@ -221,22 +219,23 @@ public int compare(NearestHit a, NearestHit b) { ScoreDoc[] hits = LatLonPointPrototypeQueries.nearest(s, "point", pointLat, pointLon, topN).scoreDocs; for(int i=0;i