From ba1042e7d19b2d197c7c9f449f97d2edf1b1ed91 Mon Sep 17 00:00:00 2001 From: Martijn van Groningen Date: Sun, 20 Jul 2014 21:57:27 +0200 Subject: [PATCH] Aggregations: Track scores should be applied properly for `top_hits` aggregation. Closes #6934 --- .../metrics/tophits/TopHitsAggregator.java | 2 +- .../aggregations/bucket/TopHitsTests.java | 51 +++++++++++++++++-- 2 files changed, 49 insertions(+), 4 deletions(-) diff --git a/src/main/java/org/elasticsearch/search/aggregations/metrics/tophits/TopHitsAggregator.java b/src/main/java/org/elasticsearch/search/aggregations/metrics/tophits/TopHitsAggregator.java index c85838c3b37fb..dd0b232b3c0d3 100644 --- a/src/main/java/org/elasticsearch/search/aggregations/metrics/tophits/TopHitsAggregator.java +++ b/src/main/java/org/elasticsearch/search/aggregations/metrics/tophits/TopHitsAggregator.java @@ -107,7 +107,7 @@ public void collect(int docId, long bucketOrdinal) throws IOException { int topN = topHitsContext.from() + topHitsContext.size(); topDocsCollectors.put( bucketOrdinal, - topDocsCollector = sort != null ? TopFieldCollector.create(sort, topN, true, topHitsContext.trackScores(), true, false) : TopScoreDocCollector.create(topN, false) + topDocsCollector = sort != null ? TopFieldCollector.create(sort, topN, true, topHitsContext.trackScores(), topHitsContext.trackScores(), false) : TopScoreDocCollector.create(topN, false) ); topDocsCollector.setNextReader(currentContext); topDocsCollector.setScorer(currentScorer); diff --git a/src/test/java/org/elasticsearch/search/aggregations/bucket/TopHitsTests.java b/src/test/java/org/elasticsearch/search/aggregations/bucket/TopHitsTests.java index 460576ff9acb8..0fbe2a1428d2a 100644 --- a/src/test/java/org/elasticsearch/search/aggregations/bucket/TopHitsTests.java +++ b/src/test/java/org/elasticsearch/search/aggregations/bucket/TopHitsTests.java @@ -48,6 +48,7 @@ import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertSearchResponse; import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.not; import static org.hamcrest.core.IsNull.notNullValue; /** @@ -431,9 +432,6 @@ public void testFailDeferred() throws Exception { assertThat(e.getMessage(), containsString("ElasticsearchParseException")); } } - - - @Test public void testEmptyIndex() throws Exception { @@ -448,4 +446,51 @@ public void testEmptyIndex() throws Exception { assertThat(hits.getHits().totalHits(), equalTo(0l)); } + @Test + public void testTrackScores() throws Exception { + boolean[] trackScores = new boolean[]{true, false}; + for (boolean trackScore : trackScores) { + logger.info("Track score=" + trackScore); + SearchResponse response = client().prepareSearch("idx").setTypes("field-collapsing") + .setQuery(matchQuery("text", "term rare")) + .addAggregation(terms("terms") + .field("group") + .subAggregation( + topHits("hits") + .setTrackScores(trackScore) + .setSize(1) + .addSort("_id", SortOrder.DESC) + ) + ) + .get(); + assertSearchResponse(response); + + Terms terms = response.getAggregations().get("terms"); + assertThat(terms, notNullValue()); + assertThat(terms.getName(), equalTo("terms")); + assertThat(terms.getBuckets().size(), equalTo(3)); + + Terms.Bucket bucket = terms.getBucketByKey("a"); + assertThat(key(bucket), equalTo("a")); + TopHits topHits = bucket.getAggregations().get("hits"); + SearchHits hits = topHits.getHits(); + assertThat(hits.getMaxScore(), trackScore ? not(equalTo(Float.NaN)) : equalTo(Float.NaN)); + assertThat(hits.getAt(0).score(), trackScore ? not(equalTo(Float.NaN)) : equalTo(Float.NaN)); + + bucket = terms.getBucketByKey("b"); + assertThat(key(bucket), equalTo("b")); + topHits = bucket.getAggregations().get("hits"); + hits = topHits.getHits(); + assertThat(hits.getMaxScore(), trackScore ? not(equalTo(Float.NaN)) : equalTo(Float.NaN)); + assertThat(hits.getAt(0).score(), trackScore ? not(equalTo(Float.NaN)) : equalTo(Float.NaN)); + + bucket = terms.getBucketByKey("c"); + assertThat(key(bucket), equalTo("c")); + topHits = bucket.getAggregations().get("hits"); + hits = topHits.getHits(); + assertThat(hits.getMaxScore(), trackScore ? not(equalTo(Float.NaN)) : equalTo(Float.NaN)); + assertThat(hits.getAt(0).score(), trackScore ? not(equalTo(Float.NaN)) : equalTo(Float.NaN)); + } + } + }