Skip to content
Permalink
Browse files
Graceful null handling and correctness in DoubleMean Aggregator (#12320)
* Adding null handling for double mean aggregator

* Updating code to handle nulls in DoubleMean aggregator

* oops last one should have checkstyle issues. fixed

* Updating some code and test cases

* Checking on object is null in case of numeric aggregator

* Adding one more test to improve coverage

* Changing one test as asked in the review

* Changing one test as asked in the review for nulls
  • Loading branch information
somu-imply committed Mar 14, 2022
1 parent 3de1272 commit b5195c5095a1088cb06ed602704fce110232f109
Showing 5 changed files with 113 additions and 5 deletions.
@@ -19,6 +19,7 @@

package org.apache.druid.query.aggregation.mean;

import org.apache.druid.common.config.NullHandling;
import org.apache.druid.java.util.common.Numbers;
import org.apache.druid.query.aggregation.Aggregator;
import org.apache.druid.segment.ColumnValueSelector;
@@ -43,6 +44,10 @@ public void aggregate()
{
Object update = selector.getObject();

if (update == null && NullHandling.replaceWithDefault() == false) {
return;
}

if (update instanceof DoubleMeanHolder) {
value.update((DoubleMeanHolder) update);
} else if (update instanceof List) {
@@ -19,6 +19,7 @@

package org.apache.druid.query.aggregation.mean;

import org.apache.druid.common.config.NullHandling;
import org.apache.druid.java.util.common.Numbers;
import org.apache.druid.query.aggregation.BufferAggregator;
import org.apache.druid.query.monomorphicprocessing.RuntimeShapeInspector;
@@ -51,6 +52,9 @@ public void aggregate(ByteBuffer buf, int position)
{
Object update = selector.getObject();

if (update == null && NullHandling.replaceWithDefault() == false) {
return;
}
if (update instanceof DoubleMeanHolder) {
DoubleMeanHolder.update(buf, position, (DoubleMeanHolder) update);
} else if (update instanceof List) {
@@ -20,6 +20,7 @@
package org.apache.druid.query.aggregation.mean;

import com.google.common.base.Preconditions;
import org.apache.druid.common.config.NullHandling;
import org.apache.druid.query.aggregation.VectorAggregator;
import org.apache.druid.segment.vector.VectorValueSelector;

@@ -45,11 +46,28 @@ public void init(final ByteBuffer buf, final int position)
public void aggregate(final ByteBuffer buf, final int position, final int startRow, final int endRow)
{
final double[] vector = selector.getDoubleVector();
for (int i = startRow; i < endRow; i++) {
DoubleMeanHolder.update(buf, position, vector[i]);
final boolean[] nulls = selector.getNullVector();

if (nulls != null) {
if (NullHandling.replaceWithDefault()) {
for (int i = startRow; i < endRow; i++) {
DoubleMeanHolder.update(buf, position, vector[i]);
}
} else {
for (int i = startRow; i < endRow; i++) {
if (!nulls[i]) {
DoubleMeanHolder.update(buf, position, vector[i]);
}
}
}
} else {
for (int i = startRow; i < endRow; i++) {
DoubleMeanHolder.update(buf, position, vector[i]);
}
}
}


@Override
public void aggregate(
final ByteBuffer buf,
@@ -60,10 +78,27 @@ public void aggregate(
)
{
final double[] vector = selector.getDoubleVector();
final boolean[] nulls = selector.getNullVector();

for (int i = 0; i < numRows; i++) {
final double val = vector[rows != null ? rows[i] : i];
DoubleMeanHolder.update(buf, positions[i] + positionOffset, val);
if (nulls != null) {
if (NullHandling.replaceWithDefault()) {
for (int i = 0; i < numRows; i++) {
final double val = vector[rows != null ? rows[i] : i];
DoubleMeanHolder.update(buf, positions[i] + positionOffset, val);
}
} else {
for (int j = 0; j < numRows; j++) {
if (!nulls[j]) {
final double val = vector[rows != null ? rows[j] : j];
DoubleMeanHolder.update(buf, positions[j] + positionOffset, val);
}
}
}
} else {
for (int i = 0; i < numRows; i++) {
final double val = vector[rows != null ? rows[i] : i];
DoubleMeanHolder.update(buf, positions[i] + positionOffset, val);
}
}
}

@@ -26,6 +26,7 @@
import com.google.common.collect.Lists;
import junitparams.JUnitParamsRunner;
import junitparams.Parameters;
import org.apache.druid.common.config.NullHandling;
import org.apache.druid.data.input.Row;
import org.apache.druid.java.util.common.granularity.Granularities;
import org.apache.druid.java.util.common.guava.Sequence;
@@ -46,6 +47,7 @@
import org.apache.druid.segment.IncrementalIndexSegment;
import org.apache.druid.segment.QueryableIndexSegment;
import org.apache.druid.segment.Segment;
import org.apache.druid.segment.TestIndex;
import org.apache.druid.timeline.SegmentId;
import org.easymock.EasyMock;
import org.junit.Assert;
@@ -72,6 +74,7 @@ public Object[] doVectorize()
private final AggregationTestHelper timeseriesQueryTestHelper;

private final List<Segment> segments;
private final List<Segment> biggerSegments;

public DoubleMeanAggregationTest()
{
@@ -91,6 +94,11 @@ public DoubleMeanAggregationTest()
new IncrementalIndexSegment(SimpleTestIndex.getIncrementalTestIndex(), SegmentId.dummy("test1")),
new QueryableIndexSegment(SimpleTestIndex.getMMappedTestIndex(), SegmentId.dummy("test2"))
);

biggerSegments = ImmutableList.of(
new IncrementalIndexSegment(TestIndex.getIncrementalTestIndex(), SegmentId.dummy("test1")),
new QueryableIndexSegment(TestIndex.getMMappedTestIndex(), SegmentId.dummy("test2"))
);
}

@Test
@@ -145,6 +153,33 @@ public void testVectorAggretatorUsingGroupByQueryOnDoubleColumn(boolean doVector
Assert.assertEquals(6.2d, result.getMetric("meanOnDouble").doubleValue(), 0.0001d);
}

@Test
@Parameters(method = "doVectorize")
public void testVectorAggretatorUsingGroupByQueryOnDoubleColumnOnBiggerSegments(boolean doVectorize) throws Exception
{
GroupByQuery query = new GroupByQuery.Builder()
.setDataSource("blah")
.setGranularity(Granularities.ALL)
.setInterval("1970/2050")
.setAggregatorSpecs(
new DoubleMeanAggregatorFactory("meanOnDouble", TestIndex.COLUMNS[9])
)
.setContext(Collections.singletonMap(QueryContexts.VECTORIZE_KEY, doVectorize))
.build();

// do json serialization and deserialization of query to ensure there are no serde issues
ObjectMapper jsonMapper = groupByQueryTestHelper.getObjectMapper();
query = (GroupByQuery) jsonMapper.readValue(jsonMapper.writeValueAsString(query), Query.class);

Sequence<ResultRow> seq = groupByQueryTestHelper.runQueryOnSegmentsObjs(biggerSegments, query);
Row result = Iterables.getOnlyElement(seq.toList()).toMapBasedRow(query);
if (NullHandling.replaceWithDefault()) {
Assert.assertEquals(39.2307d, result.getMetric("meanOnDouble").doubleValue(), 0.0001d);
} else {
Assert.assertEquals(51.0d, result.getMetric("meanOnDouble").doubleValue(), 0.0001d);
}
}

@Test
@Parameters(method = "doVectorize")
public void testAggretatorUsingTimeseriesQuery(boolean doVectorize) throws Exception
@@ -33,6 +33,7 @@
import org.apache.druid.collections.CloseableDefaultBlockingPool;
import org.apache.druid.collections.CloseableStupidPool;
import org.apache.druid.common.config.NullHandling;
import org.apache.druid.data.input.Row;
import org.apache.druid.data.input.Rows;
import org.apache.druid.java.util.common.DateTimes;
import org.apache.druid.java.util.common.IAE;
@@ -82,6 +83,7 @@
import org.apache.druid.query.aggregation.hyperloglog.HyperUniqueFinalizingPostAggregator;
import org.apache.druid.query.aggregation.hyperloglog.HyperUniquesAggregatorFactory;
import org.apache.druid.query.aggregation.last.LongLastAggregatorFactory;
import org.apache.druid.query.aggregation.mean.DoubleMeanAggregatorFactory;
import org.apache.druid.query.aggregation.post.ArithmeticPostAggregator;
import org.apache.druid.query.aggregation.post.ConstantPostAggregator;
import org.apache.druid.query.aggregation.post.ExpressionPostAggregator;
@@ -5881,6 +5883,33 @@ public void testDifferentIntervalSubquery()
TestHelper.assertExpectedObjects(expectedResults, results, "subquery-different-intervals");
}

@Test
public void testDoubleMeanQuery()
{
GroupByQuery query = new GroupByQuery.Builder()
.setDataSource(QueryRunnerTestHelper.DATA_SOURCE)
.setGranularity(Granularities.ALL)
.setQuerySegmentSpec(QueryRunnerTestHelper.FIRST_TO_THIRD)
.setAggregatorSpecs(
new DoubleMeanAggregatorFactory("meanOnDouble", "doubleNumericNull")
)
.build();

if (config.getDefaultStrategy().equals(GroupByStrategySelector.STRATEGY_V1)) {
expectedException.expect(ISE.class);
expectedException.expectMessage("Unable to handle complex type");
GroupByQueryRunnerTestHelper.runQuery(factory, runner, query);
} else {
Iterable<ResultRow> results = GroupByQueryRunnerTestHelper.runQuery(factory, runner, query);
Row result = Iterables.getOnlyElement(results).toMapBasedRow(query);
if (NullHandling.replaceWithDefault()) {
Assert.assertEquals(39.2307d, result.getMetric("meanOnDouble").doubleValue(), 0.0001d);
} else {
Assert.assertEquals(51.0d, result.getMetric("meanOnDouble").doubleValue(), 0.0001d);
}
}
}

@Test
public void testGroupByTimeExtractionNamedUnderUnderTime()
{

0 comments on commit b5195c5

Please sign in to comment.