diff --git a/src/main/java/org/elasticsearch/search/aggregations/bucket/range/RangeAggregator.java b/src/main/java/org/elasticsearch/search/aggregations/bucket/range/RangeAggregator.java index 00799d98c6c5b..8da74aa133a45 100644 --- a/src/main/java/org/elasticsearch/search/aggregations/bucket/range/RangeAggregator.java +++ b/src/main/java/org/elasticsearch/search/aggregations/bucket/range/RangeAggregator.java @@ -23,16 +23,16 @@ import org.apache.lucene.util.InPlaceMergeSorter; import org.elasticsearch.index.fielddata.DoubleValues; import org.elasticsearch.search.aggregations.Aggregator; +import org.elasticsearch.search.aggregations.AggregatorFactories; import org.elasticsearch.search.aggregations.InternalAggregation; import org.elasticsearch.search.aggregations.InternalAggregations; import org.elasticsearch.search.aggregations.bucket.BucketsAggregator; import org.elasticsearch.search.aggregations.support.AggregationContext; +import org.elasticsearch.search.aggregations.support.ValueSourceAggregatorFactory; import org.elasticsearch.search.aggregations.support.ValuesSourceConfig; import org.elasticsearch.search.aggregations.support.numeric.NumericValuesSource; import org.elasticsearch.search.aggregations.support.numeric.ValueFormatter; import org.elasticsearch.search.aggregations.support.numeric.ValueParser; -import org.elasticsearch.search.aggregations.AggregatorFactories; -import org.elasticsearch.search.aggregations.support.ValueSourceAggregatorFactory; import java.io.IOException; import java.util.ArrayList; @@ -94,7 +94,7 @@ public RangeAggregator(String name, AggregationContext aggregationContext, Aggregator parent) { - super(name, BucketAggregationMode.PER_BUCKET, factories, ranges.size(), aggregationContext, parent); + super(name, BucketAggregationMode.MULTI_BUCKETS, factories, ranges.size() * parent.estimatedBucketCount(), aggregationContext, parent); assert valuesSource != null; this.valuesSource = valuesSource; this.keyed = keyed; @@ -118,19 +118,21 @@ public boolean shouldCollect() { return true; } + private final long subBucketOrdinal(long owningBucketOrdinal, int rangeOrd) { + return owningBucketOrdinal * ranges.length + rangeOrd; + } + @Override public void collect(int doc, long owningBucketOrdinal) throws IOException { - assert owningBucketOrdinal == 0; - final DoubleValues values = valuesSource.doubleValues(); final int valuesCount = values.setDocument(doc); for (int i = 0, lo = 0; i < valuesCount; ++i) { final double value = values.nextValue(); - lo = collect(doc, value, lo); + lo = collect(doc, value, owningBucketOrdinal, lo); } } - private int collect(int doc, double value, int lowBound) throws IOException { + private int collect(int doc, double value, long owningBucketOrdinal, int lowBound) throws IOException { int lo = lowBound, hi = ranges.length - 1; // all candidates are between these indexes int mid = (lo + hi) >>> 1; while (lo <= hi) { @@ -172,7 +174,7 @@ private int collect(int doc, double value, int lowBound) throws IOException { for (int i = startLo; i <= endHi; ++i) { if (ranges[i].matches(value)) { - collectBucket(doc, i); + collectBucket(doc, subBucketOrdinal(owningBucketOrdinal, i)); } } @@ -181,12 +183,12 @@ private int collect(int doc, double value, int lowBound) throws IOException { @Override public InternalAggregation buildAggregation(long owningBucketOrdinal) { - assert owningBucketOrdinal == 0; List buckets = Lists.newArrayListWithCapacity(ranges.length); for (int i = 0; i < ranges.length; i++) { Range range = ranges[i]; - RangeBase.Bucket bucket = rangeFactory.createBucket(range.key, range.from, range.to, bucketDocCount(i), - bucketAggregations(i), valuesSource.formatter()); + final long bucketOrd = subBucketOrdinal(owningBucketOrdinal, i); + RangeBase.Bucket bucket = rangeFactory.createBucket(range.key, range.from, range.to, bucketDocCount(bucketOrd), + bucketAggregations(bucketOrd), valuesSource.formatter()); buckets.add(bucket); } // value source can be null in the case of unmapped fields @@ -246,7 +248,7 @@ public Unmapped(String name, Aggregator parent, AbstractRangeBase.Factory factory) { - super(name, BucketAggregationMode.PER_BUCKET, AggregatorFactories.EMPTY, 0, aggregationContext, parent); + super(name, BucketAggregationMode.MULTI_BUCKETS, AggregatorFactories.EMPTY, 0, aggregationContext, parent); this.ranges = ranges; for (Range range : this.ranges) { range.process(parser, context); diff --git a/src/test/java/org/elasticsearch/search/aggregations/bucket/RangeTests.java b/src/test/java/org/elasticsearch/search/aggregations/bucket/RangeTests.java index c07c76a03a7dd..38d1207dd4b4d 100644 --- a/src/test/java/org/elasticsearch/search/aggregations/bucket/RangeTests.java +++ b/src/test/java/org/elasticsearch/search/aggregations/bucket/RangeTests.java @@ -25,6 +25,7 @@ import org.elasticsearch.common.settings.Settings; import org.elasticsearch.search.aggregations.bucket.histogram.Histogram; import org.elasticsearch.search.aggregations.bucket.range.Range; +import org.elasticsearch.search.aggregations.bucket.terms.Terms; import org.elasticsearch.search.aggregations.metrics.avg.Avg; import org.elasticsearch.search.aggregations.metrics.sum.Sum; import org.elasticsearch.test.ElasticsearchIntegrationTest; @@ -76,6 +77,56 @@ public void init() throws Exception { ensureSearchable(); } + @Test + public void rangeAsSubAggregation() throws Exception { + SearchResponse response = client().prepareSearch("idx") + .addAggregation(terms("terms").field("values").size(100).subAggregation( + range("range").field("value") + .addUnboundedTo(3) + .addRange(3, 6) + .addUnboundedFrom(6))) + .execute().actionGet(); + + assertSearchResponse(response); + Terms terms = response.getAggregations().get("terms"); + assertThat(terms, notNullValue()); + assertThat(terms.buckets().size(), equalTo(numDocs + 1)); + for (int i = 1; i < numDocs + 2; ++i) { + Terms.Bucket bucket = terms.getByTerm("" + i); + assertThat(bucket, notNullValue()); + final long docCount = i == 1 || i == numDocs + 1 ? 1 : 2; + assertThat(bucket.getDocCount(), equalTo(docCount)); + Range range = bucket.getAggregations().get("range"); + Range.Bucket rangeBucket = range.getByKey("*-3.0"); + assertThat(rangeBucket, notNullValue()); + if (i == 1 || i == 3) { + assertThat(rangeBucket.getDocCount(), equalTo(1L)); + } else if (i == 2) { + assertThat(rangeBucket.getDocCount(), equalTo(2L)); + } else { + assertThat(rangeBucket.getDocCount(), equalTo(0L)); + } + rangeBucket = range.getByKey("3.0-6.0"); + assertThat(rangeBucket, notNullValue()); + if (i == 3 || i == 6) { + assertThat(rangeBucket.getDocCount(), equalTo(1L)); + } else if (i == 4 || i == 5) { + assertThat(rangeBucket.getDocCount(), equalTo(2L)); + } else { + assertThat(rangeBucket.getDocCount(), equalTo(0L)); + } + rangeBucket = range.getByKey("6.0-*"); + assertThat(rangeBucket, notNullValue()); + if (i == 6 || i == numDocs + 1) { + assertThat(rangeBucket.getDocCount(), equalTo(1L)); + } else if (i < 6) { + assertThat(rangeBucket.getDocCount(), equalTo(0L)); + } else { + assertThat(rangeBucket.getDocCount(), equalTo(2L)); + } + } + } + @Test public void singleValueField() throws Exception { SearchResponse response = client().prepareSearch("idx")