From ae5eed6ac64429d7e0802d05021311b29f29d1b0 Mon Sep 17 00:00:00 2001 From: Anton Mushin Date: Fri, 3 Feb 2017 14:06:49 +0400 Subject: [PATCH] [FLINK-4604] Add support for standard deviation/variance add rule for reduce standard deviation/variance functions --- docs/dev/table_api.md | 47 +++- .../flink/table/api/scala/expressionDsl.scala | 28 +++ .../table/expressions/ExpressionParser.scala | 46 +++- .../table/expressions/aggregations.scala | 65 ++++++ .../aggfunctions/Sum0AggFunction.scala | 91 ++++++++ .../table/plan/rules/FlinkRuleSets.scala | 3 + .../rules/dataSet/DataSetAggregateRule.scala | 9 +- .../DataSetAggregateWithNullValuesRule.scala | 8 +- .../table/validate/FunctionCatalog.scala | 10 + .../table/api/java/batch/sql/SqlITCase.java | 49 ++++- .../scala/batch/sql/AggregationsITCase.scala | 200 ++++++++++++++++++ .../batch/table/AggregationsITCase.scala | 76 +++++++ .../batch/utils/TableProgramsTestBase.scala | 5 + 13 files changed, 627 insertions(+), 10 deletions(-) create mode 100644 flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/Sum0AggFunction.scala diff --git a/docs/dev/table_api.md b/docs/dev/table_api.md index 2a838c7f1a42d..5f82c5815c8f0 100644 --- a/docs/dev/table_api.md +++ b/docs/dev/table_api.md @@ -1005,7 +1005,7 @@ dataType = "BYTE" | "SHORT" | "INT" | "LONG" | "FLOAT" | "DOUBLE" | "BOOLEAN" | as = composite , ".as(" , fieldReference , ")" ; -aggregation = composite , ( ".sum" | ".min" | ".max" | ".count" | ".avg" | ".start" | ".end" ) , [ "()" ] ; +aggregation = composite , ( ".sum" | ".min" | ".max" | ".count" | ".avg" | ".start" | ".end" | ".stddev_pop" | ".stddev_samp" | ".var_pop" | ".var_samp" ) , [ "()" ] ; if = composite , ".?(" , expression , "," , expression , ")" ; @@ -4772,7 +4772,7 @@ AVG(numeric)

Returns the average (arithmetic mean) of numeric across all input values.

- + {% highlight text %} @@ -4915,6 +4915,49 @@ ELEMENT(ARRAY)

Returns the sole element of an array with a single element. Returns null if the array is empty. Throws an exception if the array has more than one element.

+ + + {% highlight text %} +STDDEV_POP(value) +{% endhighlight %} + + +

Returns the standard deviation of numeric value

+ + + + + + {% highlight text %} +STDDEV_SAMP(value) +{% endhighlight %} + + +

Returns the sample standard deviation of numeric value

+ + + + + + {% highlight text %} +VAR_POP(value) +{% endhighlight %} + + +

Returns the variance of numeric value

+ + + + + + {% highlight text %} +VAR_SAMP (value) +{% endhighlight %} + + +

Returns the sample variance of numeric value

+ + diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/scala/expressionDsl.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/scala/expressionDsl.scala index 06d46e3cead4c..e0e4a2eca6283 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/scala/expressionDsl.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/scala/expressionDsl.scala @@ -163,6 +163,12 @@ trait ImplicitExpressionOperations { */ def sum = Sum(expr) + /** + * Returns the sum of the values which go into it like [[Sum]]. + * It differs in that when no non null values are applied zero is returned instead of null. + */ + def sum0 = Sum0(expr) + /** * Returns the minimum value of field across all input values. */ @@ -183,6 +189,28 @@ trait ImplicitExpressionOperations { */ def avg = Avg(expr) + /** + * Returns the population standard deviation of an expression. + * (the square root of [[VarPop]]) + */ + def stddev_pop = StddevPop(expr) + + /** + * Returns the sample standard deviation of an expression. + * (the square root of [[VarSamp]]). + */ + def stddev_samp = StddevSamp(expr) + + /** + * Returns the population standard variance of an expression. + */ + def var_pop = VarPop(expr) + + /** + * Returns the sample variance of a given expression. + */ + def var_samp = VarSamp(expr) + /** * Converts a value to a given type. * diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/expressions/ExpressionParser.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/expressions/ExpressionParser.scala index ed0b16ea0d6f3..1563bebb7beb4 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/expressions/ExpressionParser.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/expressions/ExpressionParser.scala @@ -57,6 +57,11 @@ object ExpressionParser extends JavaTokenParsers with PackratParsers { lazy val SUM: Keyword = Keyword("sum") lazy val START: Keyword = Keyword("start") lazy val END: Keyword = Keyword("end") + lazy val SUM0: Keyword = Keyword("sum0") + lazy val STDDEV_POP: Keyword = Keyword("stddev_pop") + lazy val STDDEV_SAMP: Keyword = Keyword("stddev_samp") + lazy val VAR_POP: Keyword = Keyword("var_pop") + lazy val VAR_SAMP: Keyword = Keyword("var_samp") lazy val CAST: Keyword = Keyword("cast") lazy val NULL: Keyword = Keyword("Null") lazy val IF: Keyword = Keyword("?") @@ -89,8 +94,9 @@ object ExpressionParser extends JavaTokenParsers with PackratParsers { lazy val FLATTEN: Keyword = Keyword("flatten") def functionIdent: ExpressionParser.Parser[String] = - not(ARRAY) ~ not(AS) ~ not(COUNT) ~ not(AVG) ~ not(MIN) ~ not(MAX) ~ - not(SUM) ~ not(START) ~ not(END)~ not(CAST) ~ not(NULL) ~ + not(ARRAY) ~ not(AS) ~ not(COUNT) ~ not(AVG) ~ not(MIN) ~ not(MAX) ~ not(STDDEV_POP) ~ + not(STDDEV_SAMP) ~ not(VAR_SAMP) ~ not(VAR_POP) ~ + not(SUM) ~ not(START) ~ not(END)~ not(SUM0) ~ not(CAST) ~ not(NULL) ~ not(IF) ~> super.ident // symbols @@ -181,6 +187,9 @@ object ExpressionParser extends JavaTokenParsers with PackratParsers { lazy val suffixSum: PackratParser[Expression] = composite <~ "." ~ SUM ~ opt("()") ^^ { e => Sum(e) } + lazy val suffixSum0: PackratParser[Expression] = + composite <~ "." ~ SUM0 ~ opt("()") ^^ { e => Sum0(e) } + lazy val suffixMin: PackratParser[Expression] = composite <~ "." ~ MIN ~ opt("()") ^^ { e => Min(e) } @@ -199,6 +208,17 @@ object ExpressionParser extends JavaTokenParsers with PackratParsers { lazy val suffixEnd: PackratParser[Expression] = composite <~ "." ~ END ~ opt("()") ^^ { e => WindowEnd(e) } + lazy val suffixStddevPop: PackratParser[Expression] = + composite <~ "." ~ STDDEV_POP ~ opt("()") ^^ { e => StddevPop(e) } + + lazy val suffixStddevSamp: PackratParser[Expression] = + composite <~ "." ~ STDDEV_SAMP ~ opt("()") ^^ { e => StddevSamp(e) } + + lazy val suffixVarSamp: PackratParser[Expression] = + composite <~ "." ~ VAR_SAMP ~ opt("()") ^^ { e => VarSamp(e) } + + lazy val suffixVarPop: PackratParser[Expression] = + composite <~ "." ~ VAR_POP ~ opt("()") ^^ { e => VarPop(e) } lazy val suffixCast: PackratParser[Expression] = composite ~ "." ~ CAST ~ "(" ~ dataType ~ ")" ^^ { case e ~ _ ~ _ ~ _ ~ dt ~ _ => Cast(e, dt) @@ -290,8 +310,9 @@ object ExpressionParser extends JavaTokenParsers with PackratParsers { composite <~ "." ~ FLATTEN ~ opt("()") ^^ { e => Flattening(e) } lazy val suffixed: PackratParser[Expression] = - suffixTimeInterval | suffixRowInterval | suffixSum | suffixMin | suffixMax | suffixStart | - suffixEnd | suffixCount | suffixAvg | suffixCast | suffixAs | suffixTrim | + suffixTimeInterval | suffixRowInterval | suffixSum | suffixSum0 | suffixMin | suffixMax | + suffixStart | suffixEnd | suffixCount | suffixAvg | suffixCast | suffixAs | suffixTrim | + suffixStddevPop | suffixStddevSamp | suffixVarPop | suffixVarSamp | suffixTrimWithoutArgs | suffixIf | suffixAsc | suffixDesc | suffixToDate | suffixToTimestamp | suffixToTime | suffixExtract | suffixFloor | suffixCeil | suffixGet | suffixFlattening | @@ -305,6 +326,9 @@ object ExpressionParser extends JavaTokenParsers with PackratParsers { lazy val prefixSum: PackratParser[Expression] = SUM ~ "(" ~> expression <~ ")" ^^ { e => Sum(e) } + lazy val prefixSum0: PackratParser[Expression] = + SUM0 ~ "(" ~> expression <~ ")" ^^ { e => Sum0(e) } + lazy val prefixMin: PackratParser[Expression] = MIN ~ "(" ~> expression <~ ")" ^^ { e => Min(e) } @@ -323,6 +347,17 @@ object ExpressionParser extends JavaTokenParsers with PackratParsers { lazy val prefixEnd: PackratParser[Expression] = END ~ "(" ~> expression <~ ")" ^^ { e => WindowEnd(e) } + lazy val prefixStddevPop: PackratParser[Expression] = + STDDEV_POP ~ "(" ~> expression <~ ")" ^^ { e => StddevPop(e) } + + lazy val prefixStddevSamp: PackratParser[Expression] = + STDDEV_SAMP ~ "(" ~> expression <~ ")" ^^ { e => StddevSamp(e) } + + lazy val prefixVarSamp: PackratParser[Expression] = + VAR_SAMP ~ "(" ~> expression <~ ")" ^^ { e => VarSamp(e) } + + lazy val prefixVarPop: PackratParser[Expression] = + VAR_POP ~ "(" ~> expression <~ ")" ^^ { e => VarPop(e) } lazy val prefixCast: PackratParser[Expression] = CAST ~ "(" ~ expression ~ "," ~ dataType ~ ")" ^^ { case _ ~ _ ~ e ~ _ ~ dt ~ _ => Cast(e, dt) @@ -376,7 +411,8 @@ object ExpressionParser extends JavaTokenParsers with PackratParsers { FLATTEN ~ "(" ~> composite <~ ")" ^^ { e => Flattening(e) } lazy val prefixed: PackratParser[Expression] = - prefixArray | prefixSum | prefixMin | prefixMax | prefixCount | prefixAvg | + prefixArray | prefixSum | prefixSum0 | prefixMin | prefixMax | prefixCount | prefixAvg | + prefixStddevPop | prefixStddevSamp | prefixVarSamp | prefixVarPop | prefixStart | prefixEnd | prefixCast | prefixAs | prefixTrim | prefixTrimWithoutArgs | prefixIf | prefixExtract | prefixFloor | prefixCeil | prefixGet | prefixFlattening | prefixFunctionCall | prefixFunctionCallOneArg // function call must always be at the end diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/expressions/aggregations.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/expressions/aggregations.scala index b2fca883901e8..8ea2fe89f4cdf 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/expressions/aggregations.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/expressions/aggregations.scala @@ -50,6 +50,19 @@ case class Sum(child: Expression) extends Aggregation { TypeCheckUtils.assertNumericExpr(child.resultType, "sum") } +case class Sum0(child: Expression) extends Aggregation { + override def toString = s"sum0($child)" + + override private[flink] def toAggCall(name: String)(implicit relBuilder: RelBuilder): AggCall = { + relBuilder.aggregateCall(SqlStdOperatorTable.SUM0, false, null, name, child.toRexNode) + } + + override private[flink] def resultType = child.resultType + + override private[flink] def validateInput = + TypeCheckUtils.assertNumericExpr(child.resultType, "sum0") +} + case class Min(child: Expression) extends Aggregation { override def toString = s"min($child)" @@ -98,3 +111,55 @@ case class Avg(child: Expression) extends Aggregation { override private[flink] def validateInput() = TypeCheckUtils.assertNumericExpr(child.resultType, "avg") } + +case class StddevPop(child: Expression) extends Aggregation { + override def toString = s"stddev_pop($child)" + + override private[flink] def toAggCall(name: String)(implicit relBuilder: RelBuilder): AggCall = { + relBuilder.aggregateCall(SqlStdOperatorTable.STDDEV_POP, false, null, name, child.toRexNode) + } + + override private[flink] def resultType = child.resultType + + override private[flink] def validateInput = + TypeCheckUtils.assertNumericExpr(child.resultType, "stddev_pop") +} + +case class StddevSamp(child: Expression) extends Aggregation { + override def toString = s"stddev_samp($child)" + + override private[flink] def toAggCall(name: String)(implicit relBuilder: RelBuilder): AggCall = { + relBuilder.aggregateCall(SqlStdOperatorTable.STDDEV_SAMP, false, null, name, child.toRexNode) + } + + override private[flink] def resultType = child.resultType + + override private[flink] def validateInput = + TypeCheckUtils.assertNumericExpr(child.resultType, "stddev_samp") +} + +case class VarPop(child: Expression) extends Aggregation { + override def toString = s"var_pop($child)" + + override private[flink] def toAggCall(name: String)(implicit relBuilder: RelBuilder): AggCall = { + relBuilder.aggregateCall(SqlStdOperatorTable.VAR_POP, false, null, name, child.toRexNode) + } + + override private[flink] def resultType = child.resultType + + override private[flink] def validateInput = + TypeCheckUtils.assertNumericExpr(child.resultType, "var_pop") +} + +case class VarSamp(child: Expression) extends Aggregation { + override def toString = s"var_samp($child)" + + override private[flink] def toAggCall(name: String)(implicit relBuilder: RelBuilder): AggCall = { + relBuilder.aggregateCall(SqlStdOperatorTable.VAR_SAMP, false, null, name, child.toRexNode) + } + + override private[flink] def resultType = child.resultType + + override private[flink] def validateInput = + TypeCheckUtils.assertNumericExpr(child.resultType, "var_samp") +} diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/Sum0AggFunction.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/Sum0AggFunction.scala new file mode 100644 index 0000000000000..6a24fbe821a75 --- /dev/null +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/Sum0AggFunction.scala @@ -0,0 +1,91 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.flink.table.functions.aggfunctions + +import java.math.BigDecimal + +import org.apache.flink.api.common.typeinfo.BasicTypeInfo +import org.apache.flink.table.functions.Accumulator + +abstract class Sum0AggFunction[T: Numeric] extends SumAggFunction[T] { + + override def getValue(accumulator: Accumulator): T = { + val a = accumulator.asInstanceOf[SumAccumulator[T]] + if (a.f1) { + a.f0 + } else { + 0.asInstanceOf[T] + } + } +} + +/** + * Built-in Byte Sum0 aggregate function + */ +class ByteSum0AggFunction extends Sum0AggFunction[Byte] { + override def getValueTypeInfo = BasicTypeInfo.BYTE_TYPE_INFO +} + +/** + * Built-in Short Sum0 aggregate function + */ +class ShortSum0AggFunction extends Sum0AggFunction[Short] { + override def getValueTypeInfo = BasicTypeInfo.SHORT_TYPE_INFO +} + +/** + * Built-in Int Sum0 aggregate function + */ +class IntSum0AggFunction extends Sum0AggFunction[Int] { + override def getValueTypeInfo = BasicTypeInfo.INT_TYPE_INFO +} + +/** + * Built-in Long Sum0 aggregate function + */ +class LongSum0AggFunction extends Sum0AggFunction[Long] { + override def getValueTypeInfo = BasicTypeInfo.LONG_TYPE_INFO +} + +/** + * Built-in Float Sum0 aggregate function + */ +class FloatSum0AggFunction extends Sum0AggFunction[Float] { + override def getValueTypeInfo = BasicTypeInfo.FLOAT_TYPE_INFO +} + +/** + * Built-in Double Sum0 aggregate function + */ +class DoubleSum0AggFunction extends Sum0AggFunction[Double] { + override def getValueTypeInfo = BasicTypeInfo.DOUBLE_TYPE_INFO +} + +/** + * Built-in Big Decimal Sum0 aggregate function + */ +class DecimalSum0AggFunction extends DecimalSumAggFunction { + + override def getValue(accumulator: Accumulator): BigDecimal = { + if (!accumulator.asInstanceOf[DecimalSumAccumulator].f1) { + 0.asInstanceOf[BigDecimal] + } else { + accumulator.asInstanceOf[DecimalSumAccumulator].f0 + } + } +} diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/rules/FlinkRuleSets.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/rules/FlinkRuleSets.scala index 41f095f032afb..fe80a26cea3b4 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/rules/FlinkRuleSets.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/rules/FlinkRuleSets.scala @@ -89,6 +89,9 @@ object FlinkRuleSets { // expand distinct aggregate to normal aggregate with groupby AggregateExpandDistinctAggregatesRule.JOIN, + //aggregate reduce rule (deviation/variance functions) + AggregateReduceFunctionsRule.INSTANCE, + // remove unnecessary sort rule SortRemoveRule.INSTANCE, diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/rules/dataSet/DataSetAggregateRule.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/rules/dataSet/DataSetAggregateRule.scala index 98d1c13a412d0..d57b057670373 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/rules/dataSet/DataSetAggregateRule.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/rules/dataSet/DataSetAggregateRule.scala @@ -22,6 +22,8 @@ import org.apache.calcite.plan.{Convention, RelOptRule, RelOptRuleCall, RelTrait import org.apache.calcite.rel.RelNode import org.apache.calcite.rel.convert.ConverterRule import org.apache.calcite.rel.logical.LogicalAggregate +import org.apache.calcite.sql.SqlKind +import org.apache.flink.table.api.TableException import org.apache.flink.table.plan.nodes.dataset.{DataSetAggregate, DataSetConvention, DataSetUnion} import scala.collection.JavaConversions._ @@ -52,7 +54,12 @@ class DataSetAggregateRule // check if we have distinct aggregates val distinctAggs = agg.getAggCallList.exists(_.isDistinct) - !distinctAggs + val supported = agg.getAggCallList.map(_.getAggregation.getKind).forall { + case SqlKind.STDDEV_POP | SqlKind.STDDEV_SAMP | SqlKind.VAR_POP | SqlKind.VAR_SAMP => false + case _ => true + } + + !distinctAggs && supported } override def convert(rel: RelNode): RelNode = { diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/rules/dataSet/DataSetAggregateWithNullValuesRule.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/rules/dataSet/DataSetAggregateWithNullValuesRule.scala index aa977b1c022be..b67173282e17c 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/rules/dataSet/DataSetAggregateWithNullValuesRule.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/rules/dataSet/DataSetAggregateWithNullValuesRule.scala @@ -18,6 +18,7 @@ package org.apache.flink.table.plan.rules.dataSet import org.apache.calcite.plan._ +import org.apache.calcite.sql.SqlKind import scala.collection.JavaConversions._ import com.google.common.collect.ImmutableList @@ -51,7 +52,12 @@ class DataSetAggregateWithNullValuesRule // check if we have distinct aggregates val distinctAggs = agg.getAggCallList.exists(_.isDistinct) - !distinctAggs + val supported = agg.getAggCallList.map(_.getAggregation.getKind).forall { + case SqlKind.STDDEV_POP | SqlKind.STDDEV_SAMP | SqlKind.VAR_POP | SqlKind.VAR_SAMP => false + case _ => true + } + + !distinctAggs && supported } override def convert(rel: RelNode): RelNode = { diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/validate/FunctionCatalog.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/validate/FunctionCatalog.scala index 74b371aedd9fd..519d2c2361bfe 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/validate/FunctionCatalog.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/validate/FunctionCatalog.scala @@ -169,6 +169,11 @@ object FunctionCatalog { "max" -> classOf[Max], "min" -> classOf[Min], "sum" -> classOf[Sum], + "sum0" -> classOf[Sum0], + "stddev_pop" -> classOf[StddevPop], + "stddev_samp" -> classOf[StddevSamp], + "var_pop" -> classOf[VarPop], + "var_samp" -> classOf[VarSamp], // string functions "charLength" -> classOf[CharLength], @@ -293,10 +298,15 @@ class BasicOperatorTable extends ReflectiveSqlOperatorTable { SqlStdOperatorTable.GROUPING_ID, // AGGREGATE OPERATORS SqlStdOperatorTable.SUM, + SqlStdOperatorTable.SUM0, SqlStdOperatorTable.COUNT, SqlStdOperatorTable.MIN, SqlStdOperatorTable.MAX, SqlStdOperatorTable.AVG, + SqlStdOperatorTable.STDDEV_POP, + SqlStdOperatorTable.STDDEV_SAMP, + SqlStdOperatorTable.VAR_POP, + SqlStdOperatorTable.VAR_SAMP, // ARRAY OPERATORS SqlStdOperatorTable.ARRAY_VALUE_CONSTRUCTOR, SqlStdOperatorTable.ITEM, diff --git a/flink-libraries/flink-table/src/test/java/org/apache/flink/table/api/java/batch/sql/SqlITCase.java b/flink-libraries/flink-table/src/test/java/org/apache/flink/table/api/java/batch/sql/SqlITCase.java index 5ba67dd355372..4c9587645d09b 100644 --- a/flink-libraries/flink-table/src/test/java/org/apache/flink/table/api/java/batch/sql/SqlITCase.java +++ b/flink-libraries/flink-table/src/test/java/org/apache/flink/table/api/java/batch/sql/SqlITCase.java @@ -18,6 +18,7 @@ package org.apache.flink.table.api.java.batch.sql; +import org.apache.flink.api.common.functions.MapFunction; import org.apache.flink.api.java.DataSet; import org.apache.flink.api.java.ExecutionEnvironment; import org.apache.flink.table.api.java.BatchTableEnvironment; @@ -31,7 +32,10 @@ import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.Parameterized; +import scala.collection.JavaConversions; +import scala.collection.mutable.Buffer; +import java.util.Arrays; import java.util.List; @RunWith(Parameterized.class) @@ -128,7 +132,7 @@ public void testJoin() throws Exception { DataSet> ds2 = CollectionDataSets.get5TupleDataSet(env); tableEnv.registerDataSet("t1", ds1, "a, b, c"); - tableEnv.registerDataSet("t2",ds2, "d, e, f, g, h"); + tableEnv.registerDataSet("t2", ds2, "d, e, f, g, h"); String sqlQuery = "SELECT c, g FROM t1, t2 WHERE b = e"; Table result = tableEnv.sql(sqlQuery); @@ -138,4 +142,47 @@ public void testJoin() throws Exception { String expected = "Hi,Hallo\n" + "Hello,Hallo Welt\n" + "Hello world,Hallo Welt\n"; compareResultAsText(results, expected); } + + @Test + public void testDeviationAggregation() throws Exception { + ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment(); + BatchTableEnvironment tableEnv = TableEnvironment.getTableEnvironment(env, config()); + + DataSet> ds = CollectionDataSets.get3TupleDataSet(env); + tableEnv.registerDataSet("AggTable", ds, "x, y, z"); + + Buffer columnForAgg = JavaConversions.asScalaBuffer(Arrays.asList("x, y".split(","))); + + String sqlQuery = getSelectQueryFromTemplate("AVG(?),STDDEV_POP(?),STDDEV_SAMP(?),VAR_POP(?),VAR_SAMP(?)", columnForAgg, "AggTable"); + Table result = tableEnv.sql(sqlQuery); + + String sqlQuery1 = getSelectQueryFromTemplate("SUM(?)/COUNT(?), " + + "SQRT( (SUM(? * ?) - SUM(?) * SUM(?) / COUNT(?)) / COUNT(?)), " + + "SQRT( (SUM(? * ?) - SUM(?) * SUM(?) / COUNT(?)) / CASE COUNT(?) WHEN 1 THEN NULL ELSE COUNT(?) - 1 END), " + + "(SUM(? * ?) - SUM(?) * SUM(?) / COUNT(?)) / COUNT(?), " + + "(SUM(? * ?) - SUM(?) * SUM(?) / COUNT(?)) / CASE COUNT(?) WHEN 1 THEN NULL ELSE COUNT(?) - 1 END", columnForAgg, "AggTable"); + + Table expected = tableEnv.sql(sqlQuery1); + + DataSet resultSet = tableEnv.toDataSet(result, Row.class); + List results = resultSet.collect(); + + DataSet expectedResultSet = tableEnv.toDataSet(expected, Row.class); + String expectedResults = expectedResultSet.map(new MapFunction() { + @Override + public Object map(Row value) throws Exception { + StringBuilder stringBuffer = new StringBuilder(); + + int arityCount = value.getArity(); + + for (int i = 0; i < arityCount; i++) { + Object product = value.getField(i); + stringBuffer.append(Double.valueOf(product.toString()).intValue()).append(","); + } + return stringBuffer.substring(0, stringBuffer.length() - 1); + } + }).collect().get(0).toString(); + + compareResultAsText(results, expectedResults); + } } diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/batch/sql/AggregationsITCase.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/batch/sql/AggregationsITCase.scala index 600c15be7ba5c..4fbd73475c5bf 100644 --- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/batch/sql/AggregationsITCase.scala +++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/batch/sql/AggregationsITCase.scala @@ -295,6 +295,206 @@ class AggregationsITCase( TestBaseUtils.compareResultAsText(results3.asJava, expected3) } + + + @Test + def testSqrtOfAggregatedSet(): Unit = { + val env = ExecutionEnvironment.getExecutionEnvironment + val tEnv = TableEnvironment.getTableEnvironment(env, config) + + val ds = env.fromElements((1.0f, 1), (2.0f, 2)).toTable(tEnv) + + tEnv.registerTable("MyTable", ds) + + val sqlQuery = "SELECT " + + "SQRT((SUM(a * a) - SUM(a) * SUM(a) / COUNT(a)) / COUNT(a)) " + + "from (select _1 as a from MyTable)" + + val expected = "0.5" + val results = tEnv.sql(sqlQuery).toDataSet[Row].collect() + TestBaseUtils.compareResultAsText(results.asJava, expected) + } + + @Test + def testStddevPopAggregate(): Unit = { + val env = ExecutionEnvironment.getExecutionEnvironment + val tEnv = TableEnvironment.getTableEnvironment(env, config) + + val ds = env.fromElements( + (1: Byte, 1 : Short, 1, 1L, 1F, 1D), + (2: Byte, 2 : Short, 2, 2L, 2F, 2D)).toTable(tEnv) + tEnv.registerTable("myTable", ds) + val columns = Array("_1","_2","_3","_4","_5","_6") + + val sqlQuery = getSelectQueryFromTemplate("STDDEV_POP(?)")(columns,"myTable") + + val actualResult = tEnv.sql(sqlQuery).toDataSet[Row].collect() + val expectedResult = "0,0,0,0,0.5,0.5" + TestBaseUtils.compareOrderedResultAsText(actualResult.asJava, expectedResult) + } + + @Test + def testStddevPopAggregateWithOtherAggreagteSUM0(): Unit = { + val env = ExecutionEnvironment.getExecutionEnvironment + val tEnv = TableEnvironment.getTableEnvironment(env, config) + + val ds = env.fromElements( + (1: Byte, 1 : Short, 1, 1L, 1F, 1D), + (2: Byte, 2 : Short, 2, 2L, 2F, 2D)).toTable(tEnv) + tEnv.registerTable("myTable", ds) + val columns = Array("_1","_2","_3","_4","_5","_6") + + val sqlQuery = getSelectQueryFromTemplate("STDDEV_POP(?), " + + "$sum0(?), " + + "avg(?), " + + "max(?), " + + "min(?), " + + "count(?)" ) (columns,"myTable") + + val actualResult = tEnv.sql(sqlQuery).toDataSet[Row].collect() + + val expectedResult = + "0,3,1,2,1,2," + + "0,3,1,2,1,2," + + "0,3,1,2,1,2," + + "0,3,1,2,1,2," + + "0.5,3.0,1.5,2.0,1.0,2," + + "0.5,3.0,1.5,2.0,1.0,2" + TestBaseUtils.compareOrderedResultAsText(actualResult.asJava, expectedResult) + } + + @Test + def testStddevPopAggregateWithOtherAggreagte(): Unit = { + val env = ExecutionEnvironment.getExecutionEnvironment + val tEnv = TableEnvironment.getTableEnvironment(env, config) + + val ds = env.fromElements( + (1: Byte, 1 : Short, 1, 1L, 1F, 1D), + (2: Byte, 2 : Short, 2, 2L, 2F, 2D)).toTable(tEnv) + tEnv.registerTable("myTable", ds) + val columns = Array("_1","_2","_3","_4","_5","_6") + + val sqlQuery = getSelectQueryFromTemplate("STDDEV_POP(?), " + + "sum(?), " + + "avg(?), " + + "max(?), " + + "min(?), " + + "count(?)" )(columns,"myTable") + + val actualResult = tEnv.sql(sqlQuery).toDataSet[Row].collect() + + val expectedResult = + "0,3,1,2,1,2," + + "0,3,1,2,1,2," + + "0,3,1,2,1,2," + + "0,3,1,2,1,2," + + "0.5,3.0,1.5,2.0,1.0,2," + + "0.5,3.0,1.5,2.0,1.0,2" + TestBaseUtils.compareOrderedResultAsText(actualResult.asJava, expectedResult) + } + + @Test + def testStddevSampAggregate(): Unit = { + val env = ExecutionEnvironment.getExecutionEnvironment + val tEnv = TableEnvironment.getTableEnvironment(env, config) + + val ds1 = env.fromElements( + (1: Byte, 1 : Short, 1, 1L, 1F, 1D), + (2: Byte, 2 : Short, 2, 2L, 2F, 2D)).toTable(tEnv) + tEnv.registerTable("myTable", ds1) + val columns = Seq("_1","_2","_3","_4","_5","_6") + + val sqlQuery = getSelectQueryFromTemplate("STDDEV_SAMP(?)")(columns,"myTable") + val sqlExpectedQuery = getSelectQueryFromTemplate( + "SQRT((SUM(? * ?) - SUM(?) * SUM(?) / COUNT(?)) / " + + "CASE " + + "COUNT(?) WHEN 1 THEN NULL " + + "ELSE COUNT(?) - 1 " + + "END)")(columns,"myTable") + + val actualResult = tEnv.sql(sqlQuery).toDataSet[Row].collect() + .head + .toString + .split(",").map(x=>"%.5f".format(x.toFloat)) + + val expectedResult = tEnv.sql(sqlExpectedQuery).toDataSet[Row].collect() + .head + .toString + .split(",").map(x=>"%.5f".format(x.toFloat)) + + Assert.assertEquals(expectedResult.mkString(","), actualResult.mkString(",")) + } + + @Test + def testVarPopAggregate(): Unit = { + val env = ExecutionEnvironment.getExecutionEnvironment + val tEnv = TableEnvironment.getTableEnvironment(env, config) + + val ds = env.fromElements( + (1: Byte, 1 : Short, 1, 1L, 1F, 1D), + (2: Byte, 2 : Short, 2, 2L, 2F, 2D)).toTable(tEnv) + tEnv.registerTable("myTable", ds) + val columns = Seq("_1","_2","_3","_4","_5","_6") + + val sqlQuery = getSelectQueryFromTemplate("var_pop(?)")(columns,"myTable") + val sqlExpectedQuery = getSelectQueryFromTemplate( + "(SUM(? * ?) - SUM(?) * SUM(?) / COUNT(?)) / COUNT(?)")(columns,"myTable") + + val actualResult = tEnv.sql(sqlQuery).toDataSet[Row].collect() + val expectedResult = tEnv.sql(sqlExpectedQuery) + .toDataSet[Row] + .collect().head + .toString + TestBaseUtils.compareOrderedResultAsText(actualResult.asJava, expectedResult) + } + + @Test + def testVarSampAggregate(): Unit = { + val env = ExecutionEnvironment.getExecutionEnvironment + val tEnv = TableEnvironment.getTableEnvironment(env, config) + + val ds = env.fromElements( + (1: Byte, 1 : Short, 1, 1L, 1F, 1D), + (2: Byte, 2 : Short, 2, 2L, 2F, 2D)).toTable(tEnv) + tEnv.registerTable("myTable", ds) + val columns = Seq("_1","_2","_3","_4","_5","_6") + + val sqlQuery = getSelectQueryFromTemplate("var_samp(?)")(columns,"myTable") + val sqlExpectedQuery = getSelectQueryFromTemplate( + "(SUM(? * ?) - SUM(?) * SUM(?) / COUNT(?)) / CASE COUNT(?) WHEN 1 THEN NULL " + + "ELSE COUNT(?) - 1 END")(columns,"myTable") + + val actualResult = tEnv.sql(sqlQuery).toDataSet[Row].collect() + val expectedResult = tEnv.sql(sqlExpectedQuery) + .toDataSet[Row] + .collect().head + .toString + TestBaseUtils.compareOrderedResultAsText(actualResult.asJava, expectedResult) + } + + @Test + def testSumNullElements(): Unit = { + + val env = ExecutionEnvironment.getExecutionEnvironment + val tEnv = TableEnvironment.getTableEnvironment(env, config) + + val sqlQuery = getSelectQueryFromTemplate("$sum0(?)")( + Seq("_1","_2","_3","_4","_5","_6"), + "(select * from MyTable where _1 = 4)" + ) + + val ds = env.fromElements( + (1: Byte, 2L,1D,1F,1,1:Short ), + (2: Byte, 2L,1D,1F,1,1:Short )) + tEnv.registerDataSet("MyTable", ds) + + val result = tEnv.sql(sqlQuery) + + val expected = "null,null,null,null,null,null" + val results = result.toDataSet[Row].collect() + TestBaseUtils.compareResultAsText(results.asJava, expected) + } + @Test def testTumbleWindowAggregate(): Unit = { diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/batch/table/AggregationsITCase.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/batch/table/AggregationsITCase.scala index 22b7f0f230959..fe43935f40e2c 100644 --- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/batch/table/AggregationsITCase.scala +++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/batch/table/AggregationsITCase.scala @@ -339,4 +339,80 @@ class AggregationsITCase( val results = t.toDataSet[Row].collect() TestBaseUtils.compareResultAsText(results.asJava, expected) } + + @Test + def testAnalyticAggregation(): Unit = { + val env = ExecutionEnvironment.getExecutionEnvironment + val tEnv = TableEnvironment.getTableEnvironment(env) + val t = CollectionDataSets.get3TupleDataSet(env).toTable(tEnv) + .select('_1.stddev_pop, '_1.stddev_samp, '_1.var_pop, '_1.var_samp) + val results = t.toDataSet[Row].collect() + val expected = "6,6,36,38" + TestBaseUtils.compareResultAsText(results.asJava, expected) + } + + @Test + def testSQLStyleAnalyticAggregation(): Unit = { + val env = ExecutionEnvironment.getExecutionEnvironment + val tEnv = TableEnvironment.getTableEnvironment(env) + val t = CollectionDataSets.get3TupleDataSet(env).toTable(tEnv, 'a, 'b, 'c) + .select( + """stddev_pop(a) as a1, a.stddev_pop as a2, + |stddev_samp (a) as b1, a.stddev_samp as b2, + |var_pop (a) as c1, a.var_pop as c2, + |var_samp (a) as d1, a.var_samp as d2 + """.stripMargin) + val expected = "6,6,6,6,36,36,38,38" + val results = t.toDataSet[Row].collect() + TestBaseUtils.compareResultAsText(results.asJava, expected) + } + + @Test + def testWorkingAnalyticAggregationDataTypes(): Unit = { + val env = ExecutionEnvironment.getExecutionEnvironment + val tEnv = TableEnvironment.getTableEnvironment(env) + val ds = env.fromElements( + (1: Byte, 1: Short, 1, 1L, 1.0f, 1.0d), + (2: Byte, 2: Short, 2, 2L, 2.0f, 2.0d)).toTable(tEnv) + val res = ds.select('_1.stddev_pop, '_2.stddev_pop, '_3.stddev_pop, + '_4.stddev_pop, '_5.stddev_pop, '_6.stddev_pop, + '_1.stddev_samp, '_2.stddev_samp, '_3.stddev_samp, + '_4.stddev_samp, '_5.stddev_samp, '_6.stddev_samp, + '_1.var_pop, '_2.var_pop, '_3.var_pop, + '_4.var_pop, '_5.var_pop, '_6.var_pop, + '_1.var_samp, '_2.var_samp, '_3.var_samp, + '_4.var_samp, '_5.var_samp, '_6.var_samp) + val expected = + "0,0,0," + + "0,0.5,0.5," + + "1,1,1," + + "1,0.70710677,0.7071067811865476," + + "0,0,0," + + "0,0.25,0.25," + + "1,1,1," + + "1,0.5,0.5" + val results = res.toDataSet[Row].collect() + TestBaseUtils.compareResultAsText(results.asJava, expected) + } + + @Test + def testPojoAnalyticAggregation(): Unit = { + val env = ExecutionEnvironment.getExecutionEnvironment + val tEnv = TableEnvironment.getTableEnvironment(env) + val input = env.fromElements( + MyWC("hello", 1), + MyWC("hello", 8), + MyWC("ciao", 3), + MyWC("hola", 1), + MyWC("hola", 8)) + val expr = input.toTable(tEnv) + val result = expr + .groupBy('word) + .select('word, 'frequency.stddev_pop) + .toDataSet[MyWC] + val mappedResult = result.map(w => (w.word, w.frequency)).collect() + val expected = "(hola,3)\n(ciao,0)\n(hello,3)" + TestBaseUtils.compareResultAsText(mappedResult.asJava, expected) + } + } diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/batch/utils/TableProgramsTestBase.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/batch/utils/TableProgramsTestBase.scala index cf9d947f100c7..ee8e1f547b92e 100644 --- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/batch/utils/TableProgramsTestBase.scala +++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/batch/utils/TableProgramsTestBase.scala @@ -37,6 +37,11 @@ class TableProgramsTestBase( } conf } + + def getSelectQueryFromTemplate(selectBlock: String) + (columnsName: Seq[String], table :String): String = { + s"SELECT ${columnsName.map(x=>selectBlock.replace("?",x)).mkString(",")} FROM $table" + } } object TableProgramsTestBase {