Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
End-to-end support for grouping count aggregator function (#306)
End-to-end support for grouping count aggregator function, e.g. _from test | stats count(count) by data_.
- Loading branch information
1 parent
8571e52
commit e0ca1e8
Showing
8 changed files
with
361 additions
and
25 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
89 changes: 89 additions & 0 deletions
89
server/src/main/java/org/elasticsearch/compute/aggregation/GroupingCountAggregator.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,89 @@ | ||
/* | ||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one | ||
* or more contributor license agreements. Licensed under the Elastic License | ||
* 2.0 and the Server Side Public License, v 1; you may not use this file except | ||
* in compliance with, at your election, the Elastic License 2.0 or the Server | ||
* Side Public License, v 1. | ||
*/ | ||
|
||
package org.elasticsearch.compute.aggregation; | ||
|
||
import org.elasticsearch.compute.Experimental; | ||
import org.elasticsearch.compute.data.AggregatorStateBlock; | ||
import org.elasticsearch.compute.data.Block; | ||
import org.elasticsearch.compute.data.LongArrayBlock; | ||
import org.elasticsearch.compute.data.Page; | ||
|
||
@Experimental | ||
public class GroupingCountAggregator implements GroupingAggregatorFunction { | ||
|
||
private final LongArrayState state; | ||
private final int channel; | ||
|
||
static GroupingCountAggregator create(int inputChannel) { | ||
if (inputChannel < 0) { | ||
throw new IllegalArgumentException(); | ||
} | ||
return new GroupingCountAggregator(inputChannel, new LongArrayState(0)); | ||
} | ||
|
||
static GroupingCountAggregator createIntermediate() { | ||
return new GroupingCountAggregator(-1, new LongArrayState(0)); | ||
} | ||
|
||
private GroupingCountAggregator(int channel, LongArrayState state) { | ||
this.channel = channel; | ||
this.state = state; | ||
} | ||
|
||
@Override | ||
public void addRawInput(Block groupIdBlock, Page page) { | ||
assert channel >= 0; | ||
Block valuesBlock = page.getBlock(channel); | ||
LongArrayState s = this.state; | ||
int len = valuesBlock.getPositionCount(); | ||
for (int i = 0; i < len; i++) { | ||
int groupId = (int) groupIdBlock.getLong(i); | ||
s.set(s.getOrDefault(groupId, 0) + 1, groupId); | ||
} | ||
} | ||
|
||
@Override | ||
public void addIntermediateInput(Block groupIdBlock, Block block) { | ||
assert channel == -1; | ||
if (block instanceof AggregatorStateBlock) { | ||
@SuppressWarnings("unchecked") | ||
AggregatorStateBlock<LongArrayState> blobBlock = (AggregatorStateBlock<LongArrayState>) block; | ||
LongArrayState tmpState = new LongArrayState(0); | ||
blobBlock.get(0, tmpState); | ||
final long[] values = tmpState.getValues(); | ||
final int positions = groupIdBlock.getPositionCount(); | ||
final LongArrayState s = state; | ||
for (int i = 0; i < positions; i++) { | ||
int groupId = (int) groupIdBlock.getLong(i); | ||
s.set(s.getOrDefault(groupId, 0) + values[i], groupId); | ||
} | ||
} else { | ||
throw new RuntimeException("expected AggregatorStateBlock, got:" + block); | ||
} | ||
} | ||
|
||
@Override | ||
public Block evaluateIntermediate() { | ||
AggregatorStateBlock.Builder<AggregatorStateBlock<LongArrayState>, LongArrayState> builder = AggregatorStateBlock | ||
.builderOfAggregatorState(LongArrayState.class); | ||
builder.add(state); | ||
return builder.build(); | ||
} | ||
|
||
@Override | ||
public Block evaluateFinal() { | ||
LongArrayState s = state; | ||
int positions = s.largestIndex + 1; | ||
long[] result = new long[positions]; | ||
for (int i = 0; i < positions; i++) { | ||
result[i] = s.get(i); | ||
} | ||
return new LongArrayBlock(result, positions); | ||
} | ||
} |
107 changes: 107 additions & 0 deletions
107
server/src/main/java/org/elasticsearch/compute/aggregation/LongArrayState.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,107 @@ | ||
/* | ||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one | ||
* or more contributor license agreements. Licensed under the Elastic License | ||
* 2.0 and the Server Side Public License, v 1; you may not use this file except | ||
* in compliance with, at your election, the Elastic License 2.0 or the Server | ||
* Side Public License, v 1. | ||
*/ | ||
|
||
package org.elasticsearch.compute.aggregation; | ||
|
||
import org.elasticsearch.compute.Experimental; | ||
|
||
import java.lang.invoke.MethodHandles; | ||
import java.lang.invoke.VarHandle; | ||
import java.nio.ByteOrder; | ||
import java.util.Arrays; | ||
import java.util.Objects; | ||
|
||
@Experimental | ||
final class LongArrayState implements AggregatorState<LongArrayState> { | ||
|
||
private long[] values; | ||
// total number of groups; <= values.length | ||
int largestIndex; | ||
|
||
private final LongArrayStateSerializer serializer; | ||
|
||
LongArrayState(long... values) { | ||
this.values = values; | ||
this.serializer = new LongArrayStateSerializer(); | ||
} | ||
|
||
long[] getValues() { | ||
return values; | ||
} | ||
|
||
long get(int index) { | ||
// TODO bounds check | ||
return values[index]; | ||
} | ||
|
||
long getOrDefault(int index, long defaultValue) { | ||
if (index > largestIndex) { | ||
return defaultValue; | ||
} else { | ||
return values[index]; | ||
} | ||
} | ||
|
||
void set(long value, int index) { | ||
ensureCapacity(index); | ||
if (index > largestIndex) { | ||
largestIndex = index; | ||
} | ||
values[index] = value; | ||
} | ||
|
||
private void ensureCapacity(int position) { | ||
if (position >= values.length) { | ||
int newSize = values.length << 1; // trivial | ||
values = Arrays.copyOf(values, newSize); | ||
} | ||
} | ||
|
||
@Override | ||
public AggregatorStateSerializer<LongArrayState> serializer() { | ||
return serializer; | ||
} | ||
|
||
static class LongArrayStateSerializer implements AggregatorStateSerializer<LongArrayState> { | ||
|
||
static final int BYTES_SIZE = Long.BYTES; | ||
|
||
@Override | ||
public int size() { | ||
return BYTES_SIZE; | ||
} | ||
|
||
private static final VarHandle longHandle = MethodHandles.byteArrayViewVarHandle(long[].class, ByteOrder.BIG_ENDIAN); | ||
|
||
@Override | ||
public int serialize(LongArrayState state, byte[] ba, int offset) { | ||
int positions = state.largestIndex + 1; | ||
longHandle.set(ba, offset, positions); | ||
offset += Long.BYTES; | ||
for (int i = 0; i < positions; i++) { | ||
longHandle.set(ba, offset, state.values[i]); | ||
offset += BYTES_SIZE; | ||
} | ||
return Long.BYTES + (BYTES_SIZE * positions); // number of bytes written | ||
} | ||
|
||
@Override | ||
public void deserialize(LongArrayState state, byte[] ba, int offset) { | ||
Objects.requireNonNull(state); | ||
int positions = (int) (long) longHandle.get(ba, offset); | ||
offset += Long.BYTES; | ||
long[] values = new long[positions]; | ||
for (int i = 0; i < positions; i++) { | ||
values[i] = (long) longHandle.get(ba, offset); | ||
offset += BYTES_SIZE; | ||
} | ||
state.values = values; | ||
state.largestIndex = positions - 1; | ||
} | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.