Skip to content

Commit

Permalink
lucene 4: Upgraded o.e.search.dfs package. #2
Browse files Browse the repository at this point in the history
  • Loading branch information
martijnvg committed Oct 30, 2012
1 parent 2b823ef commit f772730
Show file tree
Hide file tree
Showing 5 changed files with 109 additions and 32 deletions.
Expand Up @@ -88,21 +88,36 @@ public boolean optimizeSingleShard() {
}

public AggregatedDfs aggregateDfs(Iterable<DfsSearchResult> results) {
TMap<Term, TermStatistics> dfMap = new ExtTHashMap<Term, TermStatistics>(Constants.DEFAULT_CAPACITY, Constants.DEFAULT_LOAD_FACTOR);
TMap<Term, TermStatistics> termStatistics = new ExtTHashMap<Term, TermStatistics>(Constants.DEFAULT_CAPACITY, Constants.DEFAULT_LOAD_FACTOR);
TMap<String, CollectionStatistics> fieldStatistics = new ExtTHashMap<String, CollectionStatistics>(Constants.DEFAULT_CAPACITY, Constants.DEFAULT_LOAD_FACTOR);
long aggMaxDoc = 0;
for (DfsSearchResult result : results) {
for (int i = 0; i < result.termStatistics().length; i++) {
TermStatistics existing = dfMap.get(result.terms()[i]);
TermStatistics existing = termStatistics.get(result.terms()[i]);
if (existing != null) {
dfMap.put(result.terms()[i], new TermStatistics(existing.term(), existing.docFreq() + result.termStatistics()[i].docFreq(), existing.totalTermFreq() + result.termStatistics()[i].totalTermFreq()));
termStatistics.put(result.terms()[i], new TermStatistics(existing.term(), existing.docFreq() + result.termStatistics()[i].docFreq(), existing.totalTermFreq() + result.termStatistics()[i].totalTermFreq()));
} else {
dfMap.put(result.terms()[i], result.termStatistics()[i]);
termStatistics.put(result.terms()[i], result.termStatistics()[i]);
}

}
for (Map.Entry<String, CollectionStatistics> entry : result.fieldStatistics().entrySet()) {
CollectionStatistics existing = fieldStatistics.get(entry.getKey());
if (existing != null) {
CollectionStatistics merged = new CollectionStatistics(
entry.getKey(), existing.maxDoc() + entry.getValue().maxDoc(),
existing.docCount() + entry.getValue().docCount(),
existing.sumTotalTermFreq() + entry.getValue().sumTotalTermFreq(),
existing.sumDocFreq() + entry.getValue().sumDocFreq()
);
fieldStatistics.put(entry.getKey(), merged);
} else {
fieldStatistics.put(entry.getKey(), entry.getValue());
}
}
aggMaxDoc += result.maxDoc();
}
return new AggregatedDfs(dfMap, aggMaxDoc);
return new AggregatedDfs(termStatistics, fieldStatistics, aggMaxDoc);
}

public ShardDoc[] sortDocs(Collection<? extends QuerySearchResultProvider> results1) {
Expand Down
46 changes: 32 additions & 14 deletions src/main/java/org/elasticsearch/search/dfs/AggregatedDfs.java
Expand Up @@ -20,16 +20,14 @@
package org.elasticsearch.search.dfs;

import gnu.trove.impl.Constants;
import gnu.trove.iterator.TObjectIntIterator;
import gnu.trove.map.TMap;
import gnu.trove.map.hash.TObjectIntHashMap;
import org.apache.lucene.index.Term;
import org.apache.lucene.search.CollectionStatistics;
import org.apache.lucene.search.TermStatistics;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.io.stream.Streamable;
import org.elasticsearch.common.trove.ExtTHashMap;
import org.elasticsearch.common.trove.ExtTObjectIntHasMap;

import java.io.IOException;
import java.util.Map;
Expand All @@ -39,21 +37,26 @@
*/
public class AggregatedDfs implements Streamable {

private TMap<Term, TermStatistics> dfMap;

private TMap<Term, TermStatistics> termStatistics;
private TMap<String, CollectionStatistics> fieldStatistics;
private long maxDoc;

private AggregatedDfs() {

}

public AggregatedDfs(TMap<Term, TermStatistics> dfMap, long maxDoc) {
this.dfMap = dfMap;
public AggregatedDfs(TMap<Term, TermStatistics> termStatistics, TMap<String, CollectionStatistics> fieldStatistics, long maxDoc) {
this.termStatistics = termStatistics;
this.fieldStatistics = fieldStatistics;
this.maxDoc = maxDoc;
}

public TMap<Term, TermStatistics> dfMap() {
return dfMap;
public TMap<Term, TermStatistics> termStatistics() {
return termStatistics;
}

public TMap<String, CollectionStatistics> fieldStatistics() {
return fieldStatistics;
}

public long maxDoc() {
Expand All @@ -69,20 +72,26 @@ public static AggregatedDfs readAggregatedDfs(StreamInput in) throws IOException
@Override
public void readFrom(StreamInput in) throws IOException {
int size = in.readVInt();
dfMap = new ExtTHashMap<Term, TermStatistics>(size, Constants.DEFAULT_LOAD_FACTOR);
termStatistics = new ExtTHashMap<Term, TermStatistics>(size, Constants.DEFAULT_LOAD_FACTOR);
for (int i = 0; i < size; i++) {
Term term = new Term(in.readString(), in.readBytesRef());
TermStatistics stats = new TermStatistics(in.readBytesRef(), in.readVLong(), in.readVLong());
dfMap.put(term, stats);
termStatistics.put(term, stats);
}
size = in.readVInt();
fieldStatistics = new ExtTHashMap<String, CollectionStatistics>(size, Constants.DEFAULT_LOAD_FACTOR);
for (int i = 0; i < size; i++) {
String field = in.readString();
CollectionStatistics stats = new CollectionStatistics(field, in.readVLong(), in.readVLong(), in.readVLong(), in.readVLong());
fieldStatistics.put(field, stats);
}
maxDoc = in.readVLong();
}

@Override
public void writeTo(final StreamOutput out) throws IOException {
out.writeVInt(dfMap.size());

for (Map.Entry<Term, TermStatistics> termTermStatisticsEntry : dfMap.entrySet()) {
out.writeVInt(termStatistics.size());
for (Map.Entry<Term, TermStatistics> termTermStatisticsEntry : termStatistics.entrySet()) {
Term term = termTermStatisticsEntry.getKey();
out.writeString(term.field());
out.writeBytesRef(term.bytes());
Expand All @@ -92,6 +101,15 @@ public void writeTo(final StreamOutput out) throws IOException {
out.writeVLong(stats.totalTermFreq());
}

out.writeVInt(fieldStatistics.size());
for (Map.Entry<String, CollectionStatistics> entry : fieldStatistics.entrySet()) {
out.writeString(entry.getKey());
out.writeVLong(entry.getValue().maxDoc());
out.writeVLong(entry.getValue().docCount());
out.writeVLong(entry.getValue().sumTotalTermFreq());
out.writeVLong(entry.getValue().sumDocFreq());
}

out.writeVLong(maxDoc);
}
}
19 changes: 11 additions & 8 deletions src/main/java/org/elasticsearch/search/dfs/CachedDfSource.java
Expand Up @@ -24,7 +24,6 @@
import org.apache.lucene.search.*;
import org.apache.lucene.search.similarities.Similarity;
import org.elasticsearch.ElasticSearchIllegalArgumentException;
import org.elasticsearch.search.internal.SearchContext;

import java.io.IOException;
import java.util.List;
Expand All @@ -34,25 +33,25 @@
*/
public class CachedDfSource extends IndexSearcher {

private final AggregatedDfs dfs;
private final AggregatedDfs aggregatedDfs;

private final int maxDoc;

public CachedDfSource(IndexReader reader, AggregatedDfs dfs, Similarity similarity) throws IOException {
public CachedDfSource(IndexReader reader, AggregatedDfs aggregatedDfs, Similarity similarity) throws IOException {
super(reader);
this.dfs = dfs;
this.aggregatedDfs = aggregatedDfs;
setSimilarity(similarity);
if (dfs.maxDoc() > Integer.MAX_VALUE) {
if (aggregatedDfs.maxDoc() > Integer.MAX_VALUE) {
maxDoc = Integer.MAX_VALUE;
} else {
maxDoc = (int) dfs.maxDoc();
maxDoc = (int) aggregatedDfs.maxDoc();
}
}


@Override
public TermStatistics termStatistics(Term term, TermContext context) throws IOException {
TermStatistics termStatistics = dfs.dfMap().get(term);
TermStatistics termStatistics = aggregatedDfs.termStatistics().get(term);
if (termStatistics == null) {
throw new ElasticSearchIllegalArgumentException("Not distributed term statistics for term: " + term);
}
Expand All @@ -61,7 +60,11 @@ public TermStatistics termStatistics(Term term, TermContext context) throws IOEx

@Override
public CollectionStatistics collectionStatistics(String field) throws IOException {
throw new UnsupportedOperationException();
CollectionStatistics collectionStatistics = aggregatedDfs.fieldStatistics().get(field);
if (collectionStatistics == null) {
throw new ElasticSearchIllegalArgumentException("Not distributed collection statistics for field: " + field);
}
return collectionStatistics;
}

public int maxDoc() {
Expand Down
19 changes: 15 additions & 4 deletions src/main/java/org/elasticsearch/search/dfs/DfsPhase.java
Expand Up @@ -20,17 +20,23 @@
package org.elasticsearch.search.dfs;

import com.google.common.collect.ImmutableMap;
import gnu.trove.map.TMap;
import gnu.trove.set.hash.THashSet;
import org.apache.lucene.index.IndexReaderContext;
import org.apache.lucene.index.Term;
import org.apache.lucene.index.TermContext;
import org.apache.lucene.search.CollectionStatistics;
import org.apache.lucene.search.TermStatistics;
import org.elasticsearch.common.trove.ExtTHashMap;
import org.elasticsearch.common.util.concurrent.ThreadLocals;
import org.elasticsearch.search.SearchParseElement;
import org.elasticsearch.search.SearchPhase;
import org.elasticsearch.search.internal.SearchContext;

import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;

/**
*
Expand Down Expand Up @@ -71,11 +77,16 @@ public void execute(SearchContext context) {
termStatistics[i] = context.searcher().termStatistics(terms[i], termContext);
}

// TODO: LUCENE 4 UPGRADE - add collection stats for each unique field, for distributed scoring
// context.searcher().collectionStatistics()
TMap<String, CollectionStatistics> fieldStatistics = new ExtTHashMap<String, CollectionStatistics>();
for (Term term : terms) {
if (!fieldStatistics.containsKey(term.field())) {
fieldStatistics.put(term.field(), context.searcher().collectionStatistics(term.field()));
}
}

context.dfsResult().termsAndFreqs(terms, termStatistics);
context.dfsResult().maxDoc(context.searcher().getIndexReader().maxDoc());
context.dfsResult().termsStatistics(terms, termStatistics)
.fieldStatistics(fieldStatistics)
.maxDoc(context.searcher().getIndexReader().maxDoc());
} catch (Exception e) {
throw new DfsPhaseExecutionException(context, "Exception during dfs phase", e);
}
Expand Down
32 changes: 31 additions & 1 deletion src/main/java/org/elasticsearch/search/dfs/DfsSearchResult.java
Expand Up @@ -19,16 +19,20 @@

package org.elasticsearch.search.dfs;

import gnu.trove.map.TMap;
import org.apache.lucene.index.Term;
import org.apache.lucene.search.CollectionStatistics;
import org.apache.lucene.search.TermStatistics;
import org.apache.lucene.util.BytesRef;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.trove.ExtTHashMap;
import org.elasticsearch.search.SearchPhaseResult;
import org.elasticsearch.search.SearchShardTarget;
import org.elasticsearch.transport.TransportResponse;

import java.io.IOException;
import java.util.Map;

/**
*
Expand All @@ -42,6 +46,7 @@ public class DfsSearchResult extends TransportResponse implements SearchPhaseRes
private long id;
private Term[] terms;
private TermStatistics[] termStatistics;
private TMap<String, CollectionStatistics> fieldStatistics = new ExtTHashMap<String, CollectionStatistics>();
private int maxDoc;

public DfsSearchResult() {
Expand Down Expand Up @@ -75,12 +80,17 @@ public int maxDoc() {
return maxDoc;
}

public DfsSearchResult termsAndFreqs(Term[] terms, TermStatistics[] termStatistics) {
public DfsSearchResult termsStatistics(Term[] terms, TermStatistics[] termStatistics) {
this.terms = terms;
this.termStatistics = termStatistics;
return this;
}

public DfsSearchResult fieldStatistics(TMap<String, CollectionStatistics> fieldStatistics) {
this.fieldStatistics = fieldStatistics;
return this;
}

public Term[] terms() {
return terms;
}
Expand All @@ -89,6 +99,10 @@ public TermStatistics[] termStatistics() {
return termStatistics;
}

public TMap<String, CollectionStatistics> fieldStatistics() {
return fieldStatistics;
}

public static DfsSearchResult readDfsSearchResult(StreamInput in) throws IOException, ClassNotFoundException {
DfsSearchResult result = new DfsSearchResult();
result.readFrom(in);
Expand Down Expand Up @@ -121,6 +135,13 @@ public void readFrom(StreamInput in) throws IOException {
termStatistics[i] = new TermStatistics(term, docFreq, totalTermFreq);
}
}
int numFieldStatistics = in.readVInt();
for (int i = 0; i < numFieldStatistics; i++) {
String field = in.readString();
CollectionStatistics stats = new CollectionStatistics(field, in.readVLong(), in.readVLong(), in.readVLong(), in.readVLong());
fieldStatistics.put(field, stats);
}

maxDoc = in.readVInt();
}

Expand All @@ -139,6 +160,15 @@ public void writeTo(StreamOutput out) throws IOException {
out.writeVLong(termStatistic.docFreq());
out.writeVLong(termStatistic.totalTermFreq());
}
out.writeVInt(fieldStatistics.size());
for (Map.Entry<String, CollectionStatistics> entry : fieldStatistics.entrySet()) {
out.writeString(entry.getKey());
out.writeVLong(entry.getValue().maxDoc());
out.writeVLong(entry.getValue().docCount());
out.writeVLong(entry.getValue().sumTotalTermFreq());
out.writeVLong(entry.getValue().sumDocFreq());
}
out.writeVInt(maxDoc);
}

}

0 comments on commit f772730

Please sign in to comment.