Skip to content

Commit

Permalink
Make TermStates#build concurrent (#12183)
Browse files Browse the repository at this point in the history
  • Loading branch information
shubhamvishu committed Sep 20, 2023
1 parent 1cb0d81 commit 408d0fb
Show file tree
Hide file tree
Showing 17 changed files with 111 additions and 46 deletions.
2 changes: 1 addition & 1 deletion lucene/CHANGES.txt
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ Improvements

Optimizations
---------------------
(No changes)
* GITHUB#12183: Make TermStates#build concurrent. (Shubham Chaudhary)

Changes in runtime behavior
---------------------
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
import org.apache.lucene.analysis.tokenattributes.CharTermAttribute;
import org.apache.lucene.analysis.tokenattributes.TermFrequencyAttribute;
import org.apache.lucene.index.IndexOptions;
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.Term;
import org.apache.lucene.index.TermStates;
import org.apache.lucene.search.BooleanQuery;
Expand Down Expand Up @@ -345,7 +344,7 @@ public FeatureFunction rewrite(IndexSearcher indexSearcher) throws IOException {
if (pivot != null) {
return super.rewrite(indexSearcher);
}
float newPivot = computePivotFeatureValue(indexSearcher.getIndexReader(), field, feature);
float newPivot = computePivotFeatureValue(indexSearcher, field, feature);
return new SaturationFunction(field, feature, newPivot);
}

Expand Down Expand Up @@ -618,14 +617,14 @@ public static Query newSigmoidQuery(
* store the exponent in the higher bits, it means that the result will be an approximation of the
* geometric mean of all feature values.
*
* @param reader the {@link IndexReader} to search against
* @param searcher the {@link IndexSearcher} to perform the search
* @param featureField the field that stores features
* @param featureName the name of the feature
*/
static float computePivotFeatureValue(IndexReader reader, String featureField, String featureName)
throws IOException {
static float computePivotFeatureValue(
IndexSearcher searcher, String featureField, String featureName) throws IOException {
Term term = new Term(featureField, featureName);
TermStates states = TermStates.build(reader.getContext(), term, true);
TermStates states = TermStates.build(searcher, term, true);
if (states.docFreq() == 0) {
// avoid division by 0
// The return value doesn't matter much here, the term doesn't exist,
Expand Down
89 changes: 79 additions & 10 deletions lucene/core/src/java/org/apache/lucene/index/TermStates.java
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,11 @@
package org.apache.lucene.index;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.TaskExecutor;

/**
* Maintains a {@link IndexReader} {@link TermState} view over {@link IndexReader} instances
Expand Down Expand Up @@ -86,19 +90,48 @@ public TermStates(
* @param needsStats if {@code true} then all leaf contexts will be visited up-front to collect
* term statistics. Otherwise, the {@link TermState} objects will be built only when requested
*/
public static TermStates build(IndexReaderContext context, Term term, boolean needsStats)
public static TermStates build(IndexSearcher indexSearcher, Term term, boolean needsStats)
throws IOException {
assert context != null && context.isTopLevel;
IndexReaderContext context = indexSearcher.getTopReaderContext();
assert context != null;
final TermStates perReaderTermState = new TermStates(needsStats ? null : term, context);
if (needsStats) {
for (final LeafReaderContext ctx : context.leaves()) {
// if (DEBUG) System.out.println(" r=" + leaves[i].reader);
TermsEnum termsEnum = loadTermsEnum(ctx, term);
if (termsEnum != null) {
final TermState termState = termsEnum.termState();
// if (DEBUG) System.out.println(" found");
perReaderTermState.register(
termState, ctx.ord, termsEnum.docFreq(), termsEnum.totalTermFreq());
TaskExecutor taskExecutor = indexSearcher.getTaskExecutor();
if (taskExecutor != null) {
// build the term states concurrently
List<TaskExecutor.Task<TermStateInfo>> tasks =
context.leaves().stream()
.map(
ctx ->
taskExecutor.createTask(
() -> {
TermsEnum termsEnum = loadTermsEnum(ctx, term);
if (termsEnum != null) {
return new TermStateInfo(
termsEnum.termState(),
ctx.ord,
termsEnum.docFreq(),
termsEnum.totalTermFreq());
}
return null;
}))
.toList();
List<TaskExecutor.Task<TermStateInfo>> taskList = new ArrayList<>(tasks);
List<TermStateInfo> resultInfos = taskExecutor.invokeAll(taskList);
for (TermStateInfo info : resultInfos) {
if (info != null) {
perReaderTermState.register(
info.getState(), info.getOrdinal(), info.getDocFreq(), info.getTotalTermFreq());
}
}
} else {
// build the term states sequentially
for (final LeafReaderContext ctx : context.leaves()) {
TermsEnum termsEnum = loadTermsEnum(ctx, term);
if (termsEnum != null) {
perReaderTermState.register(
termsEnum.termState(), ctx.ord, termsEnum.docFreq(), termsEnum.totalTermFreq());
}
}
}
}
Expand Down Expand Up @@ -211,4 +244,40 @@ public String toString() {

return sb.toString();
}

/** Wrapper over TermState, ordinal value, term doc frequency and total term frequency */
private static final class TermStateInfo {
private final TermState state;
private final int ordinal;
private final int docFreq;
private final long totalTermFreq;

/** Initialize TermStateInfo */
public TermStateInfo(TermState state, int ordinal, int docFreq, long totalTermFreq) {
this.state = state;
this.ordinal = ordinal;
this.docFreq = docFreq;
this.totalTermFreq = totalTermFreq;
}

/** Get term state */
public TermState getState() {
return state;
}

/** Get ordinal value */
public int getOrdinal() {
return ordinal;
}

/** Get term doc frequency */
public int getDocFreq() {
return docFreq;
}

/** Get total term frequency */
public long getTotalTermFreq() {
return totalTermFreq;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -272,7 +272,7 @@ public final Query rewrite(IndexSearcher indexSearcher) throws IOException {
for (int i = 0; i < contexts.length; ++i) {
if (contexts[i] == null
|| contexts[i].wasBuiltFor(indexSearcher.getTopReaderContext()) == false) {
contexts[i] = TermStates.build(indexSearcher.getTopReaderContext(), terms[i], true);
contexts[i] = TermStates.build(indexSearcher, terms[i], true);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
import java.util.List;
import java.util.Map;
import java.util.Objects;
import org.apache.lucene.index.IndexReaderContext;
import org.apache.lucene.index.LeafReader;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.PostingsEnum;
Expand Down Expand Up @@ -219,15 +218,14 @@ public Weight createWeight(IndexSearcher searcher, ScoreMode scoreMode, float bo

@Override
protected Similarity.SimScorer getStats(IndexSearcher searcher) throws IOException {
final IndexReaderContext context = searcher.getTopReaderContext();

// compute idf
ArrayList<TermStatistics> allTermStats = new ArrayList<>();
for (final Term[] terms : termArrays) {
for (Term term : terms) {
TermStates ts = termStates.get(term);
if (ts == null) {
ts = TermStates.build(context, term, scoreMode.needsScores());
ts = TermStates.build(searcher, term, scoreMode.needsScores());
termStates.put(term, ts);
}
if (scoreMode.needsScores() && ts.docFreq() > 0) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
import org.apache.lucene.codecs.lucene90.Lucene90PostingsFormat;
import org.apache.lucene.codecs.lucene90.Lucene90PostingsReader;
import org.apache.lucene.index.ImpactsEnum;
import org.apache.lucene.index.IndexReaderContext;
import org.apache.lucene.index.LeafReader;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.PostingsEnum;
Expand Down Expand Up @@ -451,13 +450,12 @@ protected Similarity.SimScorer getStats(IndexSearcher searcher) throws IOExcepti
throw new IllegalStateException(
"PhraseWeight requires that the first position is 0, call rewrite first");
}
final IndexReaderContext context = searcher.getTopReaderContext();
states = new TermStates[terms.length];
TermStatistics[] termStats = new TermStatistics[terms.length];
int termUpTo = 0;
for (int i = 0; i < terms.length; i++) {
final Term term = terms[i];
states[i] = TermStates.build(context, term, scoreMode.needsScores());
states[i] = TermStates.build(searcher, term, scoreMode.needsScores());
if (scoreMode.needsScores()) {
TermStates ts = states[i];
if (ts.docFreq() > 0) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,7 @@ class SynonymWeight extends Weight {
termStates = new TermStates[terms.length];
for (int i = 0; i < termStates.length; i++) {
Term term = new Term(field, terms[i].term);
TermStates ts = TermStates.build(searcher.getTopReaderContext(), term, true);
TermStates ts = TermStates.build(searcher, term, true);
termStates[i] = ts;
if (ts.docFreq() > 0) {
TermStatistics termStats =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -272,7 +272,7 @@ public Weight createWeight(IndexSearcher searcher, ScoreMode scoreMode, float bo
final IndexReaderContext context = searcher.getTopReaderContext();
final TermStates termState;
if (perReaderTermState == null || perReaderTermState.wasBuiltFor(context) == false) {
termState = TermStates.build(context, term, scoreMode.needsScores());
termState = TermStates.build(searcher, term, scoreMode.needsScores());
} else {
// PRTS was pre-build for this IS
termState = this.perReaderTermState;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -272,7 +272,8 @@ public void testComputePivotFeatureValue() throws IOException {

// Make sure that we create a legal pivot on missing features
DirectoryReader reader = writer.getReader();
float pivot = FeatureField.computePivotFeatureValue(reader, "features", "pagerank");
IndexSearcher searcher = LuceneTestCase.newSearcher(reader);
float pivot = FeatureField.computePivotFeatureValue(searcher, "features", "pagerank");
assertTrue(Float.isFinite(pivot));
assertTrue(pivot > 0);
reader.close();
Expand All @@ -298,7 +299,8 @@ public void testComputePivotFeatureValue() throws IOException {
reader = writer.getReader();
writer.close();

pivot = FeatureField.computePivotFeatureValue(reader, "features", "pagerank");
searcher = LuceneTestCase.newSearcher(reader);
pivot = FeatureField.computePivotFeatureValue(searcher, "features", "pagerank");
double expected = Math.pow(10 * 100 * 1 * 42, 1 / 4.); // geometric mean
assertEquals(expected, pivot, 0.1);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package org.apache.lucene.index;

import org.apache.lucene.document.Document;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.store.Directory;
import org.apache.lucene.tests.index.RandomIndexWriter;
import org.apache.lucene.tests.util.LuceneTestCase;
Expand All @@ -30,8 +31,8 @@ public void testToStringOnNullTermState() throws Exception {
RandomIndexWriter w = new RandomIndexWriter(random(), dir);
w.addDocument(new Document());
IndexReader r = w.getReader();
TermStates states =
TermStates.build(r.getContext(), new Term("foo", "bar"), random().nextBoolean());
IndexSearcher s = new IndexSearcher(r);
TermStates states = TermStates.build(s, new Term("foo", "bar"), random().nextBoolean());
assertEquals("TermStates\n state=null\n", states.toString());
IOUtils.close(r, w, dir);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -365,7 +365,7 @@ static class SlowMinShouldMatchScorer extends Scorer {
if (ord >= 0) {
boolean success = ords.add(ord);
assert success; // no dups
TermStates ts = TermStates.build(reader.getContext(), term, true);
TermStates ts = TermStates.build(searcher, term, true);
SimScorer w =
weight.similarity.scorer(
1f,
Expand Down
16 changes: 8 additions & 8 deletions lucene/core/src/test/org/apache/lucene/search/TestTermQuery.java
Original file line number Diff line number Diff line change
Expand Up @@ -55,11 +55,12 @@ public void testEquals() throws IOException {
final CompositeReaderContext context;
try (MultiReader multiReader = new MultiReader()) {
context = multiReader.getContext();
IndexSearcher searcher = new IndexSearcher(context);
QueryUtils.checkEqual(
new TermQuery(new Term("foo", "bar")),
new TermQuery(
new Term("foo", "bar"), TermStates.build(searcher, new Term("foo", "bar"), true)));
}
QueryUtils.checkEqual(
new TermQuery(new Term("foo", "bar")),
new TermQuery(
new Term("foo", "bar"), TermStates.build(context, new Term("foo", "bar"), true)));
}

public void testCreateWeightDoesNotSeekIfScoresAreNotNeeded() throws IOException {
Expand Down Expand Up @@ -100,8 +101,7 @@ public void testCreateWeightDoesNotSeekIfScoresAreNotNeeded() throws IOException
assertEquals(1, totalHits);
TermQuery queryWithContext =
new TermQuery(
new Term("foo", "bar"),
TermStates.build(reader.getContext(), new Term("foo", "bar"), true));
new Term("foo", "bar"), TermStates.build(searcher, new Term("foo", "bar"), true));
totalHits = searcher.search(queryWithContext, DummyTotalHitCountCollector.createManager());
assertEquals(1, totalHits);

Expand Down Expand Up @@ -160,10 +160,10 @@ public void testGetTermStates() throws Exception {
w.addDocument(new Document());

DirectoryReader reader = w.getReader();
IndexSearcher searcher = new IndexSearcher(reader);
TermQuery queryWithContext =
new TermQuery(
new Term("foo", "bar"),
TermStates.build(reader.getContext(), new Term("foo", "bar"), true));
new Term("foo", "bar"), TermStates.build(searcher, new Term("foo", "bar"), true));
assertNotNull(queryWithContext.getTermStates());
IOUtils.close(reader, w, dir);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ public SpanWeight createWeight(IndexSearcher searcher, ScoreMode scoreMode, floa
final TermStates context;
final IndexReaderContext topContext = searcher.getTopReaderContext();
if (termStates == null || termStates.wasBuiltFor(topContext) == false) {
context = TermStates.build(topContext, term, scoreMode.needsScores());
context = TermStates.build(searcher, term, scoreMode.needsScores());
} else {
context = termStates;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -330,7 +330,7 @@ class CombinedFieldWeight extends Weight {
termStates = new TermStates[fieldTerms.length];
for (int i = 0; i < termStates.length; i++) {
FieldAndWeight field = fieldAndWeights.get(fieldTerms[i].field());
TermStates ts = TermStates.build(searcher.getTopReaderContext(), fieldTerms[i], true);
TermStates ts = TermStates.build(searcher, fieldTerms[i], true);
termStates[i] = ts;
if (ts.docFreq() > 0) {
TermStatistics termStats =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -375,7 +375,7 @@ protected int collectSingleTermData(
TermData termData = termsData.getOrCreateTermData(singleTerm.termPosition);
Term term = singleTerm.term;
termData.terms.add(term);
TermStates termStates = TermStates.build(searcher.getIndexReader().getContext(), term, true);
TermStates termStates = TermStates.build(searcher, term, true);

// Collect TermState per segment.
int numMatches = 0;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.apache.lucene.index.IndexReaderContext;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.PostingsEnum;
import org.apache.lucene.index.ReaderUtil;
Expand Down Expand Up @@ -209,14 +208,13 @@ public void finish(int determinizeWorkLimit) {
@Override
public Weight createWeight(IndexSearcher searcher, ScoreMode scoreMode, float boost)
throws IOException {
IndexReaderContext context = searcher.getTopReaderContext();
Map<Integer, TermStates> termStates = new HashMap<>();

for (Map.Entry<BytesRef, Integer> ent : termToID.entrySet()) {
if (ent.getKey() != null) {
termStates.put(
ent.getValue(),
TermStates.build(context, new Term(field, ent.getKey()), scoreMode.needsScores()));
TermStates.build(searcher, new Term(field, ent.getKey()), scoreMode.needsScores()));
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,7 @@ Map<Term, TermStatistics> getNodeTermStats(Set<Term> terms, int nodeID, long ver
}
try {
for (Term term : terms) {
final TermStates ts = TermStates.build(s.getIndexReader().getContext(), term, true);
final TermStates ts = TermStates.build(s, term, true);
if (ts.docFreq() > 0) {
stats.put(term, s.termStatistics(term, ts.docFreq(), ts.totalTermFreq()));
}
Expand Down

0 comments on commit 408d0fb

Please sign in to comment.