diff --git a/lucene/CHANGES.txt b/lucene/CHANGES.txt index 8ba40ef88d06..350c1ffa40cf 100644 --- a/lucene/CHANGES.txt +++ b/lucene/CHANGES.txt @@ -984,6 +984,8 @@ API Changes * GITHUB#13820, GITHUB#13825, GITHUB#13830: Corrects DataInput.readGroupVInts to be public and not-final, removes the protected DataInput.readGroupVInt method. (Zhang Chao, Robert Muir, Uwe Schindler, Dawid Weiss) +* GITHUB#15376, GITHUB#15197: Added prefetching in bkd tree traversal, couple of new api in PointValues visitDocIDs from a position and prepareOrVisitDocIDs to prefetch the IO before visiting docIds (Saurabh Singh) + New Features --------------------- diff --git a/lucene/core/src/java/org/apache/lucene/index/CheckIndex.java b/lucene/core/src/java/org/apache/lucene/index/CheckIndex.java index dc3242722636..ca3d1fdccaf1 100644 --- a/lucene/core/src/java/org/apache/lucene/index/CheckIndex.java +++ b/lucene/core/src/java/org/apache/lucene/index/CheckIndex.java @@ -3194,7 +3194,7 @@ private static void checkByteVectorValues( * * @lucene.internal */ - public static class VerifyPointsVisitor implements PointValues.IntersectVisitor { + public static class VerifyPointsVisitor implements IntersectVisitor { private long pointCountSeen; private int lastDocID = -1; private final FixedBitSet docsSeen; diff --git a/lucene/core/src/java/org/apache/lucene/index/PointValues.java b/lucene/core/src/java/org/apache/lucene/index/PointValues.java index c77eec0e5ffd..593e81b79ff7 100644 --- a/lucene/core/src/java/org/apache/lucene/index/PointValues.java +++ b/lucene/core/src/java/org/apache/lucene/index/PointValues.java @@ -345,13 +345,15 @@ default void grow(int count) {} * Finds all documents and points matching the provided visitor. This method does not enforce live * documents, so it's up to the caller to test whether each document is deleted, if necessary. */ - public final void intersect(IntersectVisitor visitor) throws IOException { + public void intersect(IntersectVisitor visitor) throws IOException { final PointTree pointTree = getPointTree(); intersect(visitor, pointTree); assert pointTree.moveToParent() == false; } - private static void intersect(IntersectVisitor visitor, PointTree pointTree) throws IOException { + /** Finds all documents and points matching the provided visitor for the provided point tree. */ + protected static void intersect(IntersectVisitor visitor, PointTree pointTree) + throws IOException { while (true) { Relation compare = visitor.compare(pointTree.getMinPackedValue(), pointTree.getMaxPackedValue()); diff --git a/lucene/core/src/java/org/apache/lucene/util/bkd/BKDReader.java b/lucene/core/src/java/org/apache/lucene/util/bkd/BKDReader.java index 9c991e6b1b4a..9fa956db5bf4 100644 --- a/lucene/core/src/java/org/apache/lucene/util/bkd/BKDReader.java +++ b/lucene/core/src/java/org/apache/lucene/util/bkd/BKDReader.java @@ -17,7 +17,9 @@ package org.apache.lucene.util.bkd; import java.io.IOException; +import java.util.ArrayList; import java.util.Arrays; +import java.util.List; import org.apache.lucene.codecs.CodecUtil; import org.apache.lucene.index.CorruptIndexException; import org.apache.lucene.index.PointValues; @@ -589,6 +591,65 @@ public void visitDocIDs(PointValues.IntersectVisitor visitor) throws IOException addAll(visitor, false); } + /** prefetch DocIds below current node */ + public void prefetchDocIDs(TwoPhaseIntersectVisitor visitor) throws IOException { + resetNodeDataPosition(); + prefetchAll(visitor, false); + } + + /** visit Doc Ids for a leafNode at provided input position */ + public void visitDocIDs(long position, IntersectVisitor visitor) throws IOException { + visitDocIDs(position, visitor, false); + } + + private void visitDocIDs(long position, IntersectVisitor visitor, boolean grown) + throws IOException { + leafNodes.seek(position); + int count = leafNodes.readVInt(); + if (!grown) { + visitor.grow(count); + } + docIdsWriter.readInts(leafNodes, count, visitor, scratchIterator.docIDs); + } + + private int getLeafNodeOrdinal() { + assert isLeafNode() : "nodeID=" + nodeID + " is not a leaf"; + return nodeID - leafNodeOffset; + } + + public void prefetchAll(TwoPhaseIntersectVisitor visitor, boolean grown) throws IOException { + if (grown == false) { + final long size = size(); + if (size <= Integer.MAX_VALUE) { + visitor.grow((int) size); + grown = true; + } + } + if (isLeafNode()) { + // int count = isLastLeaf() ? config.maxPointsInLeafNode() : lastLeafNodePointCount; + long leafFp = getLeafBlockFP(); + int leafNodeOrdinal = getLeafNodeOrdinal(); + // Only call prefetch is this is the first leaf node ordinal or the first match in + // contigiuous sequence of matches for leaf nodes + // boolean prefetched = false; + if (visitor.lastDeferredBlockOrdinal() == -1 + || visitor.lastDeferredBlockOrdinal() + 1 < leafNodeOrdinal) { + // System.out.println("Prefetched called on " + leafNodeOrdinal); + leafNodes.prefetch(leafFp, 1); + // prefetched = true; + } + visitor.setLastDeferredBlockOrdinal(leafNodeOrdinal); + visitor.deferBlock(leafFp); + } else { + pushLeft(); + prefetchAll(visitor, grown); + pop(); + pushRight(); + prefetchAll(visitor, grown); + pop(); + } + } + public void addAll(PointValues.IntersectVisitor visitor, boolean grown) throws IOException { if (grown == false) { final long size = size(); @@ -1076,4 +1137,123 @@ public long cost() { return length; } } + + /** + * We can recurse the {@link BKDPointTree} using {@link TwoPhaseIntersectVisitor}. This visitor + * travere {@link BKDPointTree} in two phases. In the first phase, it recurses over the {@link + * BKDPointTree} optionally triggering IO for some of the blocks and caching them. In the second + * phase, once the recursion is over it visits the cached blocks one by one. + * + * @lucene.experimental + */ + public interface TwoPhaseIntersectVisitor extends IntersectVisitor { + /** return the last deferred block ordinal during recursion. */ + public int lastDeferredBlockOrdinal(); + + /** set last deferred block ordinal */ + public void setLastDeferredBlockOrdinal(int leafNodeOrdinal); + + /** Defer this block for processing in the second phase. */ + public void deferBlock(long leafFp); + + /** Returns a snapshot of the currently deferred blocks. */ + public List deferredBlocks(); + + /** Mark the given block as processed and remove it from the deferred set. */ + public void onProcessingDeferredBlock(long leafFp); + } + + /** + * Base implementation of {@link TwoPhaseIntersectVisitor} that maintains a list of deferred + * blocks from first phase of traversal and visits them in the second phase. + * + * @lucene.experimental + */ + public abstract static class BaseTwoPhaseIntersectVisitor implements TwoPhaseIntersectVisitor { + + int lastDeferredBlockOrdinal = -1; + List deferredBlocks = new ArrayList<>(); + + /** + * return the last deferred block ordinal - this is used to avoid prefetching call for + * contiguous ordinals assuming contiguous ordinals prefetching can be taken care by readaheads. + */ + @Override + public int lastDeferredBlockOrdinal() { + return lastDeferredBlockOrdinal; + } + + /** set last deferred block ordinal * */ + @Override + public void setLastDeferredBlockOrdinal(int leafNodeOrdinal) { + lastDeferredBlockOrdinal = leafNodeOrdinal; + } + + /** Defer this block for processing in the second phase. */ + @Override + public void deferBlock(long leafFp) { + deferredBlocks.add(leafFp); + } + + /** Returns a snapshot of the currently deferred blocks. */ + @Override + public List deferredBlocks() { + return new ArrayList<>(deferredBlocks); + } + + /** Mark the given block as processed and remove it from the deferred set. */ + @Override + public void onProcessingDeferredBlock(long leafFp) { + deferredBlocks.remove(leafFp); + } + } + + /** + * Finds all documents and points matching the provided visitor. This method does not enforce live + * documents, so it's up to the caller to test whether each document is deleted, if necessary. + */ + @Override + public final void intersect(IntersectVisitor visitor) throws IOException { + final BKDPointTree pointTree = (BKDPointTree) getPointTree(); + if (visitor instanceof TwoPhaseIntersectVisitor twoPhaseIntersectVisitor) { + intersect(twoPhaseIntersectVisitor, pointTree); + List fps = twoPhaseIntersectVisitor.deferredBlocks(); + for (int i = 0; i < fps.size(); ++i) { + long fp = fps.get(i); + pointTree.visitDocIDs(fp, visitor); + twoPhaseIntersectVisitor.onProcessingDeferredBlock(fp); + } + } else { + intersect(visitor, pointTree); + } + assert pointTree.moveToParent() == false; + } + + private static void intersect(TwoPhaseIntersectVisitor visitor, BKDPointTree pointTree) + throws IOException { + while (true) { + Relation compare = + visitor.compare(pointTree.getMinPackedValue(), pointTree.getMaxPackedValue()); + if (compare == Relation.CELL_INSIDE_QUERY) { + // This cell is fully inside the query shape: recursively prefetch all points in this cell + // without filtering + pointTree.prefetchDocIDs(visitor); + } else if (compare == Relation.CELL_CROSSES_QUERY) { + // The cell crosses the shape boundary, or the cell fully contains the query, so we fall + // through and do full filtering: + if (pointTree.moveToChild()) { + continue; + } + // TODO: we can assert that the first value here in fact matches what the pointTree + // claimed? + // Leaf node; scan and filter all points in this block: + pointTree.visitDocValues(visitor); + } + while (pointTree.moveToSibling() == false) { + if (pointTree.moveToParent() == false) { + return; + } + } + } + } } diff --git a/lucene/core/src/test/org/apache/lucene/codecs/lucene90/TestLucene90PointsFormat.java b/lucene/core/src/test/org/apache/lucene/codecs/lucene90/TestLucene90PointsFormat.java index f5940e7ad3d2..c099fbbc6b8c 100644 --- a/lucene/core/src/test/org/apache/lucene/codecs/lucene90/TestLucene90PointsFormat.java +++ b/lucene/core/src/test/org/apache/lucene/codecs/lucene90/TestLucene90PointsFormat.java @@ -18,6 +18,8 @@ import java.io.IOException; import java.util.Arrays; +import java.util.BitSet; +import java.util.List; import org.apache.lucene.codecs.Codec; import org.apache.lucene.codecs.FilterCodec; import org.apache.lucene.codecs.PointsFormat; @@ -39,7 +41,10 @@ import org.apache.lucene.tests.index.BasePointsFormatTestCase; import org.apache.lucene.tests.index.MockRandomMergePolicy; import org.apache.lucene.tests.util.TestUtil; +import org.apache.lucene.util.IOUtils; +import org.apache.lucene.util.NumericUtils; import org.apache.lucene.util.bkd.BKDConfig; +import org.apache.lucene.util.bkd.BKDReader; public class TestLucene90PointsFormat extends BasePointsFormatTestCase { @@ -355,4 +360,134 @@ public Relation compare(byte[] minPackedValue, byte[] maxPackedValue) { r.close(); dir.close(); } + + public void testBasicWithPrefetchVisitor() throws Exception { + Directory dir = newDirectory(); + IndexWriterConfig iwc = newIndexWriterConfig(); + // Avoid mockRandomMP since it may cause non-optimal merges that make the + // number of points per leaf hard to predict + while (iwc.getMergePolicy() instanceof MockRandomMergePolicy) { + iwc.setMergePolicy(newMergePolicy()); + } + IndexWriter w = new IndexWriter(dir, iwc); + byte[] pointValue = new byte[3]; + byte[] uniquePointValue = new byte[3]; + random().nextBytes(uniquePointValue); + final int numDocs = + TEST_NIGHTLY ? atLeast(10000) : atLeast(500); // at night, make sure we have several leaves + final boolean multiValues = random().nextBoolean(); + int totalValues = 0; + for (int i = 0; i < numDocs; ++i) { + Document doc = new Document(); + if (i == numDocs / 2) { + totalValues++; + doc.add(new BinaryPoint("f", uniquePointValue)); + } else { + final int numValues = (multiValues) ? TestUtil.nextInt(random(), 2, 100) : 1; + for (int j = 0; j < numValues; j++) { + do { + random().nextBytes(pointValue); + } while (Arrays.equals(pointValue, uniquePointValue)); + doc.add(new BinaryPoint("f", pointValue)); + totalValues++; + } + } + w.addDocument(doc); + } + w.forceMerge(1); + final IndexReader r = DirectoryReader.open(w); + w.close(); + + final LeafReader lr = getOnlyLeafReader(r); + PointValues points = lr.getPointValues("f"); + + BKDReader.BaseTwoPhaseIntersectVisitor allPointsVisitor = + new BKDReader.BaseTwoPhaseIntersectVisitor() { + @Override + public void visit(int docID, byte[] packedValue) throws IOException {} + + @Override + public void visit(int docID) throws IOException {} + + @Override + public Relation compare(byte[] minPackedValue, byte[] maxPackedValue) { + return Relation.CELL_INSIDE_QUERY; + } + }; + + List savedBlocks = allPointsVisitor.deferredBlocks(); + assertEquals(0, savedBlocks.size()); // Test that all deferred blocks were processed + assertEquals(totalValues, points.estimatePointCount(allPointsVisitor)); + assertEquals(numDocs, points.estimateDocCount(allPointsVisitor)); + + r.close(); + dir.close(); + } + + public void testBasicWithPrefetchCapableVisitor() throws Exception { + Directory dir = newDirectory(); + IndexWriterConfig iwc = newIndexWriterConfig(); + iwc.setMergePolicy(newLogMergePolicy()); + IndexWriter w = new IndexWriter(dir, iwc); + byte[] point = new byte[4]; + for (int i = 0; i < 20; i++) { + Document doc = new Document(); + NumericUtils.intToSortableBytes(i, point, 0); + doc.add(new BinaryPoint("dim", point)); + w.addDocument(doc); + } + w.forceMerge(1); + w.close(); + + DirectoryReader r = DirectoryReader.open(dir); + LeafReader sub = getOnlyLeafReader(r); + PointValues values = sub.getPointValues("dim"); + + // Simple test: make sure prefetch capable visitor can visit every doc when cell crosses query: + BitSet seen = new BitSet(); + values.intersect( + new BKDReader.BaseTwoPhaseIntersectVisitor() { + @Override + public Relation compare(byte[] minPacked, byte[] maxPacked) { + return Relation.CELL_CROSSES_QUERY; + } + + @Override + public void visit(int docID) { + throw new IllegalStateException(); + } + + @Override + public void visit(int docID, byte[] packedValue) { + seen.set(docID); + assertEquals(docID, NumericUtils.sortableBytesToInt(packedValue, 0)); + } + }); + assertEquals(20, seen.cardinality()); + // Make sure prefetch capable visitor can visit all docs when all docs are inside query + // Also test we are not visiting documents twice based on whether PointTree has a prefetch + // implementation of + // prepareOrVisit or uses the default implementation + seen.clear(); + final int[] docCount = {0}; + values.intersect( + new BKDReader.BaseTwoPhaseIntersectVisitor() { + @Override + public void visit(int docID) throws IOException { + seen.set(docID); + docCount[0]++; + } + + @Override + public void visit(int docID, byte[] packedValue) throws IOException {} + + @Override + public Relation compare(byte[] minPackedValue, byte[] maxPackedValue) { + return Relation.CELL_INSIDE_QUERY; + } + }); + assertEquals(20, seen.cardinality()); + assertEquals(20, docCount[0]); + IOUtils.close(r, dir); + } } diff --git a/lucene/core/src/test/org/apache/lucene/util/bkd/TestBKD.java b/lucene/core/src/test/org/apache/lucene/util/bkd/TestBKD.java index 8eb513bf7b6b..af7fc114d76d 100644 --- a/lucene/core/src/test/org/apache/lucene/util/bkd/TestBKD.java +++ b/lucene/core/src/test/org/apache/lucene/util/bkd/TestBKD.java @@ -80,7 +80,10 @@ public void testBasicInts1D() throws Exception { final BitSet hits = new BitSet(); r.intersect(getIntersectVisitor(hits, queryMin, queryMax, config)); - + final BitSet hitsFromTwoPhaseVisitor = new BitSet(); + r.intersect( + getTwoPhaseIntersectVisitor(hitsFromTwoPhaseVisitor, queryMin, queryMax, config)); + assertEquals(hits, hitsFromTwoPhaseVisitor); for (int docID = 0; docID < 100; docID++) { boolean expected = docID >= 42 && docID <= 87; boolean actual = hits.get(docID); @@ -177,7 +180,11 @@ public void testRandomIntsNDims() throws Exception { final BitSet hits = new BitSet(); r.intersect(getIntersectVisitor(hits, queryMinBytes, queryMaxBytes, config)); - + final BitSet hitsWithTwoPhaseVisitor = new BitSet(); + r.intersect( + getTwoPhaseIntersectVisitor( + hitsWithTwoPhaseVisitor, queryMinBytes, queryMaxBytes, config)); + assertEquals(hitsWithTwoPhaseVisitor, hits); for (int docID = 0; docID < numDocs; docID++) { int[] docValues = docs[docID]; boolean expected = true; @@ -265,7 +272,11 @@ public void testBigIntNDims() throws Exception { final BitSet hits = new BitSet(); pointValues.intersect(getIntersectVisitor(hits, queryMinBytes, queryMaxBytes, config)); - + final BitSet hitsWithTwoPhaseVisitor = new BitSet(); + pointValues.intersect( + getTwoPhaseIntersectVisitor( + hitsWithTwoPhaseVisitor, queryMinBytes, queryMaxBytes, config)); + assertEquals(hitsWithTwoPhaseVisitor, hits); for (int docID = 0; docID < numDocs; docID++) { BigInteger[] docValues = docs[docID]; boolean expected = true; @@ -854,12 +865,22 @@ private void verify( new BKDConfig(numDataDims, numIndexDims, numBytesPerDim, maxPointsInLeafNode); final BitSet hits = new BitSet(); pointValues.intersect(getIntersectVisitor(hits, queryMin, queryMax, config)); + final BitSet hitsWithTwoPhaseVisitor = new BitSet(); + pointValues.intersect( + getTwoPhaseIntersectVisitor(hitsWithTwoPhaseVisitor, queryMin, queryMax, config)); + assertEquals(hitsWithTwoPhaseVisitor, hits); assertHits(hits, expected); hits.clear(); + hitsWithTwoPhaseVisitor.clear(); pointValues .getPointTree() .visitDocValues(getIntersectVisitor(hits, queryMin, queryMax, config)); + pointValues + .getPointTree() + .visitDocValues( + getTwoPhaseIntersectVisitor(hitsWithTwoPhaseVisitor, queryMin, queryMax, config)); + assertEquals(hitsWithTwoPhaseVisitor, hits); assertHits(hits, expected); } in.close(); @@ -883,7 +904,9 @@ private void assertSize(PointValues.PointTree tree) throws IOException { tree = rarely() ? clone : tree; final long[] visitDocIDSize = new long[] {0}; final long[] visitDocValuesSize = new long[] {0}; - final IntersectVisitor visitor = + IntersectVisitor visitor = null; + + visitor = new IntersectVisitor() { @Override public void visit(int docID) { @@ -900,9 +923,11 @@ public Relation compare(byte[] minPackedValue, byte[] maxPackedValue) { return Relation.CELL_CROSSES_QUERY; } }; + if (random().nextBoolean()) { tree.visitDocIDs(visitor); tree.visitDocValues(visitor); + } else { tree.visitDocValues(visitor); tree.visitDocIDs(visitor); @@ -1052,6 +1077,116 @@ public Relation compare(byte[] minPacked, byte[] maxPacked) { }; } + private BKDReader.BaseTwoPhaseIntersectVisitor getTwoPhaseIntersectVisitor( + BitSet hits, byte[][] queryMin, byte[][] queryMax, BKDConfig config) { + return new BKDReader.BaseTwoPhaseIntersectVisitor() { + @Override + public void visit(int docID) { + hits.set(docID); + // System.out.println("visit docID=" + docID); + } + + @Override + public void visit(int docID, byte[] packedValue) { + // System.out.println("visit check docID=" + docID); + for (int dim = 0; dim < config.numIndexDims(); dim++) { + if (Arrays.compareUnsigned( + packedValue, + dim * config.bytesPerDim(), + dim * config.bytesPerDim() + config.bytesPerDim(), + queryMin[dim], + 0, + config.bytesPerDim()) + < 0 + || Arrays.compareUnsigned( + packedValue, + dim * config.bytesPerDim(), + dim * config.bytesPerDim() + config.bytesPerDim(), + queryMax[dim], + 0, + config.bytesPerDim()) + > 0) { + // System.out.println(" no"); + return; + } + } + + // System.out.println(" yes"); + hits.set(docID); + } + + @Override + public void visit(DocIdSetIterator iterator, byte[] packedValue) throws IOException { + if (random().nextBoolean()) { + // check the default method is correct + super.visit(iterator, packedValue); + } else { + assertEquals(iterator.docID(), -1); + int cost = Math.toIntExact(iterator.cost()); + int numberOfPoints = 0; + int docID; + while ((docID = iterator.nextDoc()) != DocIdSetIterator.NO_MORE_DOCS) { + assertEquals(iterator.docID(), docID); + visit(docID, packedValue); + numberOfPoints++; + } + assertEquals(cost, numberOfPoints); + assertEquals(iterator.docID(), DocIdSetIterator.NO_MORE_DOCS); + assertEquals(iterator.nextDoc(), DocIdSetIterator.NO_MORE_DOCS); + assertEquals(iterator.docID(), DocIdSetIterator.NO_MORE_DOCS); + } + } + + @Override + public Relation compare(byte[] minPacked, byte[] maxPacked) { + boolean crosses = false; + for (int dim = 0; dim < config.numIndexDims(); dim++) { + if (Arrays.compareUnsigned( + maxPacked, + dim * config.bytesPerDim(), + dim * config.bytesPerDim() + config.bytesPerDim(), + queryMin[dim], + 0, + config.bytesPerDim()) + < 0 + || Arrays.compareUnsigned( + minPacked, + dim * config.bytesPerDim(), + dim * config.bytesPerDim() + config.bytesPerDim(), + queryMax[dim], + 0, + config.bytesPerDim()) + > 0) { + return Relation.CELL_OUTSIDE_QUERY; + } else if (Arrays.compareUnsigned( + minPacked, + dim * config.bytesPerDim(), + dim * config.bytesPerDim() + config.bytesPerDim(), + queryMin[dim], + 0, + config.bytesPerDim()) + < 0 + || Arrays.compareUnsigned( + maxPacked, + dim * config.bytesPerDim(), + dim * config.bytesPerDim() + config.bytesPerDim(), + queryMax[dim], + 0, + config.bytesPerDim()) + > 0) { + crosses = true; + } + } + + if (crosses) { + return Relation.CELL_CROSSES_QUERY; + } else { + return Relation.CELL_INSIDE_QUERY; + } + } + }; + } + private BigInteger randomBigInt(int numBytes) { BigInteger x = new BigInteger(numBytes * 8 - 1, random()); if (random().nextBoolean()) { @@ -1344,6 +1479,28 @@ public void test2DLongOrdsOffline() throws Exception { in.seek(fp); PointValues r = getPointValues(in); int[] count = new int[1]; + int[] countWithTwoPhaseVisitor = new int[1]; + r.intersect( + new BKDReader.BaseTwoPhaseIntersectVisitor() { + @Override + public void visit(int docID) throws IOException { + countWithTwoPhaseVisitor[0]++; + } + + @Override + public void visit(int docID, byte[] packedValue) throws IOException { + visit(docID); + } + + @Override + public Relation compare(byte[] minPackedValue, byte[] maxPackedValue) { + if (random().nextInt(7) == 1) { + return Relation.CELL_CROSSES_QUERY; + } else { + return Relation.CELL_INSIDE_QUERY; + } + } + }); r.intersect( new IntersectVisitor() { @@ -1366,6 +1523,7 @@ public Relation compare(byte[] minPacked, byte[] maxPacked) { } } }); + assertEquals(count[0], countWithTwoPhaseVisitor[0]); assertEquals(numDocs, count[0]); in.close(); } @@ -1411,6 +1569,31 @@ public void testWastedLeadingBytes() throws Exception { in.seek(fp); PointValues r = getPointValues(in); int[] count = new int[1]; + int[] countWithTwoPhaseVisitor = new int[1]; + r.intersect( + new BKDReader.BaseTwoPhaseIntersectVisitor() { + @Override + public void visit(int docID) throws IOException { + countWithTwoPhaseVisitor[0]++; + } + + @Override + public void visit(int docID, byte[] packedValue) throws IOException { + assert packedValue.length == numDims * bytesPerDim; + visit(docID); + } + + @Override + public Relation compare(byte[] minPacked, byte[] maxPacked) { + assert minPacked.length == numIndexDims * bytesPerDim; + assert maxPacked.length == numIndexDims * bytesPerDim; + if (random().nextInt(7) == 1) { + return Relation.CELL_CROSSES_QUERY; + } else { + return Relation.CELL_INSIDE_QUERY; + } + } + }); r.intersect( new IntersectVisitor() { @@ -1436,6 +1619,7 @@ public Relation compare(byte[] minPacked, byte[] maxPacked) { } } }); + assertEquals(count[0], countWithTwoPhaseVisitor[0]); assertEquals(numDocs, count[0]); in.close(); dir.close(); @@ -1480,6 +1664,22 @@ public void testEstimatePointCount() throws IOException { pointsIn.seek(indexFP); PointValues points = getPointValues(pointsIn); + long estimatedCountWithTwoPhaseVisitor = + points.estimatePointCount( + new BKDReader.BaseTwoPhaseIntersectVisitor() { + @Override + public void visit(int docID) throws IOException {} + + @Override + public void visit(int docID, byte[] packedValue) throws IOException {} + + @Override + public Relation compare(byte[] minPackedValue, byte[] maxPackedValue) { + return Relation.CELL_INSIDE_QUERY; + } + }); + assertEquals(estimatedCountWithTwoPhaseVisitor, numValues); + // If all points match, then the point count is numValues assertEquals( numValues, @@ -1497,6 +1697,22 @@ public Relation compare(byte[] minPackedValue, byte[] maxPackedValue) { } })); + assertEquals( + 0, + points.estimatePointCount( + new BKDReader.BaseTwoPhaseIntersectVisitor() { + @Override + public void visit(int docID, byte[] packedValue) throws IOException {} + + @Override + public void visit(int docID) throws IOException {} + + @Override + public Relation compare(byte[] minPackedValue, byte[] maxPackedValue) { + return Relation.CELL_OUTSIDE_QUERY; + } + })); + // Return 0 if no points match assertEquals( 0, @@ -1516,6 +1732,28 @@ public Relation compare(byte[] minPackedValue, byte[] maxPackedValue) { // If only one point matches, then the point count is (actualMaxPointsInLeafNode + 1) / 2 // in general, or maybe 2x that if the point is a split value + final long pointCountWithTwoPhaseVisitor = + points.estimatePointCount( + new BKDReader.BaseTwoPhaseIntersectVisitor() { + @Override + public void visit(int docID, byte[] packedValue) throws IOException {} + + @Override + public void visit(int docID) throws IOException {} + + @Override + public Relation compare(byte[] minPackedValue, byte[] maxPackedValue) { + if (Arrays.compareUnsigned( + uniquePointValue, 0, numBytesPerDim, maxPackedValue, 0, numBytesPerDim) + > 0 + || Arrays.compareUnsigned( + uniquePointValue, 0, numBytesPerDim, minPackedValue, 0, numBytesPerDim) + < 0) { + return Relation.CELL_OUTSIDE_QUERY; + } + return Relation.CELL_CROSSES_QUERY; + } + }); final long pointCount = points.estimatePointCount( new IntersectVisitor() { @@ -1539,6 +1777,7 @@ public Relation compare(byte[] minPackedValue, byte[] maxPackedValue) { } }); long lastNodePointCount = numValues % maxPointsInLeafNode; + assertEquals(pointCount, pointCountWithTwoPhaseVisitor); assertTrue( "" + pointCount, pointCount == (maxPointsInLeafNode + 1) / 2 // common case