Skip to content

Commit

Permalink
End-to-end support for grouping count aggregator function (#306)
Browse files Browse the repository at this point in the history
End-to-end support for grouping count aggregator function, e.g. _from
test | stats count(count) by data_.
  • Loading branch information
ChrisHegarty committed Oct 19, 2022
1 parent 8571e52 commit e0ca1e8
Show file tree
Hide file tree
Showing 8 changed files with 361 additions and 25 deletions.
Expand Up @@ -33,6 +33,14 @@ public interface GroupingAggregatorFunction {
}
};

BiFunction<AggregatorMode, Integer, GroupingAggregatorFunction> count = (AggregatorMode mode, Integer inputChannel) -> {
if (mode.isInputPartial()) {
return GroupingCountAggregator.createIntermediate();
} else {
return GroupingCountAggregator.create(inputChannel);
}
};

BiFunction<AggregatorMode, Integer, GroupingAggregatorFunction> min = (AggregatorMode mode, Integer inputChannel) -> {
if (mode.isInputPartial()) {
return GroupingMinAggregator.createIntermediate();
Expand Down
@@ -0,0 +1,89 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0 and the Server Side Public License, v 1; you may not use this file except
* in compliance with, at your election, the Elastic License 2.0 or the Server
* Side Public License, v 1.
*/

package org.elasticsearch.compute.aggregation;

import org.elasticsearch.compute.Experimental;
import org.elasticsearch.compute.data.AggregatorStateBlock;
import org.elasticsearch.compute.data.Block;
import org.elasticsearch.compute.data.LongArrayBlock;
import org.elasticsearch.compute.data.Page;

@Experimental
public class GroupingCountAggregator implements GroupingAggregatorFunction {

private final LongArrayState state;
private final int channel;

static GroupingCountAggregator create(int inputChannel) {
if (inputChannel < 0) {
throw new IllegalArgumentException();
}
return new GroupingCountAggregator(inputChannel, new LongArrayState(0));
}

static GroupingCountAggregator createIntermediate() {
return new GroupingCountAggregator(-1, new LongArrayState(0));
}

private GroupingCountAggregator(int channel, LongArrayState state) {
this.channel = channel;
this.state = state;
}

@Override
public void addRawInput(Block groupIdBlock, Page page) {
assert channel >= 0;
Block valuesBlock = page.getBlock(channel);
LongArrayState s = this.state;
int len = valuesBlock.getPositionCount();
for (int i = 0; i < len; i++) {
int groupId = (int) groupIdBlock.getLong(i);
s.set(s.getOrDefault(groupId, 0) + 1, groupId);
}
}

@Override
public void addIntermediateInput(Block groupIdBlock, Block block) {
assert channel == -1;
if (block instanceof AggregatorStateBlock) {
@SuppressWarnings("unchecked")
AggregatorStateBlock<LongArrayState> blobBlock = (AggregatorStateBlock<LongArrayState>) block;
LongArrayState tmpState = new LongArrayState(0);
blobBlock.get(0, tmpState);
final long[] values = tmpState.getValues();
final int positions = groupIdBlock.getPositionCount();
final LongArrayState s = state;
for (int i = 0; i < positions; i++) {
int groupId = (int) groupIdBlock.getLong(i);
s.set(s.getOrDefault(groupId, 0) + values[i], groupId);
}
} else {
throw new RuntimeException("expected AggregatorStateBlock, got:" + block);
}
}

@Override
public Block evaluateIntermediate() {
AggregatorStateBlock.Builder<AggregatorStateBlock<LongArrayState>, LongArrayState> builder = AggregatorStateBlock
.builderOfAggregatorState(LongArrayState.class);
builder.add(state);
return builder.build();
}

@Override
public Block evaluateFinal() {
LongArrayState s = state;
int positions = s.largestIndex + 1;
long[] result = new long[positions];
for (int i = 0; i < positions; i++) {
result[i] = s.get(i);
}
return new LongArrayBlock(result, positions);
}
}
@@ -0,0 +1,107 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0 and the Server Side Public License, v 1; you may not use this file except
* in compliance with, at your election, the Elastic License 2.0 or the Server
* Side Public License, v 1.
*/

package org.elasticsearch.compute.aggregation;

import org.elasticsearch.compute.Experimental;

import java.lang.invoke.MethodHandles;
import java.lang.invoke.VarHandle;
import java.nio.ByteOrder;
import java.util.Arrays;
import java.util.Objects;

@Experimental
final class LongArrayState implements AggregatorState<LongArrayState> {

private long[] values;
// total number of groups; <= values.length
int largestIndex;

private final LongArrayStateSerializer serializer;

LongArrayState(long... values) {
this.values = values;
this.serializer = new LongArrayStateSerializer();
}

long[] getValues() {
return values;
}

long get(int index) {
// TODO bounds check
return values[index];
}

long getOrDefault(int index, long defaultValue) {
if (index > largestIndex) {
return defaultValue;
} else {
return values[index];
}
}

void set(long value, int index) {
ensureCapacity(index);
if (index > largestIndex) {
largestIndex = index;
}
values[index] = value;
}

private void ensureCapacity(int position) {
if (position >= values.length) {
int newSize = values.length << 1; // trivial
values = Arrays.copyOf(values, newSize);
}
}

@Override
public AggregatorStateSerializer<LongArrayState> serializer() {
return serializer;
}

static class LongArrayStateSerializer implements AggregatorStateSerializer<LongArrayState> {

static final int BYTES_SIZE = Long.BYTES;

@Override
public int size() {
return BYTES_SIZE;
}

private static final VarHandle longHandle = MethodHandles.byteArrayViewVarHandle(long[].class, ByteOrder.BIG_ENDIAN);

@Override
public int serialize(LongArrayState state, byte[] ba, int offset) {
int positions = state.largestIndex + 1;
longHandle.set(ba, offset, positions);
offset += Long.BYTES;
for (int i = 0; i < positions; i++) {
longHandle.set(ba, offset, state.values[i]);
offset += BYTES_SIZE;
}
return Long.BYTES + (BYTES_SIZE * positions); // number of bytes written
}

@Override
public void deserialize(LongArrayState state, byte[] ba, int offset) {
Objects.requireNonNull(state);
int positions = (int) (long) longHandle.get(ba, offset);
offset += Long.BYTES;
long[] values = new long[positions];
for (int i = 0; i < positions; i++) {
values[i] = (long) longHandle.get(ba, offset);
offset += BYTES_SIZE;
}
state.values = values;
state.largestIndex = positions - 1;
}
}
}
18 changes: 14 additions & 4 deletions server/src/test/java/org/elasticsearch/compute/OperatorTests.java
Expand Up @@ -494,7 +494,8 @@ public void testBasicGroupingOperators() {
new GroupingAggregator(GroupingAggregatorFunction.avg, AggregatorMode.INITIAL, 1),
new GroupingAggregator(GroupingAggregatorFunction.max, AggregatorMode.INITIAL, 1),
new GroupingAggregator(GroupingAggregatorFunction.min, AggregatorMode.INITIAL, 1),
new GroupingAggregator(GroupingAggregatorFunction.sum, AggregatorMode.INITIAL, 1)
new GroupingAggregator(GroupingAggregatorFunction.sum, AggregatorMode.INITIAL, 1),
new GroupingAggregator(GroupingAggregatorFunction.count, AggregatorMode.INITIAL, 1)
),
BigArrays.NON_RECYCLING_INSTANCE
),
Expand All @@ -504,7 +505,8 @@ public void testBasicGroupingOperators() {
new GroupingAggregator(GroupingAggregatorFunction.avg, AggregatorMode.INTERMEDIATE, 1),
new GroupingAggregator(GroupingAggregatorFunction.max, AggregatorMode.INTERMEDIATE, 2),
new GroupingAggregator(GroupingAggregatorFunction.min, AggregatorMode.INTERMEDIATE, 3),
new GroupingAggregator(GroupingAggregatorFunction.sum, AggregatorMode.INTERMEDIATE, 4)
new GroupingAggregator(GroupingAggregatorFunction.sum, AggregatorMode.INTERMEDIATE, 4),
new GroupingAggregator(GroupingAggregatorFunction.count, AggregatorMode.INTERMEDIATE, 5)
),
BigArrays.NON_RECYCLING_INSTANCE
),
Expand All @@ -514,7 +516,8 @@ public void testBasicGroupingOperators() {
new GroupingAggregator(GroupingAggregatorFunction.avg, AggregatorMode.FINAL, 1),
new GroupingAggregator(GroupingAggregatorFunction.max, AggregatorMode.FINAL, 2),
new GroupingAggregator(GroupingAggregatorFunction.min, AggregatorMode.FINAL, 3),
new GroupingAggregator(GroupingAggregatorFunction.sum, AggregatorMode.FINAL, 4)
new GroupingAggregator(GroupingAggregatorFunction.sum, AggregatorMode.FINAL, 4),
new GroupingAggregator(GroupingAggregatorFunction.count, AggregatorMode.FINAL, 5)
),
BigArrays.NON_RECYCLING_INSTANCE
),
Expand All @@ -530,7 +533,7 @@ public void testBasicGroupingOperators() {
driver.run();
assertEquals(1, pageCount.get());
assertEquals(cardinality, rowCount.get());
assertEquals(5, lastPage.get().getBlockCount());
assertEquals(6, lastPage.get().getBlockCount());

final Block groupIdBlock = lastPage.get().getBlock(0);
assertEquals(cardinality, groupIdBlock.getPositionCount());
Expand Down Expand Up @@ -567,6 +570,13 @@ public void testBasicGroupingOperators() {
.collect(toMap(i -> initialGroupId + i, i -> (double) IntStream.range(i * 100, (i * 100) + 100).sum()));
var actualSumValues = IntStream.range(0, cardinality).boxed().collect(toMap(groupIdBlock::getLong, sumValuesBlock::getDouble));
assertEquals(expectedSumValues, actualSumValues);

// assert count
final Block countValuesBlock = lastPage.get().getBlock(5);
assertEquals(cardinality, countValuesBlock.getPositionCount());
var expectedCountValues = IntStream.range(0, cardinality).boxed().collect(toMap(i -> initialGroupId + i, i -> 100L));
var actualCountValues = IntStream.range(0, cardinality).boxed().collect(toMap(groupIdBlock::getLong, countValuesBlock::getLong));
assertEquals(expectedCountValues, actualCountValues);
}

// Tests grouping avg aggregations with multiple intermediate partial blocks.
Expand Down
Expand Up @@ -73,15 +73,15 @@ public void testRow() {
assertEquals(List.of(List.of(value)), response.values());
}

public void testFromStats() {
testFromStatsImpl("from test | stats avg(count)", "avg(count)");
public void testFromStatsAvg() {
testFromStatsAvgImpl("from test | stats avg(count)", "avg(count)");
}

public void testFromStatsWithAlias() {
testFromStatsImpl("from test | stats f1 = avg(count)", "f1");
public void testFromStatsAvgWithAlias() {
testFromStatsAvgImpl("from test | stats f1 = avg(count)", "f1");
}

private void testFromStatsImpl(String command, String expectedFieldName) {
private void testFromStatsAvgImpl(String command, String expectedFieldName) {
EsqlQueryResponse results = run(command);
logger.info(results);
Assert.assertEquals(1, results.columns().size());
Expand All @@ -92,20 +92,39 @@ private void testFromStatsImpl(String command, String expectedFieldName) {
assertEquals(43, (double) results.values().get(0).get(0), 1d);
}

public void testFromStatsCount() {
testFromStatsCountImpl("from test | stats count(data)", "count(data)");
}

public void testFromStatsCountWithAlias() {
testFromStatsCountImpl("from test | stats dataCount = count(data)", "dataCount");
}

public void testFromStatsCountImpl(String command, String expectedFieldName) {
EsqlQueryResponse results = run(command);
logger.info(results);
Assert.assertEquals(1, results.columns().size());
Assert.assertEquals(1, results.values().size());
assertEquals(expectedFieldName, results.columns().get(0).name());
assertEquals("long", results.columns().get(0).type());
assertEquals(1, results.values().get(0).size());
assertEquals(40L, results.values().get(0).get(0));
}

@AwaitsFix(bugUrl = "line 1:45: Unknown column [data]")
public void testFromStatsGroupingWithSort() { // FIX ME
testFromStatsGroupingImpl("from test | stats avg(count) by data | sort data | limit 2", "avg(count)", "data");
public void testFromStatsGroupingAvgWithSort() { // FIX ME
testFromStatsGroupingAvgImpl("from test | stats avg(count) by data | sort data | limit 2", "avg(count)", "data");
}

public void testFromStatsGrouping() {
testFromStatsGroupingImpl("from test | stats avg(count) by data", "avg(count)", "data");
public void testFromStatsGroupingAvg() {
testFromStatsGroupingAvgImpl("from test | stats avg(count) by data", "avg(count)", "data");
}

public void testFromStatsGroupingWithAliases() {
testFromStatsGroupingImpl("from test | eval g = data | stats f = avg(count) by g", "f", "g");
public void testFromStatsGroupingAvgWithAliases() {
testFromStatsGroupingAvgImpl("from test | eval g = data | stats f = avg(count) by g", "f", "g");
}

private void testFromStatsGroupingImpl(String command, String expectedFieldName, String expectedGroupName) {
private void testFromStatsGroupingAvgImpl(String command, String expectedFieldName, String expectedGroupName) {
EsqlQueryResponse results = run(command);
logger.info(results);
Assert.assertEquals(2, results.columns().size());
Expand Down Expand Up @@ -135,6 +154,44 @@ private void testFromStatsGroupingImpl(String command, String expectedFieldName,
}
}

public void testFromStatsGroupingCount() {
testFromStatsGroupingCountImpl("from test | stats count(count) by data", "count(count)", "data");
}

public void testFromStatsGroupingCountWithAliases() {
testFromStatsGroupingCountImpl("from test | eval grp = data | stats total = count(count) by grp", "total", "grp");
}

private void testFromStatsGroupingCountImpl(String command, String expectedFieldName, String expectedGroupName) {
EsqlQueryResponse results = run(command);
logger.info(results);
Assert.assertEquals(2, results.columns().size());

// assert column metadata
ColumnInfo groupColumn = results.columns().get(0);
assertEquals(expectedGroupName, groupColumn.name());
assertEquals("long", groupColumn.type());
ColumnInfo valuesColumn = results.columns().get(1);
assertEquals(expectedFieldName, valuesColumn.name());
assertEquals("long", valuesColumn.type());

// assert column values
List<List<Object>> valueValues = results.values();
assertEquals(2, valueValues.size());
// This is loathsome, find a declarative way to assert the expected output.
if ((long) valueValues.get(0).get(0) == 1L) {
assertEquals(20L, valueValues.get(0).get(1));
assertEquals(2L, valueValues.get(1).get(0));
assertEquals(20L, valueValues.get(1).get(1));
} else if ((long) valueValues.get(0).get(0) == 2L) {
assertEquals(20L, valueValues.get(1).get(1));
assertEquals(1L, valueValues.get(1).get(0));
assertEquals(20L, valueValues.get(0).get(1));
} else {
fail("Unexpected group value: " + valueValues.get(0).get(0));
}
}

// Grouping where the groupby field is of a date type.
public void testFromStatsGroupingByDate() {
EsqlQueryResponse results = run("from test | stats avg(count) by time");
Expand Down
Expand Up @@ -8,6 +8,7 @@
package org.elasticsearch.xpack.esql.expression.function;

import org.elasticsearch.xpack.esql.expression.function.aggregate.Avg;
import org.elasticsearch.xpack.esql.expression.function.aggregate.Count;
import org.elasticsearch.xpack.ql.expression.function.FunctionDefinition;
import org.elasticsearch.xpack.ql.expression.function.FunctionRegistry;

Expand All @@ -24,7 +25,8 @@ public EsqlFunctionRegistry() {
}

private FunctionDefinition[][] functions() {
return new FunctionDefinition[][] { new FunctionDefinition[] { def(Avg.class, Avg::new, "avg") } };
return new FunctionDefinition[][] {
new FunctionDefinition[] { def(Avg.class, Avg::new, "avg"), def(Count.class, Count::new, "count") } };
}

@Override
Expand Down

0 comments on commit e0ca1e8

Please sign in to comment.