Skip to content

Commit

Permalink
Use typecasting comparator for numeric "any" aggregations. (#16494)
Browse files Browse the repository at this point in the history
This brings them in line with the behavior of other numeric aggregations.
It is important because otherwise ClassCastExceptions can arise if comparing
different numeric types that may arise from deserialization.
  • Loading branch information
gianm committed May 22, 2024
1 parent 44ea4e1 commit eb410f7
Show file tree
Hide file tree
Showing 8 changed files with 41 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
*/
public class FloatSumAggregator implements Aggregator
{
static final Comparator COMPARATOR = new Ordering()
public static final Comparator COMPARATOR = new Ordering()
{
@Override
public int compare(Object o, Object o1)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
*/
public class LongSumAggregator implements Aggregator
{
static final Comparator COMPARATOR = new Ordering()
public static final Comparator COMPARATOR = new Ordering()
{
@Override
public int compare(Object o, Object o1)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import org.apache.druid.query.aggregation.AggregatorFactory;
import org.apache.druid.query.aggregation.AggregatorUtil;
import org.apache.druid.query.aggregation.BufferAggregator;
import org.apache.druid.query.aggregation.DoubleSumAggregator;
import org.apache.druid.query.aggregation.VectorAggregator;
import org.apache.druid.query.cache.CacheKeyBuilder;
import org.apache.druid.segment.BaseDoubleColumnValueSelector;
Expand All @@ -48,8 +49,6 @@

public class DoubleAnyAggregatorFactory extends AggregatorFactory
{
private static final Comparator<Double> VALUE_COMPARATOR = Comparator.nullsFirst(Double::compare);

private static final Aggregator NIL_AGGREGATOR = new DoubleAnyAggregator(
NilColumnValueSelector.instance()
)
Expand Down Expand Up @@ -136,7 +135,7 @@ public boolean canVectorize(ColumnInspector columnInspector)
@Override
public Comparator getComparator()
{
return VALUE_COMPARATOR;
return DoubleSumAggregator.COMPARATOR;
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import org.apache.druid.query.aggregation.AggregatorFactory;
import org.apache.druid.query.aggregation.AggregatorUtil;
import org.apache.druid.query.aggregation.BufferAggregator;
import org.apache.druid.query.aggregation.FloatSumAggregator;
import org.apache.druid.query.aggregation.VectorAggregator;
import org.apache.druid.query.cache.CacheKeyBuilder;
import org.apache.druid.segment.BaseFloatColumnValueSelector;
Expand All @@ -47,8 +48,6 @@

public class FloatAnyAggregatorFactory extends AggregatorFactory
{
private static final Comparator<Float> VALUE_COMPARATOR = Comparator.nullsFirst(Float::compare);

private static final Aggregator NIL_AGGREGATOR = new FloatAnyAggregator(
NilColumnValueSelector.instance()
)
Expand Down Expand Up @@ -133,7 +132,7 @@ public boolean canVectorize(ColumnInspector columnInspector)
@Override
public Comparator getComparator()
{
return VALUE_COMPARATOR;
return FloatSumAggregator.COMPARATOR;
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import org.apache.druid.query.aggregation.AggregatorFactory;
import org.apache.druid.query.aggregation.AggregatorUtil;
import org.apache.druid.query.aggregation.BufferAggregator;
import org.apache.druid.query.aggregation.LongSumAggregator;
import org.apache.druid.query.aggregation.VectorAggregator;
import org.apache.druid.query.cache.CacheKeyBuilder;
import org.apache.druid.segment.BaseLongColumnValueSelector;
Expand All @@ -46,8 +47,6 @@

public class LongAnyAggregatorFactory extends AggregatorFactory
{
private static final Comparator<Long> VALUE_COMPARATOR = Comparator.nullsFirst(Long::compare);

private static final Aggregator NIL_AGGREGATOR = new LongAnyAggregator(
NilColumnValueSelector.instance()
)
Expand Down Expand Up @@ -132,7 +131,7 @@ public boolean canVectorize(ColumnInspector columnInspector)
@Override
public Comparator getComparator()
{
return VALUE_COMPARATOR;
return LongSumAggregator.COMPARATOR;
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,17 @@ public void testComparatorWithNulls()
Assert.assertEquals(-1, comparator.compare(d2, d1));
}

@Test
public void testComparatorWithTypeMismatch()
{
Long n1 = 3L;
Double n2 = 4.0;
Comparator comparator = doubleAnyAggFactory.getComparator();
Assert.assertEquals(0, comparator.compare(n1, n1));
Assert.assertEquals(-1, comparator.compare(n1, n2));
Assert.assertEquals(1, comparator.compare(n2, n1));
}

@Test
public void testDoubleAnyCombiningAggregator()
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,17 @@ public void testComparatorWithNulls()
Assert.assertEquals(-1, comparator.compare(f2, f1));
}

@Test
public void testComparatorWithTypeMismatch()
{
Long n1 = 3L;
Float n2 = 4.0f;
Comparator comparator = floatAnyAggFactory.getComparator();
Assert.assertEquals(0, comparator.compare(n1, n1));
Assert.assertEquals(-1, comparator.compare(n1, n2));
Assert.assertEquals(1, comparator.compare(n2, n1));
}

@Test
public void testFloatAnyCombiningAggregator()
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,17 @@ public void testComparatorWithNulls()
Assert.assertEquals(-1, comparator.compare(l2, l1));
}

@Test
public void testComparatorWithTypeMismatch()
{
Integer n1 = 3;
Long n2 = 4L;
Comparator comparator = longAnyAggFactory.getComparator();
Assert.assertEquals(0, comparator.compare(n1, n1));
Assert.assertEquals(-1, comparator.compare(n1, n2));
Assert.assertEquals(1, comparator.compare(n2, n1));
}

@Test
public void testLongAnyCombiningAggregator()
{
Expand Down

0 comments on commit eb410f7

Please sign in to comment.