Skip to content

Commit

Permalink
Small refactor of the aggregate functions (elastic#437)
Browse files Browse the repository at this point in the history
use just one method instead of two
centralize the creation into Aggregator
centralize the ESQL function mapping (for now) into AggregateMapper
  • Loading branch information
costin committed Dec 6, 2022
1 parent 7245f69 commit 8fc9c1e
Show file tree
Hide file tree
Showing 12 changed files with 149 additions and 172 deletions.
Expand Up @@ -13,7 +13,6 @@
import org.elasticsearch.compute.data.Block;
import org.elasticsearch.compute.data.Page;

import java.util.function.BiFunction;
import java.util.function.Supplier;

@Experimental
Expand All @@ -24,25 +23,28 @@ public class Aggregator {

private final int intermediateChannel;

public record AggregatorFactory(AggregatorFunction.AggregatorFunctionFactory aggCreationFunc, AggregatorMode mode, int inputChannel)
public record AggregatorFactory(AggregatorFunction.Provider provider, AggregatorMode mode, int inputChannel)
implements
Supplier<Aggregator>,
Describable {
@Override
public Aggregator get() {
return new Aggregator(aggCreationFunc, mode, inputChannel);
return new Aggregator(provider, mode, inputChannel);
}

@Override
public String describe() {
return aggCreationFunc.describe();
return provider.describe();
}
}

public Aggregator(BiFunction<AggregatorMode, Integer, AggregatorFunction> aggCreationFunc, AggregatorMode mode, int inputChannel) {
this.aggregatorFunction = aggCreationFunc.apply(mode, inputChannel);
this.mode = mode;
public Aggregator(AggregatorFunction.Provider provider, AggregatorMode mode, int inputChannel) {
assert mode.isInputPartial() || inputChannel >= 0;
// input channel is used both to signal the creation of the page (when the input is not partial)
this.aggregatorFunction = provider.create(mode.isInputPartial() ? -1 : inputChannel);
// and to indicate the page during the intermediate phase
this.intermediateChannel = mode.isInputPartial() ? inputChannel : -1;
this.mode = mode;
}

public void processPage(Page page) {
Expand Down
Expand Up @@ -13,8 +13,6 @@
import org.elasticsearch.compute.data.Block;
import org.elasticsearch.compute.data.Page;

import java.util.function.BiFunction;

@Experimental
public interface AggregatorFunction {

Expand All @@ -26,72 +24,20 @@ public interface AggregatorFunction {

Block evaluateFinal();

abstract class AggregatorFunctionFactory implements BiFunction<AggregatorMode, Integer, AggregatorFunction>, Describable {

private final String name;

AggregatorFunctionFactory(String name) {
this.name = name;
}
@FunctionalInterface
interface Provider extends Describable {
AggregatorFunction create(int inputChannel);

@Override
public String describe() {
return name;
}
}

AggregatorFunctionFactory doubleAvg = new AggregatorFunctionFactory("doubleAvg") {
@Override
public AggregatorFunction apply(AggregatorMode mode, Integer inputChannel) {
if (mode.isInputPartial()) {
return DoubleAvgAggregator.createIntermediate();
} else {
return DoubleAvgAggregator.create(inputChannel);
}
}
};

AggregatorFunctionFactory longAvg = new AggregatorFunctionFactory("longAvg") {
@Override
public AggregatorFunction apply(AggregatorMode mode, Integer inputChannel) {
if (mode.isInputPartial()) {
return LongAvgAggregator.createIntermediate();
} else {
return LongAvgAggregator.create(inputChannel);
default String describe() {
var description = getClass().getName();
// FooBarAggregator --> fooBar
description = description.substring(0, description.length() - 10);
var startChar = Character.toLowerCase(description.charAt(0));
if (startChar != description.charAt(0)) {
description = startChar + description.substring(1);
}
return description;
}
};

AggregatorFunctionFactory count = new AggregatorFunctionFactory("count") {
@Override
public AggregatorFunction apply(AggregatorMode mode, Integer inputChannel) {
if (mode.isInputPartial()) {
return CountRowsAggregator.createIntermediate();
} else {
return CountRowsAggregator.create(inputChannel);
}
}
};

AggregatorFunctionFactory max = new AggregatorFunctionFactory("max") {
@Override
public AggregatorFunction apply(AggregatorMode mode, Integer inputChannel) {
if (mode.isInputPartial()) {
return MaxAggregator.createIntermediate();
} else {
return MaxAggregator.create(inputChannel);
}
}
};

AggregatorFunctionFactory sum = new AggregatorFunctionFactory("sum") {
@Override
public AggregatorFunction apply(AggregatorMode mode, Integer inputChannel) {
if (mode.isInputPartial()) {
return SumAggregator.createIntermediate();
} else {
return SumAggregator.create(inputChannel);
}
}
};
}
}
@@ -0,0 +1,34 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0 and the Server Side Public License, v 1; you may not use this file except
* in compliance with, at your election, the Elastic License 2.0 or the Server
* Side Public License, v 1.
*/

package org.elasticsearch.compute.aggregation;

public final class AggregatorFunctionProviders {

private AggregatorFunctionProviders() {}

public static AggregatorFunction.Provider avgDouble() {
return DoubleAvgAggregator::create;
}

public static AggregatorFunction.Provider avgLong() {
return LongAvgAggregator::create;
}

public static AggregatorFunction.Provider count() {
return CountRowsAggregator::create;
}

public static AggregatorFunction.Provider max() {
return MaxAggregator::create;
}

public static AggregatorFunction.Provider sum() {
return SumAggregator::create;
}
}
Expand Up @@ -20,17 +20,10 @@ public class CountRowsAggregator implements AggregatorFunction {
private final LongState state;
private final int channel;

static CountRowsAggregator create(int inputChannel) {
if (inputChannel < 0) {
throw new IllegalArgumentException();
}
public static CountRowsAggregator create(int inputChannel) {
return new CountRowsAggregator(inputChannel, new LongState());
}

static CountRowsAggregator createIntermediate() {
return new CountRowsAggregator(-1, new LongState());
}

private CountRowsAggregator(int channel, LongState state) {
this.channel = channel;
this.state = state;
Expand Down
Expand Up @@ -26,16 +26,9 @@ class DoubleAvgAggregator implements AggregatorFunction {
private final int channel;

static DoubleAvgAggregator create(int inputChannel) {
if (inputChannel < 0) {
throw new IllegalArgumentException();
}
return new DoubleAvgAggregator(inputChannel, new AvgState());
}

static DoubleAvgAggregator createIntermediate() {
return new DoubleAvgAggregator(-1, new AvgState());
}

private DoubleAvgAggregator(int channel, AvgState state) {
this.channel = channel;
this.state = state;
Expand Down
Expand Up @@ -26,16 +26,9 @@ class LongAvgAggregator implements AggregatorFunction {
private final int channel;

static LongAvgAggregator create(int inputChannel) {
if (inputChannel < 0) {
throw new IllegalArgumentException();
}
return new LongAvgAggregator(inputChannel, new AvgState());
}

static LongAvgAggregator createIntermediate() {
return new LongAvgAggregator(-1, new AvgState());
}

private LongAvgAggregator(int channel, AvgState state) {
this.channel = channel;
this.state = state;
Expand Down
Expand Up @@ -22,16 +22,9 @@ final class MaxAggregator implements AggregatorFunction {
private final int channel;

static MaxAggregator create(int inputChannel) {
if (inputChannel < 0) {
throw new IllegalArgumentException();
}
return new MaxAggregator(inputChannel, new DoubleState(Double.NEGATIVE_INFINITY));
}

static MaxAggregator createIntermediate() {
return new MaxAggregator(-1, new DoubleState(Double.NEGATIVE_INFINITY));
}

private MaxAggregator(int channel, DoubleState state) {
this.channel = channel;
this.state = state;
Expand Down
Expand Up @@ -22,16 +22,9 @@ final class SumAggregator implements AggregatorFunction {
private final int channel;

static SumAggregator create(int inputChannel) {
if (inputChannel < 0) {
throw new IllegalArgumentException();
}
return new SumAggregator(inputChannel, new DoubleState());
}

static SumAggregator createIntermediate() {
return new SumAggregator(-1, new DoubleState());
}

private SumAggregator(int channel, DoubleState state) {
this.channel = channel;
this.state = state;
Expand Down

0 comments on commit 8fc9c1e

Please sign in to comment.