Skip to content

Commit

Permalink
script with _score: remove dependency of DocLookup and scorer
Browse files Browse the repository at this point in the history
As pointed out in #7487 DocLookup is a variable that is accessible by all scripts
for one doc while the query is executed. But the _score and therfore the scorer
depends on the current context, that is, which part of query is currently executed.
Instead of setting the scorer for DocLookup
and have Script access the DocLookup for getting the score, the Scorer should just
be explicitely set for each script.
DocLookup should not have any reference to a scorer.
This was similarly discussed in #7043.

This dependency caused a stackoverflow when running script score in combination with an
aggregation on _score. Also the wrong scorer was called when nesting several script scores.

closes #7487
closes #7819
  • Loading branch information
brwe committed Sep 26, 2014
1 parent 9c9cd01 commit 7feb742
Show file tree
Hide file tree
Showing 9 changed files with 54 additions and 53 deletions.
Expand Up @@ -95,7 +95,7 @@ void setLookup(SearchLookup lookup) {

@Override
public void setScorer(Scorer scorer) {
lookup.setScorer(scorer);
throw new UnsupportedOperationException();
}

@Override
Expand Down
9 changes: 5 additions & 4 deletions src/main/java/org/elasticsearch/script/ScoreAccessor.java
Expand Up @@ -19,6 +19,7 @@

package org.elasticsearch.script;

import org.apache.lucene.search.Scorer;
import org.elasticsearch.search.lookup.DocLookup;

import java.io.IOException;
Expand All @@ -31,15 +32,15 @@
*/
public final class ScoreAccessor extends Number {

final DocLookup doc;
Scorer scorer;

public ScoreAccessor(DocLookup d) {
doc = d;
public ScoreAccessor(Scorer scorer) {
this.scorer = scorer;
}

float score() {
try {
return doc.score();
return scorer.score();
} catch (IOException e) {
throw new RuntimeException("Could not get score", e);
}
Expand Down
21 changes: 0 additions & 21 deletions src/main/java/org/elasticsearch/script/ScriptService.java
Expand Up @@ -230,9 +230,6 @@ public ScriptService(Settings settings, Environment env, Set<ScriptEngineService
}
this.scriptEngines = builder.build();

// put some default optimized scripts
staticCache.put("doc.score", new CompiledScript("native", new DocScoreNativeScriptFactory()));

// add file watcher for static scripts
scriptsDirectory = new File(env.configFile(), "scripts");
if (logger.isTraceEnabled()) {
Expand Down Expand Up @@ -574,22 +571,4 @@ public int hashCode() {
return lang.hashCode() + 31 * script.hashCode();
}
}

public static class DocScoreNativeScriptFactory implements NativeScriptFactory {
@Override
public ExecutableScript newScript(@Nullable Map<String, Object> params) {
return new DocScoreSearchScript();
}
}

public static class DocScoreSearchScript extends AbstractFloatSearchScript {
@Override
public float runAsFloat() {
try {
return doc().score();
} catch (IOException e) {
return 0;
}
}
}
}
Expand Up @@ -43,6 +43,7 @@
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.script.*;
import org.elasticsearch.search.lookup.SearchLookup;
import org.elasticsearch.search.suggest.term.TermSuggestion;

import java.io.IOException;
import java.math.BigDecimal;
Expand Down Expand Up @@ -186,6 +187,7 @@ public static final class GroovyScript implements ExecutableScript, SearchScript
private final SearchLookup lookup;
private final Map<String, Object> variables;
private final ESLogger logger;
private Scorer scorer;

public GroovyScript(Script script, ESLogger logger) {
this(script, null, logger);
Expand All @@ -196,17 +198,12 @@ public GroovyScript(Script script, @Nullable SearchLookup lookup, ESLogger logge
this.lookup = lookup;
this.logger = logger;
this.variables = script.getBinding().getVariables();
if (lookup != null) {
// Add the _score variable, which will access score from lookup.doc()
this.variables.put("_score", new ScoreAccessor(lookup.doc()));
}
}

@Override
public void setScorer(Scorer scorer) {
if (lookup != null) {
lookup.setScorer(scorer);
}
this.scorer = scorer;
this.variables.put("_score", new ScoreAccessor(scorer));
}

@Override
Expand Down
14 changes: 0 additions & 14 deletions src/main/java/org/elasticsearch/search/lookup/DocLookup.java
Expand Up @@ -49,8 +49,6 @@ public class DocLookup implements Map {

private AtomicReaderContext reader;

private Scorer scorer;

private int docId = -1;

DocLookup(MapperService mapperService, IndexFieldDataService fieldDataService, @Nullable String[] types) {
Expand All @@ -76,22 +74,10 @@ public void setNextReader(AtomicReaderContext context) {
localCacheFieldData.clear();
}

public void setScorer(Scorer scorer) {
this.scorer = scorer;
}

public void setNextDocId(int docId) {
this.docId = docId;
}

public float score() throws IOException {
return scorer.score();
}

public float getScore() throws IOException {
return scorer.score();
}

@Override
public Object get(Object key) {
// assume its a string...
Expand Down
Expand Up @@ -76,10 +76,6 @@ public DocLookup doc() {
return this.docMap;
}

public void setScorer(Scorer scorer) {
docMap.setScorer(scorer);
}

public void setNextReader(AtomicReaderContext context) {
docMap.setNextReader(context);
sourceLookup.setNextReader(context);
Expand Down
Expand Up @@ -1172,7 +1172,7 @@ public void script_Score() {
.setQuery(functionScoreQuery(matchAllQuery()).add(ScoreFunctionBuilders.scriptFunction("doc['" + SINGLE_VALUED_FIELD_NAME + "'].value")))
.addAggregation(terms("terms")
.collectMode(randomFrom(SubAggCollectionMode.values()))
.script("ceil(_doc.score()/3)")
.script("ceil(_score.doubleValue()/3)")
).execute().actionGet();

assertSearchResponse(response);
Expand Down
Expand Up @@ -270,7 +270,7 @@ public void testFieldCollapsing() throws Exception {
topHits("hits").setSize(1)
)
.subAggregation(
max("max_score").script("_doc.score()")
max("max_score").script("_score.doubleValue()")
)
)
.get();
Expand Down
Expand Up @@ -28,6 +28,7 @@
import org.elasticsearch.index.query.functionscore.FunctionScoreQueryBuilder;
import org.elasticsearch.index.query.functionscore.ScoreFunctionBuilder;
import org.elasticsearch.index.query.functionscore.weight.WeightBuilder;
import org.elasticsearch.search.aggregations.bucket.terms.Terms;
import org.elasticsearch.test.ElasticsearchIntegrationTest;
import org.junit.Test;

Expand All @@ -40,6 +41,7 @@
import static org.elasticsearch.index.query.QueryBuilders.functionScoreQuery;
import static org.elasticsearch.index.query.QueryBuilders.termQuery;
import static org.elasticsearch.index.query.functionscore.ScoreFunctionBuilders.*;
import static org.elasticsearch.search.aggregations.AggregationBuilders.terms;
import static org.elasticsearch.search.builder.SearchSourceBuilder.searchSource;
import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertAcked;
import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertSearchResponse;
Expand Down Expand Up @@ -388,4 +390,44 @@ public void checkWeightOnlyCreatesBoostFunction() throws IOException {
assertSearchResponse(response);
assertThat(response.getHits().getAt(0).score(), equalTo(2.0f));
}

@Test
public void testScriptScoresNested() throws IOException {
index(INDEX, TYPE, "1", jsonBuilder().startObject().field("dummy_field", 1).endObject());
refresh();
SearchResponse response = client().search(
searchRequest().source(
searchSource().query(
functionScoreQuery(
functionScoreQuery(
functionScoreQuery().add(scriptFunction("1")))
.add(scriptFunction("_score.doubleValue()")))
.add(scriptFunction("_score.doubleValue()")
)
)
)
).actionGet();
assertSearchResponse(response);
assertThat(response.getHits().getAt(0).score(), equalTo(1.0f));
}

@Test
public void testScriptScoresWithAgg() throws IOException {
index(INDEX, TYPE, "1", jsonBuilder().startObject().field("dummy_field", 1).endObject());
refresh();
SearchResponse response = client().search(
searchRequest().source(
searchSource().query(
functionScoreQuery()
.add(scriptFunction("_score.doubleValue()")
)
).aggregation(terms("score_agg").script("_score.doubleValue()"))
)
).actionGet();
assertSearchResponse(response);
assertThat(response.getHits().getAt(0).score(), equalTo(1.0f));
assertThat(((Terms) response.getAggregations().asMap().get("score_agg")).getBuckets().get(0).getKeyAsNumber().floatValue(), is(1f));
assertThat(((Terms) response.getAggregations().asMap().get("score_agg")).getBuckets().get(0).getDocCount(), is(1l));
}
}

0 comments on commit 7feb742

Please sign in to comment.