From 49d39ba9bbef4c031f1ea1d5535b223c29c54f57 Mon Sep 17 00:00:00 2001 From: Atri Sharma Date: Tue, 13 Aug 2019 21:47:58 +0530 Subject: [PATCH 1/2] LUCENE-8949: Allow LeafFieldComparators to publish Feature Values --- .../lucene/document/FeatureSortField.java | 5 +++ .../LatLonPointDistanceComparator.java | 6 +++ .../apache/lucene/search/FieldComparator.java | 40 +++++++++++++++++++ .../lucene/search/LeafFieldComparator.java | 5 ++- .../search/MultiLeafFieldComparator.java | 11 +++++ .../search/TestElevationComparator.java | 6 ++- .../lucene/queries/function/ValueSource.java | 5 +++ .../Geo3DPointDistanceComparator.java | 5 +++ .../Geo3DPointOutsideDistanceComparator.java | 5 +++ .../component/QueryElevationComponent.java | 5 +++ .../apache/solr/schema/RandomSortField.java | 5 +++ 11 files changed, 96 insertions(+), 2 deletions(-) diff --git a/lucene/core/src/java/org/apache/lucene/document/FeatureSortField.java b/lucene/core/src/java/org/apache/lucene/document/FeatureSortField.java index 1e73f8c274af..8902f8e5b001 100644 --- a/lucene/core/src/java/org/apache/lucene/document/FeatureSortField.java +++ b/lucene/core/src/java/org/apache/lucene/document/FeatureSortField.java @@ -160,5 +160,10 @@ public Float value(int slot) { public int compareTop(int doc) throws IOException { return Float.compare(topValue, getValueForDoc(doc)); } + + @Override + public Float leafValue(int docID) throws IOException { + return getValueForDoc(docID); + } } } diff --git a/lucene/core/src/java/org/apache/lucene/document/LatLonPointDistanceComparator.java b/lucene/core/src/java/org/apache/lucene/document/LatLonPointDistanceComparator.java index 10566b6c8ea7..006d485b23fa 100644 --- a/lucene/core/src/java/org/apache/lucene/document/LatLonPointDistanceComparator.java +++ b/lucene/core/src/java/org/apache/lucene/document/LatLonPointDistanceComparator.java @@ -217,6 +217,12 @@ public int compareTop(int doc) throws IOException { return minValue; } + //TODO: Implement this + @Override + public Double leafValue(int docID) throws IOException { + throw new UnsupportedOperationException("This comparator does not support getting leaf values"); + } + // second half of the haversin calculation, used to convert results from haversin1 (used internally // for sorting) for display purposes. static double haversin2(double partial) { diff --git a/lucene/core/src/java/org/apache/lucene/search/FieldComparator.java b/lucene/core/src/java/org/apache/lucene/search/FieldComparator.java index 2437a338cc6c..b8ca1ff94096 100644 --- a/lucene/core/src/java/org/apache/lucene/search/FieldComparator.java +++ b/lucene/core/src/java/org/apache/lucene/search/FieldComparator.java @@ -219,6 +219,11 @@ public Double value(int slot) { public int compareTop(int doc) throws IOException { return Double.compare(topValue, getValueForDoc(doc)); } + + @Override + public Double leafValue(int docID) throws IOException { + return getValueForDoc(docID); + } } /** Parses field's values as float (using {@link @@ -279,6 +284,11 @@ public Float value(int slot) { public int compareTop(int doc) throws IOException { return Float.compare(topValue, getValueForDoc(doc)); } + + @Override + public Float leafValue(int docID) throws IOException { + return getValueForDoc(docID); + } } /** Parses field's values as int (using {@link @@ -341,6 +351,11 @@ public Integer value(int slot) { public int compareTop(int doc) throws IOException { return Integer.compare(topValue, getValueForDoc(doc)); } + + @Override + public Integer leafValue(int docID) throws IOException { + return getValueForDoc(docID); + } } /** Parses field's values as long (using {@link @@ -401,6 +416,11 @@ public Long value(int slot) { public int compareTop(int doc) throws IOException { return Long.compare(topValue, getValueForDoc(doc)); } + + @Override + public Long leafValue(int docID) throws IOException { + return getValueForDoc(docID); + } } /** Sorts by descending relevance. NOTE: if you are @@ -484,6 +504,11 @@ public int compareTop(int doc) throws IOException { assert !Float.isNaN(docValue); return Float.compare(docValue, topValue); } + + @Override + public Float leafValue(int docID) throws IOException { + return scorer.score(); + } } /** Sorts by ascending docID */ @@ -545,6 +570,11 @@ public int compareTop(int doc) { return Integer.compare(topValue, docValue); } + @Override + public Integer leafValue(int docID) throws IOException { + return docBase + docID; + } + @Override public void setScorer(Scorable scorer) {} } @@ -686,6 +716,11 @@ public int compareBottom(int doc) throws IOException { } } + @Override + public Integer leafValue(int docID) throws IOException { + return getOrdForDoc(docID); + } + @Override public void copy(int slot, int doc) throws IOException { int ord = getOrdForDoc(doc); @@ -927,5 +962,10 @@ public int compareTop(int doc) throws IOException { @Override public void setScorer(Scorable scorer) {} + + @Override + public BytesRef leafValue(int docID) throws IOException { + return getValueForDoc(docID); + } } } diff --git a/lucene/core/src/java/org/apache/lucene/search/LeafFieldComparator.java b/lucene/core/src/java/org/apache/lucene/search/LeafFieldComparator.java index c2c22745d8b9..daf3a238b7da 100644 --- a/lucene/core/src/java/org/apache/lucene/search/LeafFieldComparator.java +++ b/lucene/core/src/java/org/apache/lucene/search/LeafFieldComparator.java @@ -51,7 +51,7 @@ * @see FieldComparator * @lucene.experimental */ -public interface LeafFieldComparator { +public interface LeafFieldComparator { /** * Set the bottom slot, ie the "weakest" (sorted last) @@ -116,4 +116,7 @@ public interface LeafFieldComparator { * obtain the current hit's score, if necessary. */ void setScorer(Scorable scorer) throws IOException; + /** Publishes feature values for the given docID + */ + T leafValue(int doc) throws IOException; } diff --git a/lucene/core/src/java/org/apache/lucene/search/MultiLeafFieldComparator.java b/lucene/core/src/java/org/apache/lucene/search/MultiLeafFieldComparator.java index acec040fc152..d8a3eb980e8c 100644 --- a/lucene/core/src/java/org/apache/lucene/search/MultiLeafFieldComparator.java +++ b/lucene/core/src/java/org/apache/lucene/search/MultiLeafFieldComparator.java @@ -89,4 +89,15 @@ public void setScorer(Scorable scorer) throws IOException { } } + @Override + public Object leafValue(int docID) throws IOException { + Object[] valuesArray = new Object[comparators.length]; + + for (int i = 0; i < comparators.length; i++) { + valuesArray[i] = comparators[i].leafValue(docID); + } + + return valuesArray; + } + } diff --git a/lucene/core/src/test/org/apache/lucene/search/TestElevationComparator.java b/lucene/core/src/test/org/apache/lucene/search/TestElevationComparator.java index d4ea75d55133..38b11a3a079d 100644 --- a/lucene/core/src/test/org/apache/lucene/search/TestElevationComparator.java +++ b/lucene/core/src/test/org/apache/lucene/search/TestElevationComparator.java @@ -194,6 +194,11 @@ public void copy(int slot, int doc) throws IOException { @Override public void setScorer(Scorable scorer) {} + + @Override + public Integer leafValue(int docID) throws IOException { + return docVal(docID); + } }; } @@ -212,7 +217,6 @@ public Integer value(int slot) { return Integer.valueOf(values[slot]); } - }; } } diff --git a/lucene/queries/src/java/org/apache/lucene/queries/function/ValueSource.java b/lucene/queries/src/java/org/apache/lucene/queries/function/ValueSource.java index 1ba35804002b..cadac13eecbf 100644 --- a/lucene/queries/src/java/org/apache/lucene/queries/function/ValueSource.java +++ b/lucene/queries/src/java/org/apache/lucene/queries/function/ValueSource.java @@ -442,5 +442,10 @@ public int compareTop(int doc) throws IOException { final double docValue = docVals.doubleVal(doc); return Double.compare(topValue, docValue); } + + @Override + public Double leafValue(int docID) throws IOException { + return docVals.doubleVal(docID); + } } } diff --git a/lucene/spatial3d/src/java/org/apache/lucene/spatial3d/Geo3DPointDistanceComparator.java b/lucene/spatial3d/src/java/org/apache/lucene/spatial3d/Geo3DPointDistanceComparator.java index b2cd9c5c45b5..15e4181b172e 100644 --- a/lucene/spatial3d/src/java/org/apache/lucene/spatial3d/Geo3DPointDistanceComparator.java +++ b/lucene/spatial3d/src/java/org/apache/lucene/spatial3d/Geo3DPointDistanceComparator.java @@ -150,6 +150,11 @@ public int compareTop(int doc) throws IOException { return Double.compare(topValue, computeMinimumDistance(doc)); } + @Override + public Double leafValue(int docID) throws IOException { + return computeMinimumDistance(docID); + } + double computeMinimumDistance(final int doc) throws IOException { if (doc > currentDocs.docID()) { currentDocs.advance(doc); diff --git a/lucene/spatial3d/src/java/org/apache/lucene/spatial3d/Geo3DPointOutsideDistanceComparator.java b/lucene/spatial3d/src/java/org/apache/lucene/spatial3d/Geo3DPointOutsideDistanceComparator.java index c45cbbabb6d8..9eca26e794f0 100644 --- a/lucene/spatial3d/src/java/org/apache/lucene/spatial3d/Geo3DPointOutsideDistanceComparator.java +++ b/lucene/spatial3d/src/java/org/apache/lucene/spatial3d/Geo3DPointOutsideDistanceComparator.java @@ -121,6 +121,11 @@ public int compareTop(int doc) throws IOException { return Double.compare(topValue, computeMinimumDistance(doc)); } + @Override + public Double leafValue(int docID) throws IOException { + return computeMinimumDistance(docID); + } + double computeMinimumDistance(final int doc) throws IOException { if (doc > currentDocs.docID()) { currentDocs.advance(doc); diff --git a/solr/core/src/java/org/apache/solr/handler/component/QueryElevationComponent.java b/solr/core/src/java/org/apache/solr/handler/component/QueryElevationComponent.java index 4f4f23299538..12d52c8d78eb 100644 --- a/solr/core/src/java/org/apache/solr/handler/component/QueryElevationComponent.java +++ b/solr/core/src/java/org/apache/solr/handler/component/QueryElevationComponent.java @@ -1220,6 +1220,11 @@ public int compareTop(int doc) { final int docValue = docVal(doc); return topVal - docValue; // values will be small enough that there is no overflow concern } + + @Override + public Integer leafValue(int docID) throws IOException { + return docVal(docID); + } }; } } diff --git a/solr/core/src/java/org/apache/solr/schema/RandomSortField.java b/solr/core/src/java/org/apache/solr/schema/RandomSortField.java index 44bb420947fb..b46f52c7969b 100644 --- a/solr/core/src/java/org/apache/solr/schema/RandomSortField.java +++ b/solr/core/src/java/org/apache/solr/schema/RandomSortField.java @@ -154,6 +154,11 @@ public int compareTop(int doc) { // values will be positive... no overflow possible. return topVal - hash(doc+seed); } + + @Override + public Integer leafValue(int docID) throws IOException { + return hash(docID + seed); + } }; } }; From 35303ee16ee94630dd225f7d1e7ebf32934614fe Mon Sep 17 00:00:00 2001 From: Atri Sharma Date: Mon, 26 Aug 2019 13:53:18 +0530 Subject: [PATCH 2/2] Shared PQ Based Concurrent Early Termination --- .../org/apache/lucene/search/Collector.java | 8 + ...EarlyTerminatingFieldCollectorManager.java | 210 ++++++++++++++++ .../lucene/search/FieldValueHitQueue.java | 112 ++++++++- .../apache/lucene/search/IndexSearcher.java | 1 + .../lucene/search/TopDocsCollector.java | 1 - .../lucene/search/TopFieldCollector.java | 148 +++++++++++- .../search/TestEarlyTerminationFieldCM.java | 226 ++++++++++++++++++ .../search/SmallSliceSizeIndexSearcher.java | 46 ++++ 8 files changed, 749 insertions(+), 3 deletions(-) create mode 100644 lucene/core/src/java/org/apache/lucene/search/EarlyTerminatingFieldCollectorManager.java create mode 100644 lucene/core/src/test/org/apache/lucene/search/TestEarlyTerminationFieldCM.java create mode 100644 lucene/test-framework/src/java/org/apache/lucene/search/SmallSliceSizeIndexSearcher.java diff --git a/lucene/core/src/java/org/apache/lucene/search/Collector.java b/lucene/core/src/java/org/apache/lucene/search/Collector.java index 9818c673507c..477054e5c35b 100644 --- a/lucene/core/src/java/org/apache/lucene/search/Collector.java +++ b/lucene/core/src/java/org/apache/lucene/search/Collector.java @@ -77,4 +77,12 @@ public interface Collector { * Indicates what features are required from the scorer. */ ScoreMode scoreMode(); + + /** + * Indicates that input has ended for the collector. This allows the collector to perform + * post processing (if any). + */ + default void postProcess() { + // No-op + } } diff --git a/lucene/core/src/java/org/apache/lucene/search/EarlyTerminatingFieldCollectorManager.java b/lucene/core/src/java/org/apache/lucene/search/EarlyTerminatingFieldCollectorManager.java new file mode 100644 index 000000000000..8e88b8c24ce6 --- /dev/null +++ b/lucene/core/src/java/org/apache/lucene/search/EarlyTerminatingFieldCollectorManager.java @@ -0,0 +1,210 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.lucene.search; + +import java.util.Collection; +import java.util.concurrent.ConcurrentLinkedQueue; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.locks.ReentrantLock; + +import static org.apache.lucene.search.TopFieldCollector.EMPTY_SCOREDOCS; + +/** + * CollectorManager which allows early termination across multiple slices + * when the index sort key and the query sort key are the same + */ +public class EarlyTerminatingFieldCollectorManager implements CollectorManager { + private final Sort sort; + private final int numHits; + private final int totalHitsThreshold; + private final AtomicInteger globalTotalHits; + private final ReentrantLock lock; + private int numCollectors; + + private final ConcurrentLinkedQueue mergeableCollectors; + private FieldValueHitQueue globalHitQueue; + private FieldValueHitQueue.Entry bottom; + // We do not make this Atomic since it will be sought under a lock + private int queueSlotCounter; + private final AtomicBoolean mergeStarted; + public final AtomicBoolean mergeCompleted; + + public EarlyTerminatingFieldCollectorManager(Sort sort, int numHits, int totalHitsThreshold) { + this.sort = sort; + this.numHits = numHits; + this.totalHitsThreshold = totalHitsThreshold; + this.globalTotalHits = new AtomicInteger(); + this.lock = new ReentrantLock(); + this.mergeStarted = new AtomicBoolean(); + this.mergeCompleted = new AtomicBoolean(); + this.mergeableCollectors = new ConcurrentLinkedQueue(); + this.globalHitQueue = null; + } + + @Override + public TopFieldCollector.EarlyTerminatingFieldCollector newCollector() { + ++numCollectors; + + return new TopFieldCollector.EarlyTerminatingFieldCollector(sort, FieldValueHitQueue.create(sort.fields, numHits), numHits, + totalHitsThreshold, this, globalTotalHits); + } + + @Override + public TopFieldDocs reduce(Collection collectors) { + + if (globalHitQueue == null) { + final TopFieldDocs[] topDocs = new TopFieldDocs[collectors.size()]; + int i = 0; + for (TopFieldCollector collector : collectors) { + topDocs[i++] = collector.topDocs(); + } + return TopDocs.merge(sort, 0, numHits, topDocs); + } + + ScoreDoc[] results = populateResults(globalHitQueue.size()); + + return newTopDocs(results); + } + + public int compareAndUpdateBottom(int docBase, int doc, Object value) { + + try { + lock.lock(); + + // If not enough hits are accumulated, add this hit to the global hit queue + if (globalHitQueue.size() < numHits) { + FieldValueHitQueue.Entry newEntry = new FieldValueHitQueue.Entry(queueSlotCounter++, (doc + docBase), value); + bottom = (FieldValueHitQueue.Entry) globalHitQueue.add(newEntry); + return 1; + } + + FieldComparator[] comparators = globalHitQueue.getComparators(); + int[] reverseMul = globalHitQueue.getReverseMul(); + Object bottomValues = bottom.values; + Object[] valuesArray; + Object[] bottomValuesArray; + + if (comparators.length > 1) { + assert value instanceof Object[]; + valuesArray = (Object[]) value; + + assert bottomValues instanceof Object[]; + bottomValuesArray = (Object[]) bottomValues; + } else { + valuesArray = new Object[1]; + valuesArray[0] = value; + + bottomValuesArray = new Object[1]; + bottomValuesArray[0] = bottomValues; + } + + int cmp; + int i = 0; + for (FieldComparator comparator : comparators) { + cmp = reverseMul[i] * comparator.compareValues(bottomValuesArray[i], valuesArray[i]); + ++i; + + if (cmp != 0) { + if (cmp > 0) { + updateBottom(docBase, doc, value); + } + + return cmp; + } + } + + // For equal values, we choose the lower docID + if ((doc + docBase) < bottom.doc) { + updateBottom(docBase, doc, value); + + // Return a value greater than 0 to signify replacement + return 1; + } + + return 0; + } finally { + lock.unlock(); + } + } + + private final void updateBottom(int docBase, int doc, Object values) { + bottom.doc = docBase + doc; + bottom.values = values; + bottom = (FieldValueHitQueue.Entry) globalHitQueue.updateTop(); + + assert bottom != null; + } + + FieldValueHitQueue.Entry addCollectorToGlobalQueue(TopFieldCollector.EarlyTerminatingFieldCollector fieldCollector, int docBase) { + FieldValueHitQueue queue = fieldCollector.queue; + + try { + lock.lock(); + if (globalHitQueue == null) { + this.globalHitQueue = FieldValueHitQueue.createValuesComparingQueue(sort.fields, numHits); + } + + FieldValueHitQueue.Entry entry = (FieldValueHitQueue.Entry) queue.pop(); + while (entry != null) { + if (queueSlotCounter > numHits) { + throw new IllegalStateException("Global number exceeds number of hits. Current hit number " + queueSlotCounter + " numHits " + numHits); + } + + + if (globalHitQueue.size() > numHits) { + throw new IllegalStateException("WTF?"); + } + + // If hit count was already achieved, return this entry + if (globalHitQueue.size() == numHits) { + return entry; + } + + FieldValueHitQueue.Entry newEntry = new FieldValueHitQueue.Entry(queueSlotCounter++, (entry.doc + docBase), entry.values); + + bottom = (FieldValueHitQueue.Entry) globalHitQueue.add(newEntry); + + entry = (FieldValueHitQueue.Entry) queue.pop(); + } + } finally { + lock.unlock(); + } + + return null; + } + + private ScoreDoc[] populateResults(int howMany) { + ScoreDoc[] results = new ScoreDoc[howMany]; + // avoid casting if unnecessary. + for (int i = howMany - 1; i >= 0; i--) { + results[i] = globalHitQueue.fillFields((FieldValueHitQueue.Entry) globalHitQueue.pop()); + } + + return results; + } + + protected TopFieldDocs newTopDocs(ScoreDoc[] results) { + if (results == null) { + results = EMPTY_SCOREDOCS; + } + + //TODO: atris -- Is the relation correct, since we are early terminating? + return new TopFieldDocs(new TotalHits(results.length, TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO), results, globalHitQueue.getFields()); + } +} diff --git a/lucene/core/src/java/org/apache/lucene/search/FieldValueHitQueue.java b/lucene/core/src/java/org/apache/lucene/search/FieldValueHitQueue.java index d282a0af54a7..ddf171bfe6a6 100644 --- a/lucene/core/src/java/org/apache/lucene/search/FieldValueHitQueue.java +++ b/lucene/core/src/java/org/apache/lucene/search/FieldValueHitQueue.java @@ -33,14 +33,21 @@ public abstract class FieldValueHitQueue ext /** * Extension of ScoreDoc to also store the - * {@link FieldComparator} slot. + * {@link FieldComparator} slot and optionally the {@link LeafFieldComparator}. */ public static class Entry extends ScoreDoc { public int slot; + public Object values; public Entry(int slot, int doc) { super(doc, Float.NaN); this.slot = slot; + this.values = null; + } + + public Entry(int slot, int doc, Object values) { + this(slot, doc); + this.values = values; } @Override @@ -119,6 +126,83 @@ protected boolean lessThan(final Entry hitA, final Entry hitB) { } } + + /** + * An implementation of {@link FieldValueHitQueue} which uses a single value for comparison + */ + private static final class SingleValueComparisonFieldValueHitQueue extends FieldValueHitQueue { + + public SingleValueComparisonFieldValueHitQueue(SortField[] fields, int size) { + super(fields, size); + } + + @Override + protected boolean lessThan(final Entry hitA, final Entry hitB) { + + assert hitA != hitB; + assert hitA.slot != hitB.slot; + + assert hitA.values != null; + assert hitB.values != null; + + int numComparators = comparators.length; + for (int i = 0; i < numComparators; ++i) { + FieldComparator fieldComparator = comparators[i]; + final int c = reverseMul[i] * fieldComparator.compareValues(hitA.values, hitB.values); + if (c != 0) { + // Short circuit + return c > 0; + } + } + + // avoid random sort order that could lead to duplicates (bug #31241): + return hitA.doc > hitB.doc; + } + + } + + /** + * An implementation of {@link FieldValueHitQueue} which uses values to compare members + */ + private static final class MultiValuesComparisonFieldValueHitQueue extends FieldValueHitQueue { + + public MultiValuesComparisonFieldValueHitQueue(SortField[] fields, int size) { + super(fields, size); + } + + @Override + protected boolean lessThan(final Entry hitA, final Entry hitB) { + + assert hitA != hitB; + assert hitA.slot != hitB.slot; + + assert hitA.values != null; + assert hitB.values != null; + + assert hitA.values instanceof Object[]; + assert hitB.values instanceof Object[]; + + Object[] firstValuesArray = (Object[]) hitA.values; + Object[] secondValuesArray = (Object[]) hitB.values; + + assert firstValuesArray.length == comparators.length; + assert secondValuesArray.length == comparators.length; + + int numComparators = comparators.length; + for (int i = 0; i < numComparators; ++i) { + FieldComparator fieldComparator = comparators[i]; + final int c = reverseMul[i] * fieldComparator.compareValues(firstValuesArray[i], secondValuesArray[i]); + if (c != 0) { + // Short circuit + return c > 0; + } + } + + // avoid random sort order that could lead to duplicates (bug #31241): + return hitA.doc > hitB.doc; + } + + } // prevent instantiation and extension. private FieldValueHitQueue(SortField[] fields, int size) { @@ -165,6 +249,32 @@ public static FieldValueHitQueue create( return new MultiComparatorsFieldValueHitQueue<>(fields, size); } } + + /** + * Creates a hit queue sorted by the given list of fields and using feature values for comparisons between + * member entries + * + *

NOTE: The instances returned by this method + * pre-allocate a full array of length numHits. + * + * @param fields + * SortField array we are sorting by in priority order (highest + * priority first); cannot be null or empty + * @param size + * The number of hits to retain. Must be greater than zero. + */ + public static FieldValueHitQueue createValuesComparingQueue(SortField[] fields, int size) { + + if (fields.length == 0) { + throw new IllegalArgumentException("Sort must contain at least one field"); + } + + if (fields.length == 1) { + return new SingleValueComparisonFieldValueHitQueue<>(fields, size); + } else { + return new MultiValuesComparisonFieldValueHitQueue<>(fields, size); + } + } public FieldComparator[] getComparators() { return comparators; diff --git a/lucene/core/src/java/org/apache/lucene/search/IndexSearcher.java b/lucene/core/src/java/org/apache/lucene/search/IndexSearcher.java index 3b2ab474d395..4eada2d1ddd7 100644 --- a/lucene/core/src/java/org/apache/lucene/search/IndexSearcher.java +++ b/lucene/core/src/java/org/apache/lucene/search/IndexSearcher.java @@ -726,6 +726,7 @@ protected void search(List leaves, Weight weight, Collector c } } } + collector.postProcess(); } /** Expert: called to re-write queries into primitive queries. diff --git a/lucene/core/src/java/org/apache/lucene/search/TopDocsCollector.java b/lucene/core/src/java/org/apache/lucene/search/TopDocsCollector.java index 3e9f8344e927..8d947d0cae61 100644 --- a/lucene/core/src/java/org/apache/lucene/search/TopDocsCollector.java +++ b/lucene/core/src/java/org/apache/lucene/search/TopDocsCollector.java @@ -17,7 +17,6 @@ package org.apache.lucene.search; - import org.apache.lucene.util.PriorityQueue; /** diff --git a/lucene/core/src/java/org/apache/lucene/search/TopFieldCollector.java b/lucene/core/src/java/org/apache/lucene/search/TopFieldCollector.java index 8ba42abafbee..573a5463d57b 100644 --- a/lucene/core/src/java/org/apache/lucene/search/TopFieldCollector.java +++ b/lucene/core/src/java/org/apache/lucene/search/TopFieldCollector.java @@ -22,6 +22,7 @@ import java.util.Comparator; import java.util.List; import java.util.Objects; +import java.util.concurrent.atomic.AtomicInteger; import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.index.ReaderUtil; @@ -67,6 +68,22 @@ public void setScorer(Scorable scorer) throws IOException { } } + public static abstract class EarlyTerminatingMultiComparatorLeafCollector extends MultiComparatorLeafCollector { + public final EarlyTerminatingFieldCollector earlyTerminatingFieldCollector; + + EarlyTerminatingMultiComparatorLeafCollector(LeafFieldComparator[] comparators, int[] reverseMul, + EarlyTerminatingFieldCollector collector) { + super(comparators, reverseMul); + this.earlyTerminatingFieldCollector = collector; + } + + @Override + public void setScorer(Scorable scorer) throws IOException { + comparator.setScorer(scorer); + this.scorer = scorer; + } + } + static boolean canEarlyTerminate(Sort searchSort, Sort indexSort) { return canEarlyTerminateOnDocId(searchSort) || canEarlyTerminateOnPrefix(searchSort, indexSort); @@ -160,6 +177,7 @@ public void collect(int doc) throws IOException { // Copy hit into queue comparator.copy(slot, doc); add(slot, doc); + if (queueFull) { comparator.setBottom(bottom.slot); updateMinCompetitiveScore(scorer); @@ -279,7 +297,130 @@ public void collect(int doc) throws IOException { } - private static final ScoreDoc[] EMPTY_SCOREDOCS = new ScoreDoc[0]; + /* + * Collects hits into a local queue until the requested number of hits are collected + * globally. Post that, a global calibration step is performed + */ + public static class EarlyTerminatingFieldCollector extends TopFieldCollector { + + final Sort sort; + final FieldValueHitQueue queue; + final EarlyTerminatingFieldCollectorManager earlyTerminatingFieldCollectorManager; + private final AtomicInteger globalNumberOfHits; + private boolean addedSelfToGlobalQueue; + + //TODO: Refactor this to make an interface only for field collector uses + public EarlyTerminatingFieldCollector(Sort sort, FieldValueHitQueue queue, int numHits, int totalHitsThreshold, + EarlyTerminatingFieldCollectorManager collectorManager, AtomicInteger globalNumberOfHits) { + super(queue, numHits, totalHitsThreshold, sort.needsScores()); + this.sort = sort; + this.queue = queue; + this.earlyTerminatingFieldCollectorManager = collectorManager; + this.globalNumberOfHits = globalNumberOfHits; + } + + @Override + public LeafCollector getLeafCollector(LeafReaderContext context) throws IOException { + docBase = context.docBase; + + final LeafFieldComparator[] comparators = queue.getComparators(context); + final int[] reverseMul = queue.getReverseMul(); + final Sort indexSort = context.reader().getMetaData().getSort(); + final boolean canEarlyTerminate = canEarlyTerminate(sort, indexSort); + + return new EarlyTerminatingMultiComparatorLeafCollector(comparators, reverseMul, this) { + + boolean collectedAllCompetitiveHits = false; + + @Override + public void setScorer(Scorable scorer) throws IOException { + super.setScorer(scorer); + updateMinCompetitiveScore(scorer); + } + + @Override + public void collect(int doc) throws IOException { + + if (globalNumberOfHits.incrementAndGet() > numHits) { + if (addedSelfToGlobalQueue == false) { + Entry returnedEntry = earlyTerminatingFieldCollectorManager.addCollectorToGlobalQueue(earlyTerminatingFieldCollector, docBase); + + if (returnedEntry != null) { + filterCompetitiveHit(returnedEntry.doc, false, returnedEntry.values); + + if (queue.size() > 0) { + Entry entry = queue.pop(); + + while (entry != null) { + filterCompetitiveHit(entry.doc, false, entry.values); + entry = queue.pop(); + } + } + } + addedSelfToGlobalQueue = true; + } + + filterCompetitiveHit(doc, true, comparator.leafValue(doc)); + } else { + // Startup transient: queue hasn't gathered numHits yet + int slot = totalHits; + ++totalHits; + + comparator.copy(slot, doc); + add(slot, doc, comparator.leafValue(doc)); + } + } + + private void filterCompetitiveHit(int doc, boolean doEarlyTermination, Object value) throws IOException { + if (collectedAllCompetitiveHits || earlyTerminatingFieldCollectorManager.compareAndUpdateBottom(docBase, doc, value) <= 0) { + // since docs are visited in doc Id order, if compare is 0, it means + // this document is largest than anything else in the queue, and + // therefore not competitive. + if (canEarlyTerminate) { + if ((globalNumberOfHits.getAcquire() > totalHitsThreshold) && doEarlyTermination) { + totalHitsRelation = Relation.GREATER_THAN_OR_EQUAL_TO; + throw new CollectionTerminatedException(); + } else { + collectedAllCompetitiveHits = true; + } + } else if (totalHitsRelation == Relation.EQUAL_TO) { + // we just reached totalHitsThreshold, we can start setting the min + // competitive score now + updateMinCompetitiveScore(scorer); + } + return; + } + + updateMinCompetitiveScore(scorer); + } + }; + } + + @Override + public void postProcess() { + if (addedSelfToGlobalQueue == false) { + Entry returnedEntry = earlyTerminatingFieldCollectorManager.addCollectorToGlobalQueue(this, docBase); + + if (returnedEntry != null) { + if (returnedEntry != null) { + earlyTerminatingFieldCollectorManager.compareAndUpdateBottom(docBase, returnedEntry.doc, returnedEntry.values); + + if (queue.size() > 0) { + Entry entry = queue.pop(); + + while (entry != null) { + earlyTerminatingFieldCollectorManager.compareAndUpdateBottom(docBase, returnedEntry.doc, returnedEntry.values); + entry = queue.pop(); + } + } + } + } + addedSelfToGlobalQueue = true; + } + } + } + + public static final ScoreDoc[] EMPTY_SCOREDOCS = new ScoreDoc[0]; final int numHits; final int totalHitsThreshold; @@ -458,6 +599,11 @@ final void add(int slot, int doc) { queueFull = totalHits == numHits; } + final void add(int slot, int doc, Object values) { + bottom = pq.add(new Entry(slot, docBase + doc, values)); + queueFull = totalHits == numHits; + } + final void updateBottom(int doc) { // bottom.score is already set to Float.NaN in add(). bottom.doc = docBase + doc; diff --git a/lucene/core/src/test/org/apache/lucene/search/TestEarlyTerminationFieldCM.java b/lucene/core/src/test/org/apache/lucene/search/TestEarlyTerminationFieldCM.java new file mode 100644 index 000000000000..9fad8b846497 --- /dev/null +++ b/lucene/core/src/test/org/apache/lucene/search/TestEarlyTerminationFieldCM.java @@ -0,0 +1,226 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.lucene.search; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.HashSet; +import java.util.List; +import java.util.Random; +import java.util.Set; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.LinkedBlockingQueue; +import java.util.concurrent.ThreadPoolExecutor; +import java.util.concurrent.TimeUnit; + +import com.carrotsearch.randomizedtesting.generators.RandomPicks; +import org.apache.lucene.analysis.MockAnalyzer; +import org.apache.lucene.document.Document; +import org.apache.lucene.document.Field; +import org.apache.lucene.document.NumericDocValuesField; +import org.apache.lucene.document.StringField; +import org.apache.lucene.index.IndexReader; +import org.apache.lucene.index.IndexWriterConfig; +import org.apache.lucene.index.MockRandomMergePolicy; +import org.apache.lucene.index.RandomIndexWriter; +import org.apache.lucene.index.SerialMergeScheduler; +import org.apache.lucene.index.Term; +import org.apache.lucene.store.Directory; +import org.apache.lucene.util.LuceneTestCase; +import org.apache.lucene.util.NamedThreadFactory; +import org.apache.lucene.util.TestUtil; + +public class TestEarlyTerminationFieldCM extends LuceneTestCase { + private int numDocs; + private List terms; + private Directory dir; + private final Sort sort = new Sort(new SortField("ndv1", SortField.Type.LONG)); + private RandomIndexWriter iw; + private IndexReader reader; + private static final int FORCE_MERGE_MAX_SEGMENT_COUNT = 5; + + private Document randomDocument() { + final Document doc = new Document(); + doc.add(new NumericDocValuesField("ndv1", random().nextInt(10))); + doc.add(new NumericDocValuesField("ndv2", random().nextInt(10))); + doc.add(new StringField("s", RandomPicks.randomFrom(random(), terms), Field.Store.YES)); + return doc; + } + + private void createRandomIndex(boolean singleSortedSegment) throws IOException { + dir = newDirectory(); + numDocs = TestUtil.nextInt(random(), 20, 50); + final int numTerms = TestUtil.nextInt(random(), 1, numDocs / 5); + Set randomTerms = new HashSet<>(); + while (randomTerms.size() < numTerms) { + randomTerms.add(TestUtil.randomSimpleString(random())); + } + terms = new ArrayList<>(randomTerms); + final long seed = random().nextLong(); + final IndexWriterConfig iwc = newIndexWriterConfig(new MockAnalyzer(new Random(seed))); + if (iwc.getMergePolicy() instanceof MockRandomMergePolicy) { + // MockRandomMP randomly wraps the leaf readers which makes merging angry + iwc.setMergePolicy(newTieredMergePolicy()); + } + iwc.setMergeScheduler(new SerialMergeScheduler()); // for reproducible tests + iwc.setIndexSort(sort); + iw = new RandomIndexWriter(new Random(seed), dir, iwc); + iw.setDoRandomForceMerge(false); // don't do this, it may happen anyway with MockRandomMP + for (int i = 0; i < numDocs; ++i) { + final Document doc = randomDocument(); + iw.addDocument(doc); + if (i == numDocs / 2 || (i != numDocs - 1 && random().nextInt(8) == 0)) { + iw.commit(); + } + } + if (singleSortedSegment) { + iw.forceMerge(1); + } + else if (random().nextBoolean()) { + iw.forceMerge(FORCE_MERGE_MAX_SEGMENT_COUNT); + } + reader = iw.getReader(); + if (reader.numDocs() == 0) { + iw.addDocument(new Document()); + reader.close(); + reader = iw.getReader(); + } + } + + private void createStaticIndex() throws IOException { + dir = newDirectory(); + numDocs = TestUtil.nextInt(random(), 20, 50); + final long seed = random().nextLong(); + final IndexWriterConfig iwc = newIndexWriterConfig(new MockAnalyzer(new Random(seed))); + iwc.setIndexSort(sort); + iw = new RandomIndexWriter(new Random(seed), dir, iwc); + + for (int i = 0; i < numDocs; ++i) { + final Document doc = new Document(); + doc.add(new StringField("s1", "foo1", Field.Store.YES)); + doc.add(new StringField("s2", "foo2", Field.Store.YES)); + iw.addDocument(doc); + if (i == numDocs / 5) { + iw.commit(); + } + } + + // One document with a special value + final Document doc = new Document(); + doc.add(new StringField("s3", "foo3", Field.Store.YES)); + + iw.addDocument(doc); + + reader = iw.getReader(); + } + + private void closeIndex() throws IOException { + reader.close(); + iw.close(); + dir.close(); + } + + public void testGlobalStateEarlyTermination() throws IOException { + createRandomIndex(false); + ExecutorService service = new ThreadPoolExecutor(4, 4, 0L, TimeUnit.MILLISECONDS, + new LinkedBlockingQueue(), + new NamedThreadFactory("TestGlobalStateCMEarlyTermination")); + final IndexSearcher searcher = new SmallSliceSizeIndexSearcher(reader, service); + final int numHits = TestUtil.nextInt(random(), 2, 5); + final int totalHitsThreshold = TestUtil.nextInt(random(), 1, (numHits - 1)); + + final EarlyTerminatingFieldCollectorManager collectorManager = new EarlyTerminatingFieldCollectorManager(sort, numHits, totalHitsThreshold); + final TopFieldCollector collector1 = TopFieldCollector.create(sort, numHits, null, totalHitsThreshold); + + final Query query; + if (random().nextBoolean()) { + query = new TermQuery(new Term("s", RandomPicks.randomFrom(random(), terms))); + } else { + query = new MatchAllDocsQuery(); + } + TopDocs td2 = searcher.search(query, collectorManager); + searcher.search(query, collector1); + TopDocs td1 = collector1.topDocs(); + + assertTrue("Values were different " + td2.totalHits.value + " " + td1.scoreDocs.length, + td2.totalHits.value >= td1.scoreDocs.length); + assertTrue(td2.totalHits.value <= reader.maxDoc()); + CheckHits.checkEqual(query, td1.scoreDocs, td2.scoreDocs); + closeIndex(); + service.shutdown(); + } + + public void testGlobalStateSingleSegmentEarlyTermination() throws IOException { + createRandomIndex(true); + ExecutorService service = new ThreadPoolExecutor(4, 4, 0L, TimeUnit.MILLISECONDS, + new LinkedBlockingQueue(), + new NamedThreadFactory("TestGlobalStateCMEarlyTermination")); + final IndexSearcher searcher = new SmallSliceSizeIndexSearcher(reader, service); + final int numHits = TestUtil.nextInt(random(), 1, numDocs); + final int totalHitsThreshold = TestUtil.nextInt(random(), 1, (numHits - 1) > 0 ? (numHits -1) : 2); + + final EarlyTerminatingFieldCollectorManager collectorManager = new EarlyTerminatingFieldCollectorManager(sort, numHits, totalHitsThreshold); + final TopFieldCollector collector1 = TopFieldCollector.create(sort, numHits, null, Integer.MAX_VALUE); + + final Query query; + if (random().nextBoolean()) { + query = new TermQuery(new Term("s", RandomPicks.randomFrom(random(), terms))); + } else { + query = new MatchAllDocsQuery(); + } + TopDocs td2 = searcher.search(query, collectorManager); + searcher.search(query, collector1); + TopDocs td1 = collector1.topDocs(); + + assertFalse(collector1.isEarlyTerminated()); + assertTrue(td2.totalHits.value >= td1.scoreDocs.length); + assertTrue(td2.totalHits.value <= reader.maxDoc()); + CheckHits.checkEqual(query, td1.scoreDocs, td2.scoreDocs); + closeIndex(); + service.shutdown(); + } + + public void testOneDocumentHasHigherScoreThanAll() throws IOException { + createStaticIndex(); + ExecutorService service = new ThreadPoolExecutor(4, 4, 0L, TimeUnit.MILLISECONDS, + new LinkedBlockingQueue(), + new NamedThreadFactory("TestGlobalStateCMEarlyTermination")); + final IndexSearcher searcher = new SmallSliceSizeIndexSearcher(reader, service); + final int numHits = TestUtil.nextInt(random(), 1, numDocs); + final int totalHitsThreshold = TestUtil.nextInt(random(), 1, (numHits - 1) > 0 ? (numHits - 1) : 2); + + final EarlyTerminatingFieldCollectorManager collectorManager = new EarlyTerminatingFieldCollectorManager(sort, numHits, totalHitsThreshold); + final TopFieldCollector collector1 = TopFieldCollector.create(sort, numHits, null, Integer.MAX_VALUE); + + BooleanQuery.Builder builder = new BooleanQuery.Builder(); + builder.add(new TermQuery(new Term("s1", "foo1")), BooleanClause.Occur.MUST); + builder.add(new TermQuery(new Term("s2", "foo2")), BooleanClause.Occur.SHOULD); + Query query = builder.build(); + TopDocs td2 = searcher.search(query, collectorManager); + searcher.search(query, collector1); + TopDocs td1 = collector1.topDocs(); + + assertFalse(collector1.isEarlyTerminated()); + assertTrue("Values did not match " + td1.totalHits.value + " " + td1.scoreDocs.length, + td2.totalHits.value >= td1.scoreDocs.length); + assertTrue(td2.totalHits.value <= reader.maxDoc()); + CheckHits.checkEqual(query, td1.scoreDocs, td2.scoreDocs); + closeIndex(); + service.shutdown(); + } +} diff --git a/lucene/test-framework/src/java/org/apache/lucene/search/SmallSliceSizeIndexSearcher.java b/lucene/test-framework/src/java/org/apache/lucene/search/SmallSliceSizeIndexSearcher.java new file mode 100644 index 000000000000..bca0ac06d181 --- /dev/null +++ b/lucene/test-framework/src/java/org/apache/lucene/search/SmallSliceSizeIndexSearcher.java @@ -0,0 +1,46 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.lucene.search; + +import java.util.List; +import java.util.concurrent.Executor; + +import org.apache.lucene.index.IndexReader; +import org.apache.lucene.index.LeafReaderContext; + +/** + * An {@link IndexSearcher} that has a smaller size limit for slices, allowing higher slice count with lesser + * number of documents + */ +public class SmallSliceSizeIndexSearcher extends IndexSearcher { + + private static final int MAX_DOCS_PER_SLICE = 10; + private static final int MAX_SEGMENTS_PER_SLICE = 5; + + /** Creates a searcher searching the provided index. Search on individual + * segments will be run in the provided {@link Executor}. + * @see IndexSearcher#IndexSearcher(IndexReader, Executor) */ + public SmallSliceSizeIndexSearcher(IndexReader r, Executor executor) { + super(r, executor); + } + + @Override + protected LeafSlice[] slices(List leaves) { + return slices(leaves, MAX_DOCS_PER_SLICE, MAX_SEGMENTS_PER_SLICE); + } + +}