From a13fce98f61867fcb5adb52c80f1cfd3eecfc436 Mon Sep 17 00:00:00 2001 From: mingmxu Date: Mon, 26 Jun 2017 16:03:51 -0700 Subject: [PATCH] UDAF support: - Adds an abstract class BeamSqlUdaf for defining Calcite SQL UDAFs. - Updates built-in COUNT/SUM/AVG/MAX/MIN accumulators to use this new class. --- .../org/apache/beam/dsls/sql/BeamSqlEnv.java | 10 + .../beam/dsls/sql/rel/BeamAggregationRel.java | 2 +- .../beam/dsls/sql/schema/BeamSqlUdaf.java | 72 ++ .../transform/BeamAggregationTransforms.java | 658 ++++-------------- .../transform/BeamBuiltinAggregations.java | 412 +++++++++++ .../BeamAggregationTransformTest.java | 2 +- 6 files changed, 633 insertions(+), 523 deletions(-) create mode 100644 dsls/sql/src/main/java/org/apache/beam/dsls/sql/schema/BeamSqlUdaf.java create mode 100644 dsls/sql/src/main/java/org/apache/beam/dsls/sql/transform/BeamBuiltinAggregations.java diff --git a/dsls/sql/src/main/java/org/apache/beam/dsls/sql/BeamSqlEnv.java b/dsls/sql/src/main/java/org/apache/beam/dsls/sql/BeamSqlEnv.java index baa2617d9feee..078d9d34644d7 100644 --- a/dsls/sql/src/main/java/org/apache/beam/dsls/sql/BeamSqlEnv.java +++ b/dsls/sql/src/main/java/org/apache/beam/dsls/sql/BeamSqlEnv.java @@ -22,6 +22,7 @@ import org.apache.beam.dsls.sql.planner.BeamQueryPlanner; import org.apache.beam.dsls.sql.schema.BaseBeamTable; import org.apache.beam.dsls.sql.schema.BeamSqlRecordType; +import org.apache.beam.dsls.sql.schema.BeamSqlUdaf; import org.apache.beam.dsls.sql.utils.CalciteUtils; import org.apache.calcite.DataContext; import org.apache.calcite.linq4j.Enumerable; @@ -32,6 +33,7 @@ import org.apache.calcite.schema.SchemaPlus; import org.apache.calcite.schema.Statistic; import org.apache.calcite.schema.Statistics; +import org.apache.calcite.schema.impl.AggregateFunctionImpl; import org.apache.calcite.schema.impl.ScalarFunctionImpl; import org.apache.calcite.tools.Frameworks; @@ -57,6 +59,14 @@ public void registerUdf(String functionName, Class clazz, String methodName) schema.add(functionName, ScalarFunctionImpl.create(clazz, methodName)); } + /** + * Register a UDAF function which can be used in GROUP-BY expression. + * See {@link BeamSqlUdaf} on how to implement a UDAF. + */ + public void registerUdaf(String functionName, Class clazz) { + schema.add(functionName, AggregateFunctionImpl.create(clazz)); + } + /** * Registers a {@link BaseBeamTable} which can be used for all subsequent queries. * diff --git a/dsls/sql/src/main/java/org/apache/beam/dsls/sql/rel/BeamAggregationRel.java b/dsls/sql/src/main/java/org/apache/beam/dsls/sql/rel/BeamAggregationRel.java index 9ec9e9fd8f294..9bb290245423e 100644 --- a/dsls/sql/src/main/java/org/apache/beam/dsls/sql/rel/BeamAggregationRel.java +++ b/dsls/sql/src/main/java/org/apache/beam/dsls/sql/rel/BeamAggregationRel.java @@ -104,7 +104,7 @@ public PCollection buildBeamPipeline(PCollectionTuple inputPCollecti PCollection> aggregatedStream = exCombineByStream.apply( stageName + "combineBy", Combine.perKey( - new BeamAggregationTransforms.AggregationCombineFn(getAggCallList(), + new BeamAggregationTransforms.AggregationAdaptor(getAggCallList(), CalciteUtils.toBeamRecordType(input.getRowType())))) .setCoder(KvCoder.of(keyCoder, aggCoder)); diff --git a/dsls/sql/src/main/java/org/apache/beam/dsls/sql/schema/BeamSqlUdaf.java b/dsls/sql/src/main/java/org/apache/beam/dsls/sql/schema/BeamSqlUdaf.java new file mode 100644 index 0000000000000..9582ffaea8985 --- /dev/null +++ b/dsls/sql/src/main/java/org/apache/beam/dsls/sql/schema/BeamSqlUdaf.java @@ -0,0 +1,72 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.dsls.sql.schema; + +import java.io.Serializable; +import java.lang.reflect.ParameterizedType; +import org.apache.beam.sdk.coders.CannotProvideCoderException; +import org.apache.beam.sdk.coders.Coder; +import org.apache.beam.sdk.coders.CoderRegistry; +import org.apache.beam.sdk.transforms.Combine.CombineFn; + +/** + * abstract class of aggregation functions in Beam SQL. + * + *

There're several constrains for a UDAF:
+ * 1. A constructor with an empty argument list is required;
+ * 2. The type of {@code InputT} and {@code OutputT} can only be Interger/Long/Short/Byte/Double + * /Float/Date/BigDecimal, mapping as SQL type INTEGER/BIGINT/SMALLINT/TINYINE/DOUBLE/FLOAT + * /TIMESTAMP/DECIMAL;
+ * 3. Keep intermediate data in {@code AccumT}, and do not rely on elements in class;
+ */ +public abstract class BeamSqlUdaf implements Serializable { + public BeamSqlUdaf(){} + + /** + * create an initial aggregation object, equals to {@link CombineFn#createAccumulator()}. + */ + public abstract AccumT init(); + + /** + * add an input value, equals to {@link CombineFn#addInput(Object, Object)}. + */ + public abstract AccumT add(AccumT accumulator, InputT input); + + /** + * merge aggregation objects from parallel tasks, equals to + * {@link CombineFn#mergeAccumulators(Iterable)}. + */ + public abstract AccumT merge(Iterable accumulators); + + /** + * extract output value from aggregation object, equals to + * {@link CombineFn#extractOutput(Object)}. + */ + public abstract OutputT result(AccumT accumulator); + + /** + * get the coder for AccumT which stores the intermediate result. + * By default it's fetched from {@link CoderRegistry}. + */ + public Coder getAccumulatorCoder(CoderRegistry registry) + throws CannotProvideCoderException { + return registry.getCoder( + (Class) ((ParameterizedType) getClass() + .getGenericSuperclass()).getActualTypeArguments()[1]); + } +} diff --git a/dsls/sql/src/main/java/org/apache/beam/dsls/sql/transform/BeamAggregationTransforms.java b/dsls/sql/src/main/java/org/apache/beam/dsls/sql/transform/BeamAggregationTransforms.java index 83d473a442336..9c0b4a37ae7df 100644 --- a/dsls/sql/src/main/java/org/apache/beam/dsls/sql/transform/BeamAggregationTransforms.java +++ b/dsls/sql/src/main/java/org/apache/beam/dsls/sql/transform/BeamAggregationTransforms.java @@ -17,25 +17,35 @@ */ package org.apache.beam.dsls.sql.transform; +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; import java.io.Serializable; +import java.math.BigDecimal; import java.util.ArrayList; -import java.util.Arrays; -import java.util.Date; +import java.util.Iterator; import java.util.List; import org.apache.beam.dsls.sql.interpreter.operator.BeamSqlExpression; import org.apache.beam.dsls.sql.interpreter.operator.BeamSqlInputRefExpression; -import org.apache.beam.dsls.sql.interpreter.operator.BeamSqlPrimitive; import org.apache.beam.dsls.sql.schema.BeamSqlRecordType; import org.apache.beam.dsls.sql.schema.BeamSqlRow; +import org.apache.beam.dsls.sql.schema.BeamSqlUdaf; import org.apache.beam.dsls.sql.utils.CalciteUtils; +import org.apache.beam.sdk.coders.BigDecimalCoder; +import org.apache.beam.sdk.coders.CannotProvideCoderException; +import org.apache.beam.sdk.coders.Coder; +import org.apache.beam.sdk.coders.CoderException; +import org.apache.beam.sdk.coders.CoderRegistry; +import org.apache.beam.sdk.coders.CustomCoder; +import org.apache.beam.sdk.coders.VarIntCoder; import org.apache.beam.sdk.transforms.Combine.CombineFn; import org.apache.beam.sdk.transforms.DoFn; import org.apache.beam.sdk.transforms.SerializableFunction; import org.apache.beam.sdk.transforms.windowing.BoundedWindow; import org.apache.beam.sdk.values.KV; import org.apache.calcite.rel.core.AggregateCall; -import org.apache.calcite.sql.SqlAggFunction; -import org.apache.calcite.sql.type.SqlTypeName; +import org.apache.calcite.schema.impl.AggregateFunctionImpl; +import org.apache.calcite.sql.validate.SqlUserDefinedAggFunction; import org.apache.calcite.util.ImmutableBitSet; import org.joda.time.Instant; @@ -71,9 +81,7 @@ public void processElement(ProcessContext c, BoundedWindow window) { outRecord.addField(aggFieldNames.get(idx), kvRecord.getValue().getFieldValue(idx)); } - // if (c.pane().isLast()) { c.output(outRecord); - // } } } @@ -134,545 +142,153 @@ public Instant apply(BeamSqlRow input) { } /** - * Aggregation function which supports COUNT, MAX, MIN, SUM, AVG. - * - *

Multiple aggregation functions are combined together. - * For each aggregation function, it may accept part of all data types:
- * 1). COUNT works for any data type;
- * 2). MAX/MIN works for INT, LONG, FLOAT, DOUBLE, DECIMAL, SMALLINT, TINYINT, TIMESTAMP;
- * 3). SUM/AVG works for INT, LONG, FLOAT, DOUBLE, DECIMAL, SMALLINT, TINYINT;
- * + * An adaptor class to invoke Calcite UDAF instances in Beam {@code CombineFn}. */ - public static class AggregationCombineFn extends CombineFn { - private BeamSqlRecordType aggDataType; + public static class AggregationAdaptor + extends CombineFn { + private List aggregators; + private List sourceFieldExps; + private BeamSqlRecordType finalRecordType; - private int countIndex = -1; - - List aggFunctions; - List aggElementExpressions; - - public AggregationCombineFn(List aggregationCalls, + public AggregationAdaptor(List aggregationCalls, BeamSqlRecordType sourceRowRecordType) { - this.aggFunctions = new ArrayList<>(); - this.aggElementExpressions = new ArrayList<>(); - - boolean hasAvg = false; - boolean hasCount = false; - int countIndex = -1; - List fieldNames = new ArrayList<>(); - List fieldTypes = new ArrayList<>(); - for (int idx = 0; idx < aggregationCalls.size(); ++idx) { - AggregateCall ac = aggregationCalls.get(idx); - //verify it's supported. - verifySupportedAggregation(ac); - - fieldNames.add(ac.name); - fieldTypes.add(CalciteUtils.toJavaType(ac.type.getSqlTypeName())); - - SqlAggFunction aggFn = ac.getAggregation(); - switch (aggFn.getName()) { - case "COUNT": - aggElementExpressions.add(BeamSqlPrimitive.of(SqlTypeName.BIGINT, 1L)); - hasCount = true; - countIndex = idx; - break; - case "SUM": - case "MAX": - case "MIN": - case "AVG": - int refIndex = ac.getArgList().get(0); - aggElementExpressions.add(new BeamSqlInputRefExpression( - CalciteUtils.getFieldType(sourceRowRecordType, refIndex), refIndex)); - if ("AVG".equals(aggFn.getName())) { - hasAvg = true; - } - break; - - default: + aggregators = new ArrayList<>(); + sourceFieldExps = new ArrayList<>(); + List outFieldsName = new ArrayList<>(); + List outFieldsType = new ArrayList<>(); + for (AggregateCall call : aggregationCalls) { + int refIndex = call.getArgList().size() > 0 ? call.getArgList().get(0) : 0; + BeamSqlExpression sourceExp = new BeamSqlInputRefExpression( + CalciteUtils.getFieldType(sourceRowRecordType, refIndex), refIndex); + sourceFieldExps.add(sourceExp); + + outFieldsName.add(call.name); + int outFieldType = CalciteUtils.toJavaType(call.type.getSqlTypeName()); + outFieldsType.add(outFieldType); + + switch (call.getAggregation().getName()) { + case "COUNT": + aggregators.add(new BeamBuiltinAggregations.Count()); + break; + case "MAX": + aggregators.add(BeamBuiltinAggregations.Max.create(call.type.getSqlTypeName())); + break; + case "MIN": + aggregators.add(BeamBuiltinAggregations.Min.create(call.type.getSqlTypeName())); + break; + case "SUM": + aggregators.add(BeamBuiltinAggregations.Sum.create(call.type.getSqlTypeName())); + break; + case "AVG": + aggregators.add(BeamBuiltinAggregations.Avg.create(call.type.getSqlTypeName())); + break; + default: + if (call.getAggregation() instanceof SqlUserDefinedAggFunction) { + // handle UDAF. + SqlUserDefinedAggFunction udaf = (SqlUserDefinedAggFunction) call.getAggregation(); + AggregateFunctionImpl fn = (AggregateFunctionImpl) udaf.function; + try { + aggregators.add((BeamSqlUdaf) fn.declaringClass.newInstance()); + } catch (Exception e) { + throw new IllegalStateException(e); + } + } else { + throw new UnsupportedOperationException( + String.format("Aggregator [%s] is not supported", + call.getAggregation().getName())); + } break; } - aggFunctions.add(aggFn.getName()); } - - - // add a COUNT holder if only have AVG - if (hasAvg && !hasCount) { - fieldNames.add("__COUNT"); - fieldTypes.add(CalciteUtils.toJavaType(SqlTypeName.BIGINT)); - - aggFunctions.add("COUNT"); - aggElementExpressions.add(BeamSqlPrimitive.of(SqlTypeName.BIGINT, 1L)); - - hasCount = true; - countIndex = aggDataType.size() - 1; + finalRecordType = BeamSqlRecordType.create(outFieldsName, outFieldsType); + } + @Override + public AggregationAccumulator createAccumulator() { + AggregationAccumulator initialAccu = new AggregationAccumulator(); + for (BeamSqlUdaf agg : aggregators) { + initialAccu.accumulatorElements.add(agg.init()); } - - this.aggDataType = BeamSqlRecordType.create(fieldNames, fieldTypes); - this.countIndex = countIndex; + return initialAccu; } - - private void verifySupportedAggregation(AggregateCall ac) { - //donot support DISTINCT - if (ac.isDistinct()) { - throw new UnsupportedOperationException("DISTINCT is not supported yet."); + @Override + public AggregationAccumulator addInput(AggregationAccumulator accumulator, BeamSqlRow input) { + AggregationAccumulator deltaAcc = new AggregationAccumulator(); + for (int idx = 0; idx < aggregators.size(); ++idx) { + deltaAcc.accumulatorElements.add( + aggregators.get(idx).add(accumulator.accumulatorElements.get(idx), + sourceFieldExps.get(idx).evaluate(input).getValue())); } - String aggFnName = ac.getAggregation().getName(); - switch (aggFnName) { - case "COUNT": - //COUNT works for any data type; - break; - case "SUM": - // SUM only support for INT, LONG, FLOAT, DOUBLE, DECIMAL, SMALLINT, - // TINYINT now - if (!Arrays - .asList(SqlTypeName.INTEGER, SqlTypeName.BIGINT, SqlTypeName.FLOAT, SqlTypeName.DOUBLE, - SqlTypeName.SMALLINT, SqlTypeName.TINYINT) - .contains(ac.type.getSqlTypeName())) { - throw new UnsupportedOperationException( - "SUM only support for INT, LONG, FLOAT, DOUBLE, SMALLINT, TINYINT"); - } - break; - case "MAX": - case "MIN": - // MAX/MIN only support for INT, LONG, FLOAT, DOUBLE, DECIMAL, SMALLINT, - // TINYINT, TIMESTAMP now - if (!Arrays.asList(SqlTypeName.INTEGER, SqlTypeName.BIGINT, SqlTypeName.FLOAT, - SqlTypeName.DOUBLE, SqlTypeName.SMALLINT, SqlTypeName.TINYINT, - SqlTypeName.TIMESTAMP).contains(ac.type.getSqlTypeName())) { - throw new UnsupportedOperationException("MAX/MIN only support for INT, LONG, FLOAT," - + " DOUBLE, SMALLINT, TINYINT, TIMESTAMP"); - } - break; - case "AVG": - // AVG only support for INT, LONG, FLOAT, DOUBLE, DECIMAL, SMALLINT, - // TINYINT now - if (!Arrays - .asList(SqlTypeName.INTEGER, SqlTypeName.BIGINT, SqlTypeName.FLOAT, SqlTypeName.DOUBLE, - SqlTypeName.SMALLINT, SqlTypeName.TINYINT) - .contains(ac.type.getSqlTypeName())) { - throw new UnsupportedOperationException( - "AVG only support for INT, LONG, FLOAT, DOUBLE, SMALLINT, TINYINT"); + return deltaAcc; + } + @Override + public AggregationAccumulator mergeAccumulators(Iterable accumulators) { + AggregationAccumulator deltaAcc = new AggregationAccumulator(); + for (int idx = 0; idx < aggregators.size(); ++idx) { + List accs = new ArrayList<>(); + Iterator ite = accumulators.iterator(); + while (ite.hasNext()) { + accs.add(ite.next().accumulatorElements.get(idx)); } - break; - default: - throw new UnsupportedOperationException( - String.format("[%s] is not supported.", aggFnName)); + deltaAcc.accumulatorElements.add(aggregators.get(idx).merge(accs)); } + return deltaAcc; } - @Override - public BeamSqlRow createAccumulator() { - BeamSqlRow initialRecord = new BeamSqlRow(aggDataType); - for (int idx = 0; idx < aggElementExpressions.size(); ++idx) { - BeamSqlExpression ex = aggElementExpressions.get(idx); - String aggFnName = aggFunctions.get(idx); - switch (aggFnName) { - case "COUNT": - initialRecord.addField(idx, 0L); - break; - case "AVG": - case "SUM": - //for both AVG/SUM, a summary value is hold at first. - switch (ex.getOutputType()) { - case INTEGER: - initialRecord.addField(idx, 0); - break; - case BIGINT: - initialRecord.addField(idx, 0L); - break; - case SMALLINT: - initialRecord.addField(idx, (short) 0); - break; - case TINYINT: - initialRecord.addField(idx, (byte) 0); - break; - case FLOAT: - initialRecord.addField(idx, 0.0f); - break; - case DOUBLE: - initialRecord.addField(idx, 0.0); - break; - default: - break; - } - break; - case "MAX": - switch (ex.getOutputType()) { - case INTEGER: - initialRecord.addField(idx, Integer.MIN_VALUE); - break; - case BIGINT: - initialRecord.addField(idx, Long.MIN_VALUE); - break; - case SMALLINT: - initialRecord.addField(idx, Short.MIN_VALUE); - break; - case TINYINT: - initialRecord.addField(idx, Byte.MIN_VALUE); - break; - case FLOAT: - initialRecord.addField(idx, Float.MIN_VALUE); - break; - case DOUBLE: - initialRecord.addField(idx, Double.MIN_VALUE); - break; - case TIMESTAMP: - initialRecord.addField(idx, new Date(0)); - break; - default: - break; - } - break; - case "MIN": - switch (ex.getOutputType()) { - case INTEGER: - initialRecord.addField(idx, Integer.MAX_VALUE); - break; - case BIGINT: - initialRecord.addField(idx, Long.MAX_VALUE); - break; - case SMALLINT: - initialRecord.addField(idx, Short.MAX_VALUE); - break; - case TINYINT: - initialRecord.addField(idx, Byte.MAX_VALUE); - break; - case FLOAT: - initialRecord.addField(idx, Float.MAX_VALUE); - break; - case DOUBLE: - initialRecord.addField(idx, Double.MAX_VALUE); - break; - case TIMESTAMP: - initialRecord.addField(idx, new Date(Long.MAX_VALUE)); - break; - default: - break; - } - break; - default: - break; - } + public BeamSqlRow extractOutput(AggregationAccumulator accumulator) { + BeamSqlRow result = new BeamSqlRow(finalRecordType); + for (int idx = 0; idx < aggregators.size(); ++idx) { + result.addField(idx, aggregators.get(idx).result(accumulator.accumulatorElements.get(idx))); } - return initialRecord; + return result; } - @Override - public BeamSqlRow addInput(BeamSqlRow accumulator, BeamSqlRow input) { - BeamSqlRow deltaRecord = new BeamSqlRow(aggDataType); - for (int idx = 0; idx < aggElementExpressions.size(); ++idx) { - BeamSqlExpression ex = aggElementExpressions.get(idx); - String aggFnName = aggFunctions.get(idx); - switch (aggFnName) { - case "COUNT": - deltaRecord.addField(idx, 1 + accumulator.getLong(idx)); - break; - case "AVG": - case "SUM": - // for both AVG/SUM, a summary value is hold at first. - switch (ex.getOutputType()) { - case INTEGER: - deltaRecord.addField(idx, - ex.evaluate(input).getInteger() + accumulator.getInteger(idx)); - break; - case BIGINT: - deltaRecord.addField(idx, ex.evaluate(input).getLong() + accumulator.getLong(idx)); - break; - case SMALLINT: - deltaRecord.addField(idx, - (short) (ex.evaluate(input).getShort() + accumulator.getShort(idx))); - break; - case TINYINT: - deltaRecord.addField(idx, - (byte) (ex.evaluate(input).getByte() + accumulator.getByte(idx))); - break; - case FLOAT: - deltaRecord.addField(idx, - (float) (ex.evaluate(input).getFloat() + accumulator.getFloat(idx))); - break; - case DOUBLE: - deltaRecord.addField(idx, ex.evaluate(input).getDouble() + accumulator.getDouble(idx)); - break; - default: - break; - } - break; - case "MAX": - switch (ex.getOutputType()) { - case INTEGER: - deltaRecord.addField(idx, - Math.max(ex.evaluate(input).getInteger(), accumulator.getInteger(idx))); - break; - case BIGINT: - deltaRecord.addField(idx, - Math.max(ex.evaluate(input).getLong(), accumulator.getLong(idx))); - break; - case SMALLINT: - deltaRecord.addField(idx, - (short) Math.max(ex.evaluate(input).getShort(), accumulator.getShort(idx))); - break; - case TINYINT: - deltaRecord.addField(idx, - (byte) Math.max(ex.evaluate(input).getByte(), accumulator.getByte(idx))); - break; - case FLOAT: - deltaRecord.addField(idx, - Math.max(ex.evaluate(input).getFloat(), accumulator.getFloat(idx))); - break; - case DOUBLE: - deltaRecord.addField(idx, - Math.max(ex.evaluate(input).getDouble(), accumulator.getDouble(idx))); - break; - case TIMESTAMP: - Date preDate = accumulator.getDate(idx); - Date nowDate = ex.evaluate(input).getDate(); - deltaRecord.addField(idx, preDate.getTime() > nowDate.getTime() ? preDate : nowDate); - break; - default: - break; - } - break; - case "MIN": - switch (ex.getOutputType()) { - case INTEGER: - deltaRecord.addField(idx, - Math.min(ex.evaluate(input).getInteger(), accumulator.getInteger(idx))); - break; - case BIGINT: - deltaRecord.addField(idx, - Math.min(ex.evaluate(input).getLong(), accumulator.getLong(idx))); - break; - case SMALLINT: - deltaRecord.addField(idx, - (short) Math.min(ex.evaluate(input).getShort(), accumulator.getShort(idx))); - break; - case TINYINT: - deltaRecord.addField(idx, - (byte) Math.min(ex.evaluate(input).getByte(), accumulator.getByte(idx))); - break; - case FLOAT: - deltaRecord.addField(idx, - Math.min(ex.evaluate(input).getFloat(), accumulator.getFloat(idx))); - break; - case DOUBLE: - deltaRecord.addField(idx, - Math.min(ex.evaluate(input).getDouble(), accumulator.getDouble(idx))); - break; - case TIMESTAMP: - Date preDate = accumulator.getDate(idx); - Date nowDate = ex.evaluate(input).getDate(); - deltaRecord.addField(idx, preDate.getTime() < nowDate.getTime() ? preDate : nowDate); - break; - default: - break; - } - break; - default: - break; - } + public Coder getAccumulatorCoder( + CoderRegistry registry, Coder inputCoder) + throws CannotProvideCoderException { + registry.registerCoderForClass(BigDecimal.class, BigDecimalCoder.of()); + List aggAccuCoderList = new ArrayList<>(); + for (BeamSqlUdaf udaf : aggregators) { + aggAccuCoderList.add(udaf.getAccumulatorCoder(registry)); } - return deltaRecord; + return new AggregationAccumulatorCoder(aggAccuCoderList); } + } - @Override - public BeamSqlRow mergeAccumulators(Iterable accumulators) { - BeamSqlRow deltaRecord = new BeamSqlRow(aggDataType); + /** + * A class to holder varied accumulator objects. + */ + public static class AggregationAccumulator{ + private List accumulatorElements = new ArrayList<>(); + } - while (accumulators.iterator().hasNext()) { - BeamSqlRow sa = accumulators.iterator().next(); - for (int idx = 0; idx < aggElementExpressions.size(); ++idx) { - BeamSqlExpression ex = aggElementExpressions.get(idx); - String aggFnName = aggFunctions.get(idx); - switch (aggFnName) { - case "COUNT": - deltaRecord.addField(idx, deltaRecord.getLong(idx) + sa.getLong(idx)); - break; - case "AVG": - case "SUM": - // for both AVG/SUM, a summary value is hold at first. - switch (ex.getOutputType()) { - case INTEGER: - deltaRecord.addField(idx, deltaRecord.getInteger(idx) + sa.getInteger(idx)); - break; - case BIGINT: - deltaRecord.addField(idx, deltaRecord.getLong(idx) + sa.getLong(idx)); - break; - case SMALLINT: - deltaRecord.addField(idx, (short) (deltaRecord.getShort(idx) + sa.getShort(idx))); - break; - case TINYINT: - deltaRecord.addField(idx, (byte) (deltaRecord.getByte(idx) + sa.getByte(idx))); - break; - case FLOAT: - deltaRecord.addField(idx, (float) (deltaRecord.getFloat(idx) + sa.getFloat(idx))); - break; - case DOUBLE: - deltaRecord.addField(idx, deltaRecord.getDouble(idx) + sa.getDouble(idx)); - break; - default: - break; - } - break; - case "MAX": - switch (ex.getOutputType()) { - case INTEGER: - deltaRecord.addField(idx, Math.max(deltaRecord.getInteger(idx), sa.getInteger(idx))); - break; - case BIGINT: - deltaRecord.addField(idx, Math.max(deltaRecord.getLong(idx), sa.getLong(idx))); - break; - case SMALLINT: - deltaRecord.addField(idx, - (short) Math.max(deltaRecord.getShort(idx), sa.getShort(idx))); - break; - case TINYINT: - deltaRecord.addField(idx, (byte) Math.max(deltaRecord.getByte(idx), sa.getByte(idx))); - break; - case FLOAT: - deltaRecord.addField(idx, Math.max(deltaRecord.getFloat(idx), sa.getFloat(idx))); - break; - case DOUBLE: - deltaRecord.addField(idx, Math.max(deltaRecord.getDouble(idx), sa.getDouble(idx))); - break; - case TIMESTAMP: - Date preDate = deltaRecord.getDate(idx); - Date nowDate = sa.getDate(idx); - deltaRecord.addField(idx, preDate.getTime() > nowDate.getTime() ? preDate : nowDate); - break; - default: - break; - } - break; - case "MIN": - switch (ex.getOutputType()) { - case INTEGER: - deltaRecord.addField(idx, Math.min(deltaRecord.getInteger(idx), sa.getInteger(idx))); - break; - case BIGINT: - deltaRecord.addField(idx, Math.min(deltaRecord.getLong(idx), sa.getLong(idx))); - break; - case SMALLINT: - deltaRecord.addField(idx, - (short) Math.min(deltaRecord.getShort(idx), sa.getShort(idx))); - break; - case TINYINT: - deltaRecord.addField(idx, (byte) Math.min(deltaRecord.getByte(idx), sa.getByte(idx))); - break; - case FLOAT: - deltaRecord.addField(idx, Math.min(deltaRecord.getFloat(idx), sa.getFloat(idx))); - break; - case DOUBLE: - deltaRecord.addField(idx, Math.min(deltaRecord.getDouble(idx), sa.getDouble(idx))); - break; - case TIMESTAMP: - Date preDate = deltaRecord.getDate(idx); - Date nowDate = sa.getDate(idx); - deltaRecord.addField(idx, preDate.getTime() < nowDate.getTime() ? preDate : nowDate); - break; - default: - break; - } - break; - default: - break; - } - } + /** + * Coder for {@link AggregationAccumulator}. + */ + public static class AggregationAccumulatorCoder extends CustomCoder{ + private VarIntCoder sizeCoder = VarIntCoder.of(); + private List elementCoders; + + public AggregationAccumulatorCoder(List elementCoders) { + this.elementCoders = elementCoders; + } + + @Override + public void encode(AggregationAccumulator value, OutputStream outStream) + throws CoderException, IOException { + sizeCoder.encode(value.accumulatorElements.size(), outStream); + for (int idx = 0; idx < value.accumulatorElements.size(); ++idx) { + elementCoders.get(idx).encode(value.accumulatorElements.get(idx), outStream); } - return deltaRecord; } @Override - public BeamSqlRow extractOutput(BeamSqlRow accumulator) { - BeamSqlRow finalRecord = new BeamSqlRow(aggDataType); - for (int idx = 0; idx < aggElementExpressions.size(); ++idx) { - BeamSqlExpression ex = aggElementExpressions.get(idx); - String aggFnName = aggFunctions.get(idx); - switch (aggFnName) { - case "COUNT": - finalRecord.addField(idx, accumulator.getLong(idx)); - break; - case "AVG": - long count = accumulator.getLong(countIndex); - switch (ex.getOutputType()) { - case INTEGER: - finalRecord.addField(idx, (int) (accumulator.getInteger(idx) / count)); - break; - case BIGINT: - finalRecord.addField(idx, accumulator.getLong(idx) / count); - break; - case SMALLINT: - finalRecord.addField(idx, (short) (accumulator.getShort(idx) / count)); - break; - case TINYINT: - finalRecord.addField(idx, (byte) (accumulator.getByte(idx) / count)); - break; - case FLOAT: - finalRecord.addField(idx, (float) (accumulator.getFloat(idx) / count)); - break; - case DOUBLE: - finalRecord.addField(idx, accumulator.getDouble(idx) / count); - break; - default: - break; - } - break; - case "SUM": - switch (ex.getOutputType()) { - case INTEGER: - finalRecord.addField(idx, accumulator.getInteger(idx)); - break; - case BIGINT: - finalRecord.addField(idx, accumulator.getLong(idx)); - break; - case SMALLINT: - finalRecord.addField(idx, accumulator.getShort(idx)); - break; - case TINYINT: - finalRecord.addField(idx, accumulator.getByte(idx)); - break; - case FLOAT: - finalRecord.addField(idx, accumulator.getFloat(idx)); - break; - case DOUBLE: - finalRecord.addField(idx, accumulator.getDouble(idx)); - break; - default: - break; - } - break; - case "MAX": - case "MIN": - switch (ex.getOutputType()) { - case INTEGER: - finalRecord.addField(idx, accumulator.getInteger(idx)); - break; - case BIGINT: - finalRecord.addField(idx, accumulator.getLong(idx)); - break; - case SMALLINT: - finalRecord.addField(idx, accumulator.getShort(idx)); - break; - case TINYINT: - finalRecord.addField(idx, accumulator.getByte(idx)); - break; - case FLOAT: - finalRecord.addField(idx, accumulator.getFloat(idx)); - break; - case DOUBLE: - finalRecord.addField(idx, accumulator.getDouble(idx)); - break; - case TIMESTAMP: - finalRecord.addField(idx, accumulator.getDate(idx)); - break; - default: - break; - } - break; - default: - break; - } + public AggregationAccumulator decode(InputStream inStream) throws CoderException, IOException { + AggregationAccumulator accu = new AggregationAccumulator(); + int size = sizeCoder.decode(inStream); + for (int idx = 0; idx < size; ++idx) { + accu.accumulatorElements.add(elementCoders.get(idx).decode(inStream)); } - return finalRecord; + return accu; } } } diff --git a/dsls/sql/src/main/java/org/apache/beam/dsls/sql/transform/BeamBuiltinAggregations.java b/dsls/sql/src/main/java/org/apache/beam/dsls/sql/transform/BeamBuiltinAggregations.java new file mode 100644 index 0000000000000..fab26667e2e9f --- /dev/null +++ b/dsls/sql/src/main/java/org/apache/beam/dsls/sql/transform/BeamBuiltinAggregations.java @@ -0,0 +1,412 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.dsls.sql.transform; + +import java.math.BigDecimal; +import java.util.Date; +import java.util.Iterator; +import org.apache.beam.dsls.sql.schema.BeamSqlUdaf; +import org.apache.beam.sdk.coders.BigDecimalCoder; +import org.apache.beam.sdk.coders.ByteCoder; +import org.apache.beam.sdk.coders.CannotProvideCoderException; +import org.apache.beam.sdk.coders.Coder; +import org.apache.beam.sdk.coders.CoderRegistry; +import org.apache.beam.sdk.coders.DoubleCoder; +import org.apache.beam.sdk.coders.KvCoder; +import org.apache.beam.sdk.coders.SerializableCoder; +import org.apache.beam.sdk.coders.VarIntCoder; +import org.apache.beam.sdk.coders.VarLongCoder; +import org.apache.beam.sdk.values.KV; +import org.apache.calcite.sql.type.SqlTypeName; + +/** + * Built-in aggregations functions for COUNT/MAX/MIN/SUM/AVG. + */ +class BeamBuiltinAggregations { + /** + * Built-in aggregation for COUNT. + */ + public static final class Count extends BeamSqlUdaf { + public Count() {} + + @Override + public Long init() { + return 0L; + } + + @Override + public Long add(Long accumulator, T input) { + return accumulator + 1; + } + + @Override + public Long merge(Iterable accumulators) { + long v = 0L; + Iterator ite = accumulators.iterator(); + while (ite.hasNext()) { + v += ite.next(); + } + return v; + } + + @Override + public Long result(Long accumulator) { + return accumulator; + } + } + + /** + * Built-in aggregation for MAX. + */ + public static final class Max> extends BeamSqlUdaf { + public static Max create(SqlTypeName fieldType) { + switch (fieldType) { + case INTEGER: + return new BeamBuiltinAggregations.Max(fieldType); + case SMALLINT: + return new BeamBuiltinAggregations.Max(fieldType); + case TINYINT: + return new BeamBuiltinAggregations.Max(fieldType); + case BIGINT: + return new BeamBuiltinAggregations.Max(fieldType); + case FLOAT: + return new BeamBuiltinAggregations.Max(fieldType); + case DOUBLE: + return new BeamBuiltinAggregations.Max(fieldType); + case TIMESTAMP: + return new BeamBuiltinAggregations.Max(fieldType); + case DECIMAL: + return new BeamBuiltinAggregations.Max(fieldType); + default: + throw new UnsupportedOperationException( + String.format("[%s] is not support in MAX", fieldType)); + } + } + + private final SqlTypeName fieldType; + private Max(SqlTypeName fieldType) { + this.fieldType = fieldType; + } + + @Override + public T init() { + return null; + } + + @Override + public T add(T accumulator, T input) { + return (accumulator == null || accumulator.compareTo(input) < 0) ? input : accumulator; + } + + @Override + public T merge(Iterable accumulators) { + Iterator ite = accumulators.iterator(); + T mergedV = ite.next(); + while (ite.hasNext()) { + T v = ite.next(); + mergedV = mergedV.compareTo(v) > 0 ? mergedV : v; + } + return mergedV; + } + + @Override + public T result(T accumulator) { + return accumulator; + } + + @Override + public Coder getAccumulatorCoder(CoderRegistry registry) throws CannotProvideCoderException { + return BeamBuiltinAggregations.getSqlTypeCoder(fieldType); + } + } + + /** + * Built-in aggregation for MIN. + */ + public static final class Min> extends BeamSqlUdaf { + public static Min create(SqlTypeName fieldType) { + switch (fieldType) { + case INTEGER: + return new BeamBuiltinAggregations.Min(fieldType); + case SMALLINT: + return new BeamBuiltinAggregations.Min(fieldType); + case TINYINT: + return new BeamBuiltinAggregations.Min(fieldType); + case BIGINT: + return new BeamBuiltinAggregations.Min(fieldType); + case FLOAT: + return new BeamBuiltinAggregations.Min(fieldType); + case DOUBLE: + return new BeamBuiltinAggregations.Min(fieldType); + case TIMESTAMP: + return new BeamBuiltinAggregations.Min(fieldType); + case DECIMAL: + return new BeamBuiltinAggregations.Min(fieldType); + default: + throw new UnsupportedOperationException( + String.format("[%s] is not support in MIN", fieldType)); + } + } + + private final SqlTypeName fieldType; + private Min(SqlTypeName fieldType) { + this.fieldType = fieldType; + } + + @Override + public T init() { + return null; + } + + @Override + public T add(T accumulator, T input) { + return (accumulator == null || accumulator.compareTo(input) > 0) ? input : accumulator; + } + + @Override + public T merge(Iterable accumulators) { + Iterator ite = accumulators.iterator(); + T mergedV = ite.next(); + while (ite.hasNext()) { + T v = ite.next(); + mergedV = mergedV.compareTo(v) < 0 ? mergedV : v; + } + return mergedV; + } + + @Override + public T result(T accumulator) { + return accumulator; + } + + @Override + public Coder getAccumulatorCoder(CoderRegistry registry) throws CannotProvideCoderException { + return BeamBuiltinAggregations.getSqlTypeCoder(fieldType); + } + } + + /** + * Built-in aggregation for SUM. + */ + public static final class Sum extends BeamSqlUdaf { + public static Sum create(SqlTypeName fieldType) { + switch (fieldType) { + case INTEGER: + return new BeamBuiltinAggregations.Sum(fieldType); + case SMALLINT: + return new BeamBuiltinAggregations.Sum(fieldType); + case TINYINT: + return new BeamBuiltinAggregations.Sum(fieldType); + case BIGINT: + return new BeamBuiltinAggregations.Sum(fieldType); + case FLOAT: + return new BeamBuiltinAggregations.Sum(fieldType); + case DOUBLE: + return new BeamBuiltinAggregations.Sum(fieldType); + case TIMESTAMP: + return new BeamBuiltinAggregations.Sum(fieldType); + case DECIMAL: + return new BeamBuiltinAggregations.Sum(fieldType); + default: + throw new UnsupportedOperationException( + String.format("[%s] is not support in SUM", fieldType)); + } + } + + private SqlTypeName fieldType; + private Sum(SqlTypeName fieldType) { + this.fieldType = fieldType; + } + + @Override + public BigDecimal init() { + return new BigDecimal(0); + } + + @Override + public BigDecimal add(BigDecimal accumulator, T input) { + return accumulator.add(new BigDecimal(input.toString())); + } + + @Override + public BigDecimal merge(Iterable accumulators) { + BigDecimal v = new BigDecimal(0); + Iterator ite = accumulators.iterator(); + while (ite.hasNext()) { + v = v.add(ite.next()); + } + return v; + } + + @Override + public T result(BigDecimal accumulator) { + Object result = null; + switch (fieldType) { + case INTEGER: + result = accumulator.intValue(); + break; + case BIGINT: + result = accumulator.longValue(); + break; + case SMALLINT: + result = accumulator.shortValue(); + break; + case TINYINT: + result = accumulator.byteValue(); + break; + case DOUBLE: + result = accumulator.doubleValue(); + break; + case FLOAT: + result = accumulator.floatValue(); + break; + case DECIMAL: + result = accumulator; + break; + default: + break; + } + return (T) result; + } + } + + /** + * Built-in aggregation for AVG. + */ + public static final class Avg extends BeamSqlUdaf, T> { + public static Avg create(SqlTypeName fieldType) { + switch (fieldType) { + case INTEGER: + return new BeamBuiltinAggregations.Avg(fieldType); + case SMALLINT: + return new BeamBuiltinAggregations.Avg(fieldType); + case TINYINT: + return new BeamBuiltinAggregations.Avg(fieldType); + case BIGINT: + return new BeamBuiltinAggregations.Avg(fieldType); + case FLOAT: + return new BeamBuiltinAggregations.Avg(fieldType); + case DOUBLE: + return new BeamBuiltinAggregations.Avg(fieldType); + case TIMESTAMP: + return new BeamBuiltinAggregations.Avg(fieldType); + case DECIMAL: + return new BeamBuiltinAggregations.Avg(fieldType); + default: + throw new UnsupportedOperationException( + String.format("[%s] is not support in AVG", fieldType)); + } + } + + private SqlTypeName fieldType; + private Avg(SqlTypeName fieldType) { + this.fieldType = fieldType; + } + + @Override + public KV init() { + return KV.of(new BigDecimal(0), 0L); + } + + @Override + public KV add(KV accumulator, T input) { + return KV.of( + accumulator.getKey().add(new BigDecimal(input.toString())), + accumulator.getValue() + 1); + } + + @Override + public KV merge(Iterable> accumulators) { + BigDecimal v = new BigDecimal(0); + long s = 0; + Iterator> ite = accumulators.iterator(); + while (ite.hasNext()) { + KV r = ite.next(); + v = v.add(r.getKey()); + s += r.getValue(); + } + return KV.of(v, s); + } + + @Override + public T result(KV accumulator) { + BigDecimal decimalAvg = accumulator.getKey().divide( + new BigDecimal(accumulator.getValue())); + Object result = null; + switch (fieldType) { + case INTEGER: + result = decimalAvg.intValue(); + break; + case BIGINT: + result = decimalAvg.longValue(); + break; + case SMALLINT: + result = decimalAvg.shortValue(); + break; + case TINYINT: + result = decimalAvg.byteValue(); + break; + case DOUBLE: + result = decimalAvg.doubleValue(); + break; + case FLOAT: + result = decimalAvg.floatValue(); + break; + case DECIMAL: + result = decimalAvg; + break; + default: + break; + } + return (T) result; + } + + @Override + public Coder> getAccumulatorCoder(CoderRegistry registry) + throws CannotProvideCoderException { + return KvCoder.of(BigDecimalCoder.of(), VarLongCoder.of()); + } + } + + /** + * Find {@link Coder} for Beam SQL field types. + */ + private static Coder getSqlTypeCoder(SqlTypeName sqlType) { + switch (sqlType) { + case INTEGER: + return VarIntCoder.of(); + case SMALLINT: + return SerializableCoder.of(Short.class); + case TINYINT: + return ByteCoder.of(); + case BIGINT: + return VarLongCoder.of(); + case FLOAT: + return SerializableCoder.of(Float.class); + case DOUBLE: + return DoubleCoder.of(); + case TIMESTAMP: + return SerializableCoder.of(Date.class); + case DECIMAL: + return BigDecimalCoder.of(); + default: + throw new UnsupportedOperationException( + String.format("Cannot find a Coder for data type [%s]", sqlType)); + } + } +} diff --git a/dsls/sql/src/test/java/org/apache/beam/dsls/sql/schema/transform/BeamAggregationTransformTest.java b/dsls/sql/src/test/java/org/apache/beam/dsls/sql/schema/transform/BeamAggregationTransformTest.java index 388a34485ab34..2b01254d041f2 100644 --- a/dsls/sql/src/test/java/org/apache/beam/dsls/sql/schema/transform/BeamAggregationTransformTest.java +++ b/dsls/sql/src/test/java/org/apache/beam/dsls/sql/schema/transform/BeamAggregationTransformTest.java @@ -117,7 +117,7 @@ public void testCountPerElementBasic() throws ParseException { //3. run aggregation functions PCollection> aggregatedStream = groupedStream.apply("aggregation", Combine.groupedValues( - new BeamAggregationTransforms.AggregationCombineFn(aggCalls, inputRowType))) + new BeamAggregationTransforms.AggregationAdaptor(aggCalls, inputRowType))) .setCoder(KvCoder.of(keyCoder, aggCoder)); //4. flat KV to a single record