From cda5dd74a32dae2fae922a14fd219abea7e686dd Mon Sep 17 00:00:00 2001 From: Alexander Chermenin Date: Sat, 10 Dec 2016 23:09:25 +0300 Subject: [PATCH 01/18] [FLINK-2980] Base implementation of grouping sets. --- .../org/apache/flink/table/api/table.scala | 188 ++++++++++++++++++ .../flink/table/plan/logical/operators.scala | 79 ++++++++ .../apache/flink/table/AggregationTest.scala | 40 ++++ 3 files changed, 307 insertions(+) diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/table.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/table.scala index 957f4c5b050b8..54816baa9130d 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/table.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/table.scala @@ -18,6 +18,10 @@ package org.apache.flink.table.api import org.apache.calcite.rel.RelNode +import org.apache.calcite.rex.RexNode +import org.apache.calcite.sql.SqlKind +import org.apache.calcite.sql.validate.SqlValidatorUtil +import org.apache.calcite.util.ImmutableBitSet import org.apache.flink.api.common.typeinfo.TypeInformation import org.apache.flink.api.java.operators.join.JoinType import org.apache.flink.table.calcite.FlinkTypeFactory @@ -246,6 +250,96 @@ class Table( groupBy(fieldsExpr: _*) } + /** + * Groups the elements on some grouping sets. Use this before a selection with aggregations + * to perform the aggregation on a per-group basis. Similar to a SQL GROUP BY GROUPING SETS + * statement. + * + * Example: + * + * {{{ + * tab.groupingSets('key).select('key, 'value.avg) + * }}} + */ + def groupingSets(groups: Seq[Expression]*): GroupingSetsTable = { + new GroupingSetsTable(this, groups, SqlKind.GROUPING_SETS) + } + + /** + * Groups the elements on some grouping keys. Use this before a selection with aggregations + * to perform the aggregation on a per-group basis. Similar to a SQL GROUP BY statement. + * + * Example: + * + * {{{ + * tab.groupingSets("key").select("key, value.avg") + * }}} + */ + def groupingSets(groups: String): GroupingSetsTable = { + val fieldsExpr = groups.split("|").map(ExpressionParser.parseExpressionList) + groupingSets(fieldsExpr: _*) + } + + /** + * Groups the elements on some grouping sets. Use this before a selection with aggregations + * to perform the aggregation on a per-group basis. Similar to a SQL GROUP BY GROUPING SETS + * statement. + * + * Example: + * + * {{{ + * tab.groupingSets('key).select('key, 'value.avg) + * }}} + */ + def cube(groups: Seq[Expression]*): GroupingSetsTable = { + new GroupingSetsTable(this, groups, SqlKind.CUBE) + } + + /** + * Groups the elements on some grouping keys. Use this before a selection with aggregations + * to perform the aggregation on a per-group basis. Similar to a SQL GROUP BY statement. + * + * Example: + * + * {{{ + * tab.groupingSets("key").select("key, value.avg") + * }}} + */ + def cube(groups: String): GroupingSetsTable = { + val fieldsExpr = groups.split("|").map(ExpressionParser.parseExpressionList) + cube(fieldsExpr: _*) + } + + /** + * Groups the elements on some grouping sets. Use this before a selection with aggregations + * to perform the aggregation on a per-group basis. Similar to a SQL GROUP BY GROUPING SETS + * statement. + * + * Example: + * + * {{{ + * tab.groupingSets('key).select('key, 'value.avg) + * }}} + */ + def rollup(groups: Seq[Expression]*): GroupingSetsTable = { + new GroupingSetsTable(this, groups, SqlKind.ROLLUP) + } + + /** + * Groups the elements on some grouping keys. Use this before a selection with aggregations + * to perform the aggregation on a per-group basis. Similar to a SQL GROUP BY statement. + * + * Example: + * + * {{{ + * tab.groupingSets("key").select("key, value.avg") + * }}} + */ + def rollup(groups: String): GroupingSetsTable = { + val fieldsExpr = groups.split("|").map(ExpressionParser.parseExpressionList) + cube(fieldsExpr: _*) + } + /** * Removes duplicate values and returns only distinct (different) values. * @@ -930,3 +1024,97 @@ class GroupWindowedTable( } } + +/** + * A table that has been grouped on several sets of grouping keys. + */ +class GroupingSetsTable( + private[flink] val table: Table, + private[flink] val groups: Seq[Seq[Expression]], + private[flink] val sqlKind: SqlKind) { + + /** + * Performs a selection operation on a grouped table. Similar to an SQL SELECT statement. + * The field expressions can contain complex expressions and aggregations. + * + * Example: + * + * {{{ + * tab.groupingSets('key).select('key, 'value.avg + " The average" as 'average) + * }}} + */ + def select(fields: Expression*): Table = { + + val (projection, aggs, props) = extractAggregationsAndProperties(fields, table.tableEnv) + + if (props.nonEmpty) { + throw ValidationException("Window properties can only be used on windowed tables.") + } + + val groupingSets = sqlKind match { + case SqlKind.CUBE => cube(groups) + case SqlKind.ROLLUP => rollup(groups) + case _ => groups + } + + val logical = + Project( + projection, + Grouping( + groupingSets, + aggs, + table.logicalPlan, + sqlKind + ).validate(table.tableEnv) + ).validate(table.tableEnv) + + new Table(table.tableEnv, logical) + } + + /** + * Performs a selection operation on a grouped table. Similar to an SQL SELECT statement. + * The field expressions can contain complex expressions and aggregations. + * + * Example: + * + * {{{ + * tab.groupBy("key").select("key, value.avg + ' The average' as average") + * }}} + */ + def select(fields: String): Table = { + val fieldExprs = ExpressionParser.parseExpressionList(fields) + select(fieldExprs: _*) + } + + /** Computes the rollup of bit sets. + * + *

For example, rollup({0}, {1}) + * returns ({0, 1}, {0}, {}). + * + *

Bit sets are not necessarily singletons: + * rollup({0, 2}, {3, 5}) + * returns ({0, 2, 3, 5}, {0, 2}, {}). */ + private def rollup(groups: Seq[Seq[Expression]]): Seq[Seq[Expression]] = { + val originalBitSet = for (i <- groups.indices) yield { + ImmutableBitSet.builder().set(i).build() + } + val rollupBitSets = SqlValidatorUtil.rollup(originalBitSet.asJava) + rollupBitSets.asScala.map(_.asScala.flatMap(i => groups(i)).toSeq) + } + + /** Computes the cube of bit sets. + * + *

For example, rollup({0}, {1}) + * returns ({0, 1}, {0}, {}). + * + *

Bit sets are not necessarily singletons: + * rollup({0, 2}, {3, 5}) + * returns ({0, 2, 3, 5}, {0, 2}, {}). */ + private def cube(groups: Seq[Seq[Expression]]): Seq[Seq[Expression]] = { + val originalBitSet = for (i <- groups.indices) yield { + ImmutableBitSet.builder().set(i).build() + } + val cubeBitSets = SqlValidatorUtil.cube(originalBitSet.asJava) + cubeBitSets.asScala.map(_.asScala.flatMap(i => groups(i)).toSeq) + } +} diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/logical/operators.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/logical/operators.scala index 743bdfe1ec09b..b580f2ff3f024 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/logical/operators.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/logical/operators.scala @@ -20,12 +20,16 @@ package org.apache.flink.table.plan.logical import java.lang.reflect.Method import java.util +import com.google.common.collect.{ImmutableList, Sets} import org.apache.calcite.rel.RelNode import org.apache.calcite.rel.`type`.RelDataType import org.apache.calcite.rel.core.CorrelationId import org.apache.calcite.rel.logical.{LogicalProject, LogicalTableFunctionScan} import org.apache.calcite.rex.{RexInputRef, RexNode} +import org.apache.calcite.sql.SqlKind +import org.apache.calcite.sql.validate.SqlValidatorUtil import org.apache.calcite.tools.RelBuilder +import org.apache.calcite.util.{ImmutableBitSet, ImmutableIntList} import org.apache.flink.api.common.typeinfo.BasicTypeInfo._ import org.apache.flink.api.common.typeinfo.TypeInformation import org.apache.flink.api.java.operators.join.JoinType @@ -41,6 +45,7 @@ import org.apache.flink.table.typeutils.TypeConverter import org.apache.flink.table.validate.{ValidationFailure, ValidationSuccess} import scala.collection.JavaConverters._ +import scala.collection.immutable.BitSet import scala.collection.mutable case class Project(projectList: Seq[NamedExpression], child: LogicalNode) extends UnaryNode { @@ -263,6 +268,80 @@ case class Aggregate( resolvedAggregate } } +case class Grouping( + groupingExpressions: Seq[Seq[Expression]], + aggregateExpressions: Seq[NamedExpression], + child: LogicalNode, + sqlKind: SqlKind + ) extends UnaryNode { + + override def output: Seq[Attribute] = { + (groupingExpressions.flatten.distinct ++ aggregateExpressions) map { + case ne: NamedExpression => ne.toAttribute + case e => Alias(e, e.toString).toAttribute + } + } + + override protected[logical] def construct(relBuilder: RelBuilder): RelBuilder = { + child.construct(relBuilder) + val groupingSets = groupingExpressions.map(_.map(_.toRexNode(relBuilder)).toList).toList + relBuilder.aggregate( + relBuilder.groupKey( + groupingExpressions.head.map(_.toRexNode(relBuilder)).asJava, + true, groupingSets.map(_.asJava).asJava + ), + aggregateExpressions.map { + case Alias(agg: Aggregation, name, _) => agg.toAggCall(name)(relBuilder) + case _ => throw new RuntimeException("This should never happen.") + }.asJava) + } + + override def validate(tableEnv: TableEnvironment): LogicalNode = { + if (tableEnv.isInstanceOf[StreamTableEnvironment]) { + failValidation(s"Aggregate on stream tables is currently not supported.") + } + + val resolvedAggregate = super.validate(tableEnv).asInstanceOf[Grouping] + val groupingExprs = resolvedAggregate.groupingExpressions + val aggregateExprs = resolvedAggregate.aggregateExpressions + val resolvedGroupingExprs = groupingExprs.map(_.map { + case u @ UnresolvedFieldReference(name) => + resolveReference(tableEnv, name).getOrElse(u) + case x => x + }) + resolvedGroupingExprs.flatten.foreach(validateGroupingExpression) + aggregateExprs.foreach(validateAggregateExpression) + + def validateAggregateExpression(expr: Expression): Unit = expr match { + // check no nested aggregation exists. + case aggExpr: Aggregation => + aggExpr.children.foreach { child => + child.preOrderVisit { + case agg: Aggregation => + failValidation( + "It's not allowed to use an aggregate function as " + + "input of another aggregate function") + case _ => // OK + } + } + case a: Attribute if !groupingExprs.flatten.exists(_.checkEquals(a)) => + failValidation( + s"expression '$a' is invalid because it is neither" + + " present in group by nor an aggregate function") + case e if groupingExprs.flatten.exists(_.checkEquals(e)) => // OK + case e => e.children.foreach(validateAggregateExpression) + } + + def validateGroupingExpression(expr: Expression): Unit = { + if (!expr.resultType.isKeyType) { + failValidation( + s"expression $expr cannot be used as a grouping expression " + + "because it's not a valid key type which must be hashable and comparable") + } + } + Grouping(resolvedGroupingExprs, resolvedAggregate.aggregateExpressions, child, sqlKind) + } +} case class Minus(left: LogicalNode, right: LogicalNode, all: Boolean) extends BinaryNode { override def output: Seq[Attribute] = left.output diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/AggregationTest.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/AggregationTest.scala index 708e00766a0e7..5ecd5b4d16962 100644 --- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/AggregationTest.scala +++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/AggregationTest.scala @@ -19,8 +19,12 @@ package org.apache.flink.table import org.apache.flink.api.scala._ import org.apache.flink.table.api.scala._ +import org.apache.flink.api.scala.util.CollectionDataSets +import org.apache.flink.table.api.{TableConfig, TableEnvironment} +import org.apache.flink.table.expressions.Expression import org.apache.flink.table.utils.TableTestBase import org.apache.flink.table.utils.TableTestUtil._ +import org.apache.flink.types.Row import org.junit.Test /** @@ -258,4 +262,40 @@ class AggregationTest extends TableTestBase { util.verifyTable(resultTable, expected) } + + @Test + def testGroupingSetsTableApi(): Unit = { + val env = ExecutionEnvironment.getExecutionEnvironment + val tEnv = TableEnvironment.getTableEnvironment(env, new TableConfig) + val ds = CollectionDataSets.get3TupleDataSet(env) + val table = tEnv.fromDataSet(ds, 'a, 'b, 'c) + val result = table + .groupingSets(Seq[Expression]('a, 'b), Seq[Expression]('a), Seq[Expression]()) + .select('a, 'b) + result.toDataSet[Row].print() + } + + @Test + def testCubeTableApi(): Unit = { + val env = ExecutionEnvironment.getExecutionEnvironment + val tEnv = TableEnvironment.getTableEnvironment(env, new TableConfig) + val ds = CollectionDataSets.get3TupleDataSet(env) + val table = tEnv.fromDataSet(ds, 'a, 'b, 'c) + val result = table + .cube(Seq[Expression]('a), Seq[Expression]('b), Seq[Expression]('c)) + .select('a, 'b) + result.toDataSet[Row].print() + } + + @Test + def testRollupTableApi(): Unit = { + val env = ExecutionEnvironment.getExecutionEnvironment + val tEnv = TableEnvironment.getTableEnvironment(env, new TableConfig) + val ds = CollectionDataSets.get3TupleDataSet(env) + val table = tEnv.fromDataSet(ds, 'a, 'b, 'c) + val result = table + .rollup(Seq[Expression]('a), Seq[Expression]('b)) + .select('a, 'b) + result.toDataSet[Row].print() + } } From 489b3c97732598d4265e89fdb6588f82332c7788 Mon Sep 17 00:00:00 2001 From: Aleksandr Chermenin Date: Mon, 12 Dec 2016 11:58:17 +0300 Subject: [PATCH 02/18] [FLINK-2980] Implemented grouped expressions. --- .../flink/table/api/scala/expressionDsl.scala | 60 +++++++++++++++- .../flink/table/api/scala/package.scala | 2 +- .../org/apache/flink/table/api/table.scala | 68 +++++++++++-------- .../flink/table/expressions/Expression.scala | 17 +++++ .../apache/flink/table/AggregationTest.scala | 14 ++-- 5 files changed, 122 insertions(+), 39 deletions(-) diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/scala/expressionDsl.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/scala/expressionDsl.scala index 0634f0b825f0f..80f5381d138bb 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/scala/expressionDsl.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/scala/expressionDsl.scala @@ -17,6 +17,7 @@ */ package org.apache.flink.table.api.scala +import java.math.{BigDecimal => JBigDecimal} import java.sql.{Date, Time, Timestamp} import org.apache.calcite.avatica.util.DateTimeUtils._ @@ -24,7 +25,6 @@ import org.apache.flink.api.common.typeinfo.{SqlTimeTypeInfo, TypeInformation} import org.apache.flink.table.expressions.ExpressionUtils.{convertArray, toMilliInterval, toMonthInterval, toRowInterval} import org.apache.flink.table.expressions.TimeIntervalUnit.TimeIntervalUnit import org.apache.flink.table.expressions._ -import java.math.{BigDecimal => JBigDecimal} import scala.language.implicitConversions @@ -574,6 +574,64 @@ trait ImplicitExpressionConversions { implicit def array2ArrayConstructor(array: Array[_]): Expression = convertArray(array) } +/** + * Implicit conversions from Scala Tuples of Literals to GroupedExpression. + */ +trait ImplicitGroupedOperations { + private[flink] def expr: GroupedExpression + + + + implicit class UnitAsGroupedExpression(unit: Unit) extends ImplicitGroupedOperations { + override private[flink] def expr = new GroupedExpression(Seq()) + } + + implicit class ExpressionAsGroupedExpression(expression: Expression) + extends ImplicitGroupedOperations { + override private[flink] def expr = new GroupedExpression(Seq(expression)) + } + + implicit class ProductAsGroupedExpression(product: Product) + extends ImplicitGroupedOperations { + override private[flink] def expr = + if (product.productArity == product.productIterator.count(_.isInstanceOf[Expression])) { + new GroupedExpression(product.productIterator.map(_.asInstanceOf[Expression]).toSeq) + } else if (product.productArity == product.productIterator.count(_.isInstanceOf[Symbol])) { + new GroupedExpression( + product.productIterator + .map(_.asInstanceOf[Symbol]) + .map(s => UnresolvedFieldReference(s.name)) + .toSeq + ) + } else { + throw new IllegalArgumentException() + } + } +} + +trait ImplicitGroupedConversions { + + implicit def unitToGroupedExpression(unit: Unit): GroupedExpression = + new GroupedExpression(Seq()) + + implicit def expressionToGroupedExpression(expression: Expression): GroupedExpression = + new GroupedExpression(Seq(expression)) + + implicit def productToGroupedExpression(product: Product): GroupedExpression = + if (product.productArity == product.productIterator.count(_.isInstanceOf[Expression])) { + new GroupedExpression(product.productIterator.map(_.asInstanceOf[Expression]).toSeq) + } else if (product.productArity == product.productIterator.count(_.isInstanceOf[Symbol])) { + new GroupedExpression( + product.productIterator + .map(_.asInstanceOf[Symbol]) + .map(s => UnresolvedFieldReference(s.name)) + .toSeq + ) + } else { + throw new IllegalArgumentException() + } +} + // ------------------------------------------------------------------------------------------------ // Expressions with no parameters // ------------------------------------------------------------------------------------------------ diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/scala/package.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/scala/package.scala index cd341cbe5ba48..72dfaeb29cd49 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/scala/package.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/scala/package.scala @@ -66,7 +66,7 @@ import _root_.scala.reflect.ClassTag * }}} * */ -package object scala extends ImplicitExpressionConversions { +package object scala extends ImplicitExpressionConversions with ImplicitGroupedConversions { implicit def table2TableConversions(table: Table): TableConversions = { new TableConversions(table) diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/table.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/table.scala index 54816baa9130d..4536d446c6abb 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/table.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/table.scala @@ -26,7 +26,7 @@ import org.apache.flink.api.common.typeinfo.TypeInformation import org.apache.flink.api.java.operators.join.JoinType import org.apache.flink.table.calcite.FlinkTypeFactory import org.apache.flink.table.plan.logical.Minus -import org.apache.flink.table.expressions.{Alias, Asc, Call, Expression, ExpressionParser, Ordering, TableFunctionCall, UnresolvedAlias} +import org.apache.flink.table.expressions.{Alias, Asc, Call, Expression, ExpressionParser, GroupedExpression, Ordering, TableFunctionCall, UnresolvedAlias} import org.apache.flink.table.plan.ProjectionTranslator._ import org.apache.flink.table.plan.logical._ import org.apache.flink.table.sinks.TableSink @@ -258,86 +258,100 @@ class Table( * Example: * * {{{ - * tab.groupingSets('key).select('key, 'value.avg) + * tab.groupingSets(('a, 'b), ('a), ()).select('a, 'b, 'c.avg) * }}} */ - def groupingSets(groups: Seq[Expression]*): GroupingSetsTable = { + def groupingSets(fields: Expression*): GroupingSetsTable = { + val groups = fields.map { + case g: GroupedExpression => g.children + case x => Seq(x) + } new GroupingSetsTable(this, groups, SqlKind.GROUPING_SETS) } /** * Groups the elements on some grouping keys. Use this before a selection with aggregations - * to perform the aggregation on a per-group basis. Similar to a SQL GROUP BY statement. + * to perform the aggregation on a per-group basis. Similar to a SQL GROUP BY GROUPING SETS + * statement. * * Example: * * {{{ - * tab.groupingSets("key").select("key, value.avg") + * tab.groupingSets("(a, b), (a), ()").select("a, b, c.avg") * }}} */ def groupingSets(groups: String): GroupingSetsTable = { - val fieldsExpr = groups.split("|").map(ExpressionParser.parseExpressionList) - groupingSets(fieldsExpr: _*) +// val fieldsExpr = groups.split("|").map(ExpressionParser.parseExpressionList) +// groupingSets(fieldsExpr: _*) + ??? } /** - * Groups the elements on some grouping sets. Use this before a selection with aggregations - * to perform the aggregation on a per-group basis. Similar to a SQL GROUP BY GROUPING SETS - * statement. + * Groups the elements on cube grouping sets. Use this before a selection with aggregations + * to perform the aggregation on a per-group basis. Similar to a SQL GROUP BY CUBE statement. * * Example: * * {{{ - * tab.groupingSets('key).select('key, 'value.avg) + * tab.cube('a, 'b).select('a, 'b, 'c.avg) * }}} */ - def cube(groups: Seq[Expression]*): GroupingSetsTable = { + def cube(fields: Expression*): GroupingSetsTable = { + val groups = fields.map { + case g: GroupedExpression => g.children + case x => Seq(x) + } new GroupingSetsTable(this, groups, SqlKind.CUBE) } /** - * Groups the elements on some grouping keys. Use this before a selection with aggregations - * to perform the aggregation on a per-group basis. Similar to a SQL GROUP BY statement. + * Groups the elements on cube grouping sets. Use this before a selection with aggregations + * to perform the aggregation on a per-group basis. Similar to a SQL GROUP BY CUBE statement. * * Example: * * {{{ - * tab.groupingSets("key").select("key, value.avg") + * tab.cube("a, b").select("a, b, c.avg") * }}} */ def cube(groups: String): GroupingSetsTable = { - val fieldsExpr = groups.split("|").map(ExpressionParser.parseExpressionList) - cube(fieldsExpr: _*) +// val fieldsExpr = groups.split("|").map(ExpressionParser.parseExpressionList) +// cube(fieldsExpr: _*) + ??? } /** - * Groups the elements on some grouping sets. Use this before a selection with aggregations - * to perform the aggregation on a per-group basis. Similar to a SQL GROUP BY GROUPING SETS - * statement. + * Groups the elements on rollup grouping sets. Use this before a selection with aggregations + * to perform the aggregation on a per-group basis. Similar to a SQL GROUP BY ROLLUP statement. * * Example: * * {{{ - * tab.groupingSets('key).select('key, 'value.avg) + * tab.rollup('a, 'b).select('a, 'b, 'c.avg) * }}} */ - def rollup(groups: Seq[Expression]*): GroupingSetsTable = { + def rollup(fields: Expression*): GroupingSetsTable = { + val groups = fields.map { + case g: GroupedExpression => g.children + case x => Seq(x) + } new GroupingSetsTable(this, groups, SqlKind.ROLLUP) } /** - * Groups the elements on some grouping keys. Use this before a selection with aggregations - * to perform the aggregation on a per-group basis. Similar to a SQL GROUP BY statement. + * Groups the elements on rollup grouping sets. Use this before a selection with aggregations + * to perform the aggregation on a per-group basis. Similar to a SQL GROUP BY ROLLUP statement. * * Example: * * {{{ - * tab.groupingSets("key").select("key, value.avg") + * tab.rollup("a, b").select("a, b, c.avg") * }}} */ def rollup(groups: String): GroupingSetsTable = { - val fieldsExpr = groups.split("|").map(ExpressionParser.parseExpressionList) - cube(fieldsExpr: _*) +// val fieldsExpr = groups.split("|").map(ExpressionParser.parseExpressionList) +// rollup(fieldsExpr: _*) + ??? } /** diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/expressions/Expression.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/expressions/Expression.scala index 14d899daba892..56dabaa385b4c 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/expressions/Expression.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/expressions/Expression.scala @@ -86,3 +86,20 @@ abstract class UnaryExpression extends Expression { abstract class LeafExpression extends Expression { private[flink] val children = Nil } + +class GroupedExpression( + private[flink] val children: Seq[Expression] + ) extends Expression { + + /** + * Returns the [[TypeInformation]] for evaluating this expression. + * It's not applicable for grouped expressions. + */ + override private[flink] def resultType = ??? + + override def productElement(n: Int): Expression = children(n) + + override def productArity: Int = children.length + + override def canEqual(that: Any) = false +} diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/AggregationTest.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/AggregationTest.scala index 5ecd5b4d16962..2789b28dc4098 100644 --- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/AggregationTest.scala +++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/AggregationTest.scala @@ -21,7 +21,7 @@ import org.apache.flink.api.scala._ import org.apache.flink.table.api.scala._ import org.apache.flink.api.scala.util.CollectionDataSets import org.apache.flink.table.api.{TableConfig, TableEnvironment} -import org.apache.flink.table.expressions.Expression +import org.apache.flink.table.expressions.{Expression, GroupedExpression} import org.apache.flink.table.utils.TableTestBase import org.apache.flink.table.utils.TableTestUtil._ import org.apache.flink.types.Row @@ -269,9 +269,7 @@ class AggregationTest extends TableTestBase { val tEnv = TableEnvironment.getTableEnvironment(env, new TableConfig) val ds = CollectionDataSets.get3TupleDataSet(env) val table = tEnv.fromDataSet(ds, 'a, 'b, 'c) - val result = table - .groupingSets(Seq[Expression]('a, 'b), Seq[Expression]('a), Seq[Expression]()) - .select('a, 'b) + val result = table.groupingSets(('a, 'b), 'b, ()).select('a, 'b) result.toDataSet[Row].print() } @@ -281,9 +279,7 @@ class AggregationTest extends TableTestBase { val tEnv = TableEnvironment.getTableEnvironment(env, new TableConfig) val ds = CollectionDataSets.get3TupleDataSet(env) val table = tEnv.fromDataSet(ds, 'a, 'b, 'c) - val result = table - .cube(Seq[Expression]('a), Seq[Expression]('b), Seq[Expression]('c)) - .select('a, 'b) + val result = table.cube('a, 'b).select('a, 'b) result.toDataSet[Row].print() } @@ -293,9 +289,7 @@ class AggregationTest extends TableTestBase { val tEnv = TableEnvironment.getTableEnvironment(env, new TableConfig) val ds = CollectionDataSets.get3TupleDataSet(env) val table = tEnv.fromDataSet(ds, 'a, 'b, 'c) - val result = table - .rollup(Seq[Expression]('a), Seq[Expression]('b)) - .select('a, 'b) + val result = table.rollup('a, 'b).select('a, 'b) result.toDataSet[Row].print() } } From c156ab6a597057463ffd817ce6c0a295da39e052 Mon Sep 17 00:00:00 2001 From: Aleksandr Chermenin Date: Mon, 12 Dec 2016 13:18:42 +0300 Subject: [PATCH 03/18] [FLINK-2980] Added grouping functions. --- .../api/table/expressions/groupings.scala | 64 ++++++++++++++ .../flink/table/api/scala/expressionDsl.scala | 83 +++++++++++-------- .../org/apache/flink/table/api/table.scala | 2 +- .../flink/table/expressions/Expression.scala | 25 +++++- .../table/plan/ProjectionTranslator.scala | 1 + .../apache/flink/table/AggregationTest.scala | 2 +- 6 files changed, 139 insertions(+), 38 deletions(-) create mode 100644 flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/expressions/groupings.scala diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/expressions/groupings.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/expressions/groupings.scala new file mode 100644 index 0000000000000..f7055226066f8 --- /dev/null +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/expressions/groupings.scala @@ -0,0 +1,64 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.api.table.expressions + +import org.apache.calcite.sql.fun.SqlStdOperatorTable +import org.apache.calcite.tools.RelBuilder +import org.apache.flink.api.common.typeinfo.BasicTypeInfo + +abstract sealed class GroupFunction extends Expression { + + override def toString = s"GroupFunction($children)" +} + +case class GroupId() extends GroupFunction { + + override private[flink] def resultType = BasicTypeInfo.LONG_TYPE_INFO + + override private[flink] def children = Nil + + override private[flink] def toRexNode(implicit relBuilder: RelBuilder) = { + relBuilder.call(SqlStdOperatorTable.GROUP_ID) + } +} + +case class Grouping(expression: Expression) extends GroupFunction { + + override private[flink] def resultType = BasicTypeInfo.LONG_TYPE_INFO + + + override private[flink] def children = Seq(expression) + + override private[flink] def toRexNode(implicit relBuilder: RelBuilder) = { + relBuilder.call(SqlStdOperatorTable.GROUPING, expression.toRexNode) + } +} + +case class GroupingId(expressions: Expression*) extends GroupFunction { + + override private[flink] def resultType = BasicTypeInfo.LONG_TYPE_INFO + + + override private[flink] def children = expressions + + override private[flink] def toRexNode(implicit relBuilder: RelBuilder) = { + relBuilder.call(SqlStdOperatorTable.GROUPING_ID, expressions.map(_.toRexNode): _*) + } +} + diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/scala/expressionDsl.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/scala/expressionDsl.scala index 80f5381d138bb..f5fd8bc3d1a79 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/scala/expressionDsl.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/scala/expressionDsl.scala @@ -485,6 +485,19 @@ trait ImplicitExpressionOperations { * @return the first and only element of an array with a single element */ def element() = ArrayElement(expr) + + /** + * Grouping function. Similar to a SQL GROUPING_ID function. + */ + def groupingId(): Expression = expr match { + case g: GroupedExpression => GroupingId(g.flatChildren: _*) + case x => GroupingId(x) + } + + /** + * Grouping function. Similar to a SQL GROUPING function. + */ + def grouping(): Expression = Grouping(expr) } /** @@ -580,32 +593,13 @@ trait ImplicitExpressionConversions { trait ImplicitGroupedOperations { private[flink] def expr: GroupedExpression - - implicit class UnitAsGroupedExpression(unit: Unit) extends ImplicitGroupedOperations { override private[flink] def expr = new GroupedExpression(Seq()) } - implicit class ExpressionAsGroupedExpression(expression: Expression) - extends ImplicitGroupedOperations { - override private[flink] def expr = new GroupedExpression(Seq(expression)) - } - implicit class ProductAsGroupedExpression(product: Product) extends ImplicitGroupedOperations { - override private[flink] def expr = - if (product.productArity == product.productIterator.count(_.isInstanceOf[Expression])) { - new GroupedExpression(product.productIterator.map(_.asInstanceOf[Expression]).toSeq) - } else if (product.productArity == product.productIterator.count(_.isInstanceOf[Symbol])) { - new GroupedExpression( - product.productIterator - .map(_.asInstanceOf[Symbol]) - .map(s => UnresolvedFieldReference(s.name)) - .toSeq - ) - } else { - throw new IllegalArgumentException() - } + override private[flink] def expr = new GroupedExpression(product) } } @@ -614,22 +608,8 @@ trait ImplicitGroupedConversions { implicit def unitToGroupedExpression(unit: Unit): GroupedExpression = new GroupedExpression(Seq()) - implicit def expressionToGroupedExpression(expression: Expression): GroupedExpression = - new GroupedExpression(Seq(expression)) - implicit def productToGroupedExpression(product: Product): GroupedExpression = - if (product.productArity == product.productIterator.count(_.isInstanceOf[Expression])) { - new GroupedExpression(product.productIterator.map(_.asInstanceOf[Expression]).toSeq) - } else if (product.productArity == product.productIterator.count(_.isInstanceOf[Symbol])) { - new GroupedExpression( - product.productIterator - .map(_.asInstanceOf[Symbol]) - .map(s => UnresolvedFieldReference(s.name)) - .toSeq - ) - } else { - throw new IllegalArgumentException() - } + new GroupedExpression(product) } // ------------------------------------------------------------------------------------------------ @@ -746,4 +726,37 @@ object array { } } +/** + * Grouping function. Similar to a SQL GROUP_ID function. + */ +object groupId { + + /** + * Return evaluated result of the function. + */ + def apply(): Expression = GroupId() +} + +/** + * Grouping function. Similar to a SQL GROUPING function. + */ +object grouping { + + /** + * Return evaluated result of the function. + */ + def apply(expression: Expression): Expression = Grouping(expression) +} + +/** + * Grouping function. Similar to a SQL GROUPING function. + */ +object groupingId { + + /** + * Return evaluated result of the function. + */ + def apply(expression: Expression*): Expression = GroupingId(expression: _*) +} + // scalastyle:on object.name diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/table.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/table.scala index 4536d446c6abb..b8c7860db3998 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/table.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/table.scala @@ -263,7 +263,7 @@ class Table( */ def groupingSets(fields: Expression*): GroupingSetsTable = { val groups = fields.map { - case g: GroupedExpression => g.children + case g: GroupedExpression => g.flatChildren case x => Seq(x) } new GroupingSetsTable(this, groups, SqlKind.GROUPING_SETS) diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/expressions/Expression.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/expressions/Expression.scala index 56dabaa385b4c..0fc484f477a57 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/expressions/Expression.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/expressions/Expression.scala @@ -19,7 +19,6 @@ package org.apache.flink.table.expressions import org.apache.calcite.rex.RexNode import org.apache.calcite.tools.RelBuilder - import org.apache.flink.api.common.typeinfo.TypeInformation import org.apache.flink.table.plan.TreeNode import org.apache.flink.table.validate.{ValidationResult, ValidationSuccess} @@ -91,12 +90,36 @@ class GroupedExpression( private[flink] val children: Seq[Expression] ) extends Expression { + def this(product: Product) { + this( + product.productIterator + .map { + case e: Expression => e + case s: Symbol => UnresolvedFieldReference(s.name) + case p: Product => new GroupedExpression(p) + case _ => throw new IllegalArgumentException() + }.toSeq + ) + } + + def flatChildren: Seq[Expression] = { + children.flatMap { + case g: GroupedExpression => g.flatChildren + case x => Seq(x) + } + } + /** * Returns the [[TypeInformation]] for evaluating this expression. * It's not applicable for grouped expressions. */ override private[flink] def resultType = ??? + /** + * Grouping function. Similar to a SQL GROUPING_ID function. + */ + def groupingId(): Expression = GroupingId(children: _*) + override def productElement(n: Int): Expression = children(n) override def productArity: Int = children.length diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/ProjectionTranslator.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/ProjectionTranslator.scala index ed6cf7b4d0e30..1f0205b3ec846 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/ProjectionTranslator.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/ProjectionTranslator.scala @@ -125,6 +125,7 @@ object ProjectionTranslator { case prop: WindowProperty => val name = propNames(prop) Alias(UnresolvedFieldReference(name), tableEnv.createUniqueAttributeName()) + case g: GroupFunction => g case n @ Alias(agg: Aggregation, name, _) => val aName = aggNames(agg) Alias(UnresolvedFieldReference(aName), name) diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/AggregationTest.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/AggregationTest.scala index 2789b28dc4098..f02d36c6b9fd0 100644 --- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/AggregationTest.scala +++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/AggregationTest.scala @@ -269,7 +269,7 @@ class AggregationTest extends TableTestBase { val tEnv = TableEnvironment.getTableEnvironment(env, new TableConfig) val ds = CollectionDataSets.get3TupleDataSet(env) val table = tEnv.fromDataSet(ds, 'a, 'b, 'c) - val result = table.groupingSets(('a, 'b), 'b, ()).select('a, 'b) + val result = table.groupingSets(('a, 'b), 'b, ()).select('a, 'b, groupId() as 'g) result.toDataSet[Row].print() } From eb2250ffa43a1217c3e351824e5a596a1a3f3464 Mon Sep 17 00:00:00 2001 From: Aleksandr Chermenin Date: Mon, 12 Dec 2016 14:39:09 +0300 Subject: [PATCH 04/18] [FLINK-2980] Improved expressions parser. --- .../org/apache/flink/table/api/table.scala | 23 +++++++------- .../table/expressions/ExpressionParser.scala | 10 +++++-- .../flink/table/plan/logical/operators.scala | 6 ++-- .../apache/flink/table/AggregationTest.scala | 30 +++++++++++++++++++ 4 files changed, 51 insertions(+), 18 deletions(-) diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/table.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/table.scala index b8c7860db3998..9268111c81d58 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/table.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/table.scala @@ -280,10 +280,9 @@ class Table( * tab.groupingSets("(a, b), (a), ()").select("a, b, c.avg") * }}} */ - def groupingSets(groups: String): GroupingSetsTable = { -// val fieldsExpr = groups.split("|").map(ExpressionParser.parseExpressionList) -// groupingSets(fieldsExpr: _*) - ??? + def groupingSets(fields: String): GroupingSetsTable = { + val fieldsExpr = ExpressionParser.parseExpressionList(fields) + groupingSets(fieldsExpr: _*) } /** @@ -314,10 +313,9 @@ class Table( * tab.cube("a, b").select("a, b, c.avg") * }}} */ - def cube(groups: String): GroupingSetsTable = { -// val fieldsExpr = groups.split("|").map(ExpressionParser.parseExpressionList) -// cube(fieldsExpr: _*) - ??? + def cube(fields: String): GroupingSetsTable = { + val fieldsExpr = ExpressionParser.parseExpressionList(fields) + cube(fieldsExpr: _*) } /** @@ -348,10 +346,9 @@ class Table( * tab.rollup("a, b").select("a, b, c.avg") * }}} */ - def rollup(groups: String): GroupingSetsTable = { -// val fieldsExpr = groups.split("|").map(ExpressionParser.parseExpressionList) -// rollup(fieldsExpr: _*) - ??? + def rollup(fields: String): GroupingSetsTable = { + val fieldsExpr = ExpressionParser.parseExpressionList(fields) + rollup(fieldsExpr: _*) } /** @@ -1074,7 +1071,7 @@ class GroupingSetsTable( val logical = Project( projection, - Grouping( + GroupingAggregation( groupingSets, aggs, table.logicalPlan, diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/expressions/ExpressionParser.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/expressions/ExpressionParser.scala index d85540a98e4b9..425cfa92c09f5 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/expressions/ExpressionParser.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/expressions/ExpressionParser.scala @@ -176,6 +176,12 @@ object ExpressionParser extends JavaTokenParsers with PackratParsers { lazy val atom: PackratParser[Expression] = ( "(" ~> expression <~ ")" ) | literalExpr | fieldReference + lazy val grouped: PackratParser[Expression] = + "(" ~> expressionList <~ ")" ^^ { l => new GroupedExpression(l.toSeq) } + + lazy val unit: PackratParser[Expression] = + "()" ^^ { _ => new GroupedExpression(Seq()) } + // suffix operators lazy val suffixSum: PackratParser[Expression] = @@ -383,7 +389,7 @@ object ExpressionParser extends JavaTokenParsers with PackratParsers { // suffix/prefix composite - lazy val composite: PackratParser[Expression] = suffixed | prefixed | atom | + lazy val composite: PackratParser[Expression] = suffixed | prefixed | atom | grouped | failure("Composite expression expected.") // unary ops @@ -455,7 +461,7 @@ object ExpressionParser extends JavaTokenParsers with PackratParsers { case e ~ _ ~ _ ~ names ~ _ => Alias(e, names.head.name, names.tail.map(_.name)) } | logic - lazy val expression: PackratParser[Expression] = alias | + lazy val expression: PackratParser[Expression] = alias | grouped | unit | failure("Invalid expression.") lazy val expressionList: Parser[List[Expression]] = rep1sep(expression, ",") diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/logical/operators.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/logical/operators.scala index b580f2ff3f024..755221c298cec 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/logical/operators.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/logical/operators.scala @@ -268,7 +268,7 @@ case class Aggregate( resolvedAggregate } } -case class Grouping( +case class GroupingAggregation( groupingExpressions: Seq[Seq[Expression]], aggregateExpressions: Seq[NamedExpression], child: LogicalNode, @@ -301,7 +301,7 @@ case class Grouping( failValidation(s"Aggregate on stream tables is currently not supported.") } - val resolvedAggregate = super.validate(tableEnv).asInstanceOf[Grouping] + val resolvedAggregate = super.validate(tableEnv).asInstanceOf[GroupingAggregation] val groupingExprs = resolvedAggregate.groupingExpressions val aggregateExprs = resolvedAggregate.aggregateExpressions val resolvedGroupingExprs = groupingExprs.map(_.map { @@ -339,7 +339,7 @@ case class Grouping( "because it's not a valid key type which must be hashable and comparable") } } - Grouping(resolvedGroupingExprs, resolvedAggregate.aggregateExpressions, child, sqlKind) + GroupingAggregation(resolvedGroupingExprs, resolvedAggregate.aggregateExpressions, child, sqlKind) } } diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/AggregationTest.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/AggregationTest.scala index f02d36c6b9fd0..a272daf8ec81c 100644 --- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/AggregationTest.scala +++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/AggregationTest.scala @@ -273,6 +273,16 @@ class AggregationTest extends TableTestBase { result.toDataSet[Row].print() } + @Test + def testStringGroupingSetsTableApi(): Unit = { + val env = ExecutionEnvironment.getExecutionEnvironment + val tEnv = TableEnvironment.getTableEnvironment(env, new TableConfig) + val ds = CollectionDataSets.get3TupleDataSet(env) + val table = tEnv.fromDataSet(ds) + val result = table.groupingSets("(_1, _2), (_1), ()").select("_1, _2") + result.toDataSet[Row].print() + } + @Test def testCubeTableApi(): Unit = { val env = ExecutionEnvironment.getExecutionEnvironment @@ -283,6 +293,16 @@ class AggregationTest extends TableTestBase { result.toDataSet[Row].print() } + @Test + def testStringCubeTableApi(): Unit = { + val env = ExecutionEnvironment.getExecutionEnvironment + val tEnv = TableEnvironment.getTableEnvironment(env, new TableConfig) + val ds = CollectionDataSets.get3TupleDataSet(env) + val table = tEnv.fromDataSet(ds) + val result = table.cube("(_1, _3), _2").select("_1, _2") + result.toDataSet[Row].print() + } + @Test def testRollupTableApi(): Unit = { val env = ExecutionEnvironment.getExecutionEnvironment @@ -292,4 +312,14 @@ class AggregationTest extends TableTestBase { val result = table.rollup('a, 'b).select('a, 'b) result.toDataSet[Row].print() } + + @Test + def testStringRollupTableApi(): Unit = { + val env = ExecutionEnvironment.getExecutionEnvironment + val tEnv = TableEnvironment.getTableEnvironment(env, new TableConfig) + val ds = CollectionDataSets.get3TupleDataSet(env) + val table = tEnv.fromDataSet(ds) + val result = table.rollup("(_1, _3), _2").select("_1, _2") + result.toDataSet[Row].print() + } } From 6a7edb27cb61aa35006e7a7fcfcd7e45f4d99688 Mon Sep 17 00:00:00 2001 From: Aleksandr Chermenin Date: Mon, 12 Dec 2016 15:05:25 +0300 Subject: [PATCH 05/18] [FLINK-2980] Added support for grouping functions. --- .../table/expressions/ExpressionParser.scala | 29 +++++++++++++++++-- .../table/expressions/groupings.scala | 2 +- .../apache/flink/table/AggregationTest.scala | 2 +- 3 files changed, 28 insertions(+), 5 deletions(-) rename flink-libraries/flink-table/src/main/scala/org/apache/flink/{api => }/table/expressions/groupings.scala (97%) diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/expressions/ExpressionParser.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/expressions/ExpressionParser.scala index 425cfa92c09f5..4d55e79227a4e 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/expressions/ExpressionParser.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/expressions/ExpressionParser.scala @@ -87,6 +87,9 @@ object ExpressionParser extends JavaTokenParsers with PackratParsers { lazy val STAR: Keyword = Keyword("*") lazy val GET: Keyword = Keyword("get") lazy val FLATTEN: Keyword = Keyword("flatten") + lazy val GROUP_ID: Keyword = Keyword("groupId") + lazy val GROUPING: Keyword = Keyword("grouping") + lazy val GROUPING_ID: Keyword = Keyword("groupingId") def functionIdent: ExpressionParser.Parser[String] = not(ARRAY) ~ not(AS) ~ not(COUNT) ~ not(AVG) ~ not(MIN) ~ not(MAX) ~ @@ -295,12 +298,21 @@ object ExpressionParser extends JavaTokenParsers with PackratParsers { lazy val suffixFlattening: PackratParser[Expression] = composite <~ "." ~ FLATTEN ~ opt("()") ^^ { e => Flattening(e) } + lazy val suffixGrouping: PackratParser[Expression] = + composite <~ "." ~ GROUPING ~ opt("()") ^^ { e => Grouping(e) } + + lazy val suffixGroupingId: PackratParser[Expression] = + composite <~ "." ~ GROUPING_ID ~ opt("()") ^^ { + case g: GroupedExpression => GroupingId(g.flatChildren: _*) + case e => GroupingId(e) + } + lazy val suffixed: PackratParser[Expression] = suffixTimeInterval | suffixRowInterval | suffixSum | suffixMin | suffixMax | suffixStart | suffixEnd | suffixCount | suffixAvg | suffixCast | suffixAs | suffixTrim | suffixTrimWithoutArgs | suffixIf | suffixAsc | suffixDesc | suffixToDate | suffixToTimestamp | suffixToTime | suffixExtract | suffixFloor | suffixCeil | - suffixGet | suffixFlattening | + suffixGet | suffixFlattening | suffixGroupingId | suffixGrouping | suffixFunctionCall | suffixFunctionCallOneArg // function call must always be at the end // prefix operators @@ -381,10 +393,21 @@ object ExpressionParser extends JavaTokenParsers with PackratParsers { lazy val prefixFlattening: PackratParser[Expression] = FLATTEN ~ "(" ~> composite <~ ")" ^^ { e => Flattening(e) } + lazy val prefixGrouping: PackratParser[Expression] = + GROUPING ~ "(" ~> composite <~ ")" ^^ { e => Grouping(e) } + + lazy val prefixGroupingId: PackratParser[Expression] = + GROUPING_ID ~ "(" ~> repsep(expression, ",") <~ ")" ^^ { l => GroupingId(l: _*) } + + lazy val prefixGroupId: PackratParser[Expression] = + GROUP_ID ~ opt("()") ^^ { _ => GroupId() } + lazy val prefixed: PackratParser[Expression] = prefixArray | prefixSum | prefixMin | prefixMax | prefixCount | prefixAvg | - prefixStart | prefixEnd | prefixCast | prefixAs | prefixTrim | prefixTrimWithoutArgs | - prefixIf | prefixExtract | prefixFloor | prefixCeil | prefixGet | prefixFlattening | + prefixStart | prefixEnd | + prefixCast | prefixAs | prefixTrim | prefixTrimWithoutArgs | prefixIf | prefixExtract | + prefixFloor | prefixCeil | prefixGet | prefixFlattening | + prefixGroupingId | prefixGrouping | prefixGroupId | prefixFunctionCall | prefixFunctionCallOneArg // function call must always be at the end // suffix/prefix composite diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/expressions/groupings.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/expressions/groupings.scala similarity index 97% rename from flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/expressions/groupings.scala rename to flink-libraries/flink-table/src/main/scala/org/apache/flink/table/expressions/groupings.scala index f7055226066f8..356b83b2e88be 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/expressions/groupings.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/expressions/groupings.scala @@ -16,7 +16,7 @@ * limitations under the License. */ -package org.apache.flink.api.table.expressions +package org.apache.flink.table.expressions import org.apache.calcite.sql.fun.SqlStdOperatorTable import org.apache.calcite.tools.RelBuilder diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/AggregationTest.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/AggregationTest.scala index a272daf8ec81c..cea252649a178 100644 --- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/AggregationTest.scala +++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/AggregationTest.scala @@ -279,7 +279,7 @@ class AggregationTest extends TableTestBase { val tEnv = TableEnvironment.getTableEnvironment(env, new TableConfig) val ds = CollectionDataSets.get3TupleDataSet(env) val table = tEnv.fromDataSet(ds) - val result = table.groupingSets("(_1, _2), (_1), ()").select("_1, _2") + val result = table.groupingSets("(_1, _2), (_1), ()").select("_1, _2, groupId") result.toDataSet[Row].print() } From 12336d53566121280001a2b3c5930561477da219 Mon Sep 17 00:00:00 2001 From: Aleksandr Chermenin Date: Mon, 12 Dec 2016 15:54:19 +0300 Subject: [PATCH 06/18] [FLINK-2980] Small fixes. --- .../src/main/scala/org/apache/flink/table/api/table.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/table.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/table.scala index 9268111c81d58..d2cf0057ad160 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/table.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/table.scala @@ -297,7 +297,7 @@ class Table( */ def cube(fields: Expression*): GroupingSetsTable = { val groups = fields.map { - case g: GroupedExpression => g.children + case g: GroupedExpression => g.flatChildren case x => Seq(x) } new GroupingSetsTable(this, groups, SqlKind.CUBE) @@ -330,7 +330,7 @@ class Table( */ def rollup(fields: Expression*): GroupingSetsTable = { val groups = fields.map { - case g: GroupedExpression => g.children + case g: GroupedExpression => g.flatChildren case x => Seq(x) } new GroupingSetsTable(this, groups, SqlKind.ROLLUP) From 0907054592415465666982e5baf72fe42439abac Mon Sep 17 00:00:00 2001 From: Aleksandr Chermenin Date: Mon, 12 Dec 2016 16:32:57 +0300 Subject: [PATCH 07/18] [FLINK-2980] Windowed table with grouping sets. --- .../org/apache/flink/table/api/table.scala | 105 ++++++++++----- .../table/expressions/ExpressionUtils.scala | 35 +++++ .../flink/table/plan/logical/operators.scala | 123 ++++++++++++++++-- 3 files changed, 224 insertions(+), 39 deletions(-) diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/table.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/table.scala index d2cf0057ad160..5fd8cb26417f3 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/table.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/table.scala @@ -15,13 +15,12 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.apache.flink.table.api import org.apache.calcite.rel.RelNode -import org.apache.calcite.rex.RexNode import org.apache.calcite.sql.SqlKind -import org.apache.calcite.sql.validate.SqlValidatorUtil -import org.apache.calcite.util.ImmutableBitSet +import org.apache.calcite.tools.RelBuilder import org.apache.flink.api.common.typeinfo.TypeInformation import org.apache.flink.api.java.operators.join.JoinType import org.apache.flink.table.calcite.FlinkTypeFactory @@ -71,7 +70,7 @@ class Table( logicalPlan.output.map(_.name).toArray, logicalPlan.output.map(_.resultType).toArray) - def relBuilder = tableEnv.getRelBuilder + def relBuilder: RelBuilder = tableEnv.getRelBuilder def getRelNode: RelNode = logicalPlan.toRelNode(relBuilder) @@ -1033,7 +1032,6 @@ class GroupWindowedTable( val fieldExprs = ExpressionParser.parseExpressionList(fields) select(fieldExprs: _*) } - } /** @@ -1063,8 +1061,8 @@ class GroupingSetsTable( } val groupingSets = sqlKind match { - case SqlKind.CUBE => cube(groups) - case SqlKind.ROLLUP => rollup(groups) + case SqlKind.CUBE => ExpressionUtils.cube(groups) + case SqlKind.ROLLUP => ExpressionUtils.rollup(groups) case _ => groups } @@ -1074,8 +1072,7 @@ class GroupingSetsTable( GroupingAggregation( groupingSets, aggs, - table.logicalPlan, - sqlKind + table.logicalPlan ).validate(table.tableEnv) ).validate(table.tableEnv) @@ -1097,35 +1094,81 @@ class GroupingSetsTable( select(fieldExprs: _*) } - /** Computes the rollup of bit sets. + /** + * Groups the records of a table by assigning them to windows defined by a time or row interval. + * + * For streaming tables of infinite size, grouping into windows is required to define finite + * groups on which group-based aggregates can be computed. * - *

For example, rollup({0}, {1}) - * returns ({0, 1}, {0}, {}). + * For batch tables of finite size, windowing essentially provides shortcuts for time-based + * groupBy. * - *

Bit sets are not necessarily singletons: - * rollup({0, 2}, {3, 5}) - * returns ({0, 2, 3, 5}, {0, 2}, {}). */ - private def rollup(groups: Seq[Seq[Expression]]): Seq[Seq[Expression]] = { - val originalBitSet = for (i <- groups.indices) yield { - ImmutableBitSet.builder().set(i).build() + * @param groupWindow group-window that specifies how elements are grouped. + * @return A windowed table. + */ + def window(groupWindow: GroupWindow): GroupingSetsWindowedTable = { + if (table.tableEnv.isInstanceOf[BatchTableEnvironment]) { + throw new ValidationException(s"Windows on batch tables are currently not supported.") } - val rollupBitSets = SqlValidatorUtil.rollup(originalBitSet.asJava) - rollupBitSets.asScala.map(_.asScala.flatMap(i => groups(i)).toSeq) + new GroupingSetsWindowedTable(table, groups, sqlKind, groupWindow) } +} + +class GroupingSetsWindowedTable( + private[flink] val table: Table, + private[flink] val groups: Seq[Seq[Expression]], + private[flink] val sqlKind: SqlKind, + private[flink] val window: GroupWindow) { - /** Computes the cube of bit sets. + /** + * Performs a selection operation on a windowed table. Similar to an SQL SELECT statement. + * The field expressions can contain complex expressions and aggregations. * - *

For example, rollup({0}, {1}) - * returns ({0, 1}, {0}, {}). + * Example: * - *

Bit sets are not necessarily singletons: - * rollup({0, 2}, {3, 5}) - * returns ({0, 2, 3, 5}, {0, 2}, {}). */ - private def cube(groups: Seq[Seq[Expression]]): Seq[Seq[Expression]] = { - val originalBitSet = for (i <- groups.indices) yield { - ImmutableBitSet.builder().set(i).build() + * {{{ + * groupWindowTable.select('key, 'window.start, 'value.avg + " The average" as 'average) + * }}} + */ + def select(fields: Expression*): Table = { + + val (projection, aggs, props) = extractAggregationsAndProperties(fields, table.tableEnv) + + val groupWindow = window.toLogicalWindow + + val groupingSets = sqlKind match { + case SqlKind.CUBE => ExpressionUtils.cube(groups) + case SqlKind.ROLLUP => ExpressionUtils.rollup(groups) + case _ => groups } - val cubeBitSets = SqlValidatorUtil.cube(originalBitSet.asJava) - cubeBitSets.asScala.map(_.asScala.flatMap(i => groups(i)).toSeq) + + val logical = + Project( + projection, + GroupingWindowAggregate( + groupWindow, + groupingSets, + props, + aggs, + table.logicalPlan + ).validate(table.tableEnv) + ).validate(table.tableEnv) + + new Table(table.tableEnv, logical) + } + + /** + * Performs a selection operation on a group-windows table. Similar to an SQL SELECT statement. + * The field expressions can contain complex expressions and aggregations. + * + * Example: + * + * {{{ + * groupWindowTable.select("key, window.start, value.avg + ' The average' as average") + * }}} + */ + def select(fields: String): Table = { + val fieldExprs = ExpressionParser.parseExpressionList(fields) + select(fieldExprs: _*) } } diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/expressions/ExpressionUtils.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/expressions/ExpressionUtils.scala index 4b5781ff77000..0fcaf97b0fa10 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/expressions/ExpressionUtils.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/expressions/ExpressionUtils.scala @@ -26,10 +26,14 @@ import org.apache.calcite.avatica.util.TimeUnit import org.apache.calcite.rel.`type`.RelDataType import org.apache.calcite.rex.{RexBuilder, RexNode} import org.apache.calcite.sql.fun.SqlStdOperatorTable +import org.apache.calcite.sql.validate.SqlValidatorUtil +import org.apache.calcite.util.ImmutableBitSet import org.apache.flink.api.common.typeinfo.BasicTypeInfo import org.apache.flink.table.api.ValidationException import org.apache.flink.table.typeutils.{RowIntervalTypeInfo, TimeIntervalTypeInfo} +import scala.collection.JavaConverters._ + object ExpressionUtils { private[flink] def toMonthInterval(expr: Expression, multiplier: Int): Expression = expr match { @@ -151,4 +155,35 @@ object ExpressionUtils { rexBuilder.makeExactLiteral(value)) } + /** Computes the rollup of bit sets. + * + *

For example, rollup({0}, {1}) + * returns ({0, 1}, {0}, {}). + * + *

Bit sets are not necessarily singletons: + * rollup({0, 2}, {3, 5}) + * returns ({0, 2, 3, 5}, {0, 2}, {}). */ + private[flink] def rollup(groups: Seq[Seq[Expression]]): Seq[Seq[Expression]] = { + val originalBitSet = for (i <- groups.indices) yield { + ImmutableBitSet.builder().set(i).build() + } + val rollupBitSets = SqlValidatorUtil.rollup(originalBitSet.asJava) + rollupBitSets.asScala.map(_.asScala.flatMap(i => groups(i)).toSeq) + } + + /** Computes the cube of bit sets. + * + *

For example, rollup({0}, {1}) + * returns ({0, 1}, {0}, {}). + * + *

Bit sets are not necessarily singletons: + * rollup({0, 2}, {3, 5}) + * returns ({0, 2, 3, 5}, {0, 2}, {}). */ + private[flink] def cube(groups: Seq[Seq[Expression]]): Seq[Seq[Expression]] = { + val originalBitSet = for (i <- groups.indices) yield { + ImmutableBitSet.builder().set(i).build() + } + val cubeBitSets = SqlValidatorUtil.cube(originalBitSet.asJava) + cubeBitSets.asScala.map(_.asScala.flatMap(i => groups(i)).toSeq) + } } diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/logical/operators.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/logical/operators.scala index 755221c298cec..09b5f29ed7ae6 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/logical/operators.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/logical/operators.scala @@ -20,16 +20,12 @@ package org.apache.flink.table.plan.logical import java.lang.reflect.Method import java.util -import com.google.common.collect.{ImmutableList, Sets} import org.apache.calcite.rel.RelNode import org.apache.calcite.rel.`type`.RelDataType import org.apache.calcite.rel.core.CorrelationId import org.apache.calcite.rel.logical.{LogicalProject, LogicalTableFunctionScan} import org.apache.calcite.rex.{RexInputRef, RexNode} -import org.apache.calcite.sql.SqlKind -import org.apache.calcite.sql.validate.SqlValidatorUtil import org.apache.calcite.tools.RelBuilder -import org.apache.calcite.util.{ImmutableBitSet, ImmutableIntList} import org.apache.flink.api.common.typeinfo.BasicTypeInfo._ import org.apache.flink.api.common.typeinfo.TypeInformation import org.apache.flink.api.java.operators.join.JoinType @@ -45,7 +41,6 @@ import org.apache.flink.table.typeutils.TypeConverter import org.apache.flink.table.validate.{ValidationFailure, ValidationSuccess} import scala.collection.JavaConverters._ -import scala.collection.immutable.BitSet import scala.collection.mutable case class Project(projectList: Seq[NamedExpression], child: LogicalNode) extends UnaryNode { @@ -268,11 +263,11 @@ case class Aggregate( resolvedAggregate } } + case class GroupingAggregation( groupingExpressions: Seq[Seq[Expression]], aggregateExpressions: Seq[NamedExpression], - child: LogicalNode, - sqlKind: SqlKind + child: LogicalNode ) extends UnaryNode { override def output: Seq[Attribute] = { @@ -339,7 +334,7 @@ case class GroupingAggregation( "because it's not a valid key type which must be hashable and comparable") } } - GroupingAggregation(resolvedGroupingExprs, resolvedAggregate.aggregateExpressions, child, sqlKind) + GroupingAggregation(resolvedGroupingExprs, resolvedAggregate.aggregateExpressions, child) } } @@ -695,6 +690,118 @@ case class WindowAggregate( } } +case class GroupingWindowAggregate( + window: LogicalWindow, + groupingExpressions: Seq[Seq[Expression]], + propertyExpressions: Seq[NamedExpression], + aggregateExpressions: Seq[NamedExpression], + child: LogicalNode + ) extends UnaryNode { + + override def output: Seq[Attribute] = { + (groupingExpressions.flatten.distinct ++ aggregateExpressions ++ propertyExpressions) map { + case ne: NamedExpression => ne.toAttribute + case e => Alias(e, e.toString).toAttribute + } + } + + // resolve references of this operator's parameters + override def resolveReference( + tableEnv: TableEnvironment, + name: String) + : Option[NamedExpression] = tableEnv match { + // resolve reference to rowtime attribute in a streaming environment + case _: StreamTableEnvironment if name == "rowtime" => + Some(RowtimeAttribute()) + case _ => + window.alias match { + // resolve reference to this window's alias + case Some(UnresolvedFieldReference(alias)) if name == alias => + // check if reference can already be resolved by input fields + val found = super.resolveReference(tableEnv, name) + if (found.isDefined) { + failValidation(s"Reference $name is ambiguous.") + } else { + Some(WindowReference(name)) + } + case _ => + // resolve references as usual + super.resolveReference(tableEnv, name) + } + } + + override protected[logical] def construct(relBuilder: RelBuilder): RelBuilder = { + val flinkRelBuilder = relBuilder.asInstanceOf[FlinkRelBuilder] + child.construct(flinkRelBuilder) + val groupingSets = groupingExpressions.map(_.map(_.toRexNode(relBuilder)).toList).toList + flinkRelBuilder.aggregate( + window, + relBuilder.groupKey( + groupingExpressions.head.map(_.toRexNode(relBuilder)).asJava, + true, groupingSets.map(_.asJava).asJava + ), + propertyExpressions.map { + case Alias(prop: WindowProperty, name, _) => prop.toNamedWindowProperty(name)(relBuilder) + case _ => throw new RuntimeException("This should never happen.") + }, + aggregateExpressions.map { + case Alias(agg: Aggregation, name, _) => agg.toAggCall(name)(relBuilder) + case _ => throw new RuntimeException("This should never happen.") + }.asJava) + } + + override def validate(tableEnv: TableEnvironment): LogicalNode = { + val resolvedWindowAggregate = super.validate(tableEnv).asInstanceOf[GroupingWindowAggregate] + val groupingExprs = resolvedWindowAggregate.groupingExpressions + val aggregateExprs = resolvedWindowAggregate.aggregateExpressions + val resolvedGroupingExprs = groupingExprs.map(_.map { + case u @ UnresolvedFieldReference(name) => + resolveReference(tableEnv, name).getOrElse(u) + case x => x + }) + resolvedGroupingExprs.flatten.foreach(validateGroupingExpression) + aggregateExprs.foreach(validateAggregateExpression) + + def validateAggregateExpression(expr: Expression): Unit = expr match { + // check no nested aggregation exists. + case aggExpr: Aggregation => + aggExpr.children.foreach { child => + child.preOrderVisit { + case agg: Aggregation => + failValidation( + "It's not allowed to use an aggregate function as " + + "input of another aggregate function") + case _ => // ok + } + } + case a: Attribute if !groupingExprs.flatten.exists(_.checkEquals(a)) => + failValidation( + s"Expression '$a' is invalid because it is neither" + + " present in group by nor an aggregate function") + case e if groupingExprs.flatten.exists(_.checkEquals(e)) => // ok + case e => e.children.foreach(validateAggregateExpression) + } + + def validateGroupingExpression(expr: Expression): Unit = { + if (!expr.resultType.isKeyType) { + failValidation( + s"Expression $expr cannot be used as a grouping expression " + + "because it's not a valid key type which must be hashable and comparable") + } + } + + // validate window + resolvedWindowAggregate.window.validate(tableEnv) match { + case ValidationFailure(msg) => + failValidation(s"$window is invalid: $msg") + case ValidationSuccess => // ok + } + + GroupingWindowAggregate(window, resolvedGroupingExprs, propertyExpressions, + resolvedWindowAggregate.aggregateExpressions, child) + } +} + /** * LogicalNode for calling a user-defined table functions. * From 24b0e7a8d89ccfe06de89c281ca849c9f62b7f1e Mon Sep 17 00:00:00 2001 From: Aleksandr Chermenin Date: Thu, 15 Dec 2016 13:13:45 +0300 Subject: [PATCH 08/18] [FLINK-2980] Added tests. --- .../scala/batch/table/GroupingSetsTest.scala | 140 ++++++++++++++++++ .../apache/flink/table/AggregationTest.scala | 60 -------- .../apache/flink/test/util/TestBaseUtils.java | 16 +- 3 files changed, 153 insertions(+), 63 deletions(-) create mode 100644 flink-libraries/flink-table/src/test/scala/org/apache/flink/api/scala/batch/table/GroupingSetsTest.scala diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/scala/batch/table/GroupingSetsTest.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/scala/batch/table/GroupingSetsTest.scala new file mode 100644 index 0000000000000..5db8668133d93 --- /dev/null +++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/scala/batch/table/GroupingSetsTest.scala @@ -0,0 +1,140 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.api.scala.batch.table + + +import org.apache.flink.api.scala._ +import org.apache.flink.api.scala.table._ +import org.apache.flink.api.scala.util.CollectionDataSets +import org.apache.flink.api.table.{Row, Table, TableConfig, TableEnvironment} +import org.apache.flink.test.util.TestBaseUtils +import org.junit._ + +import scala.collection.JavaConverters._ + +class GroupingSetsTest { + + private var tableEnv: BatchTableEnvironment = _ + private var table: Table = _ + private var tableWithNulls: Table = _ + + @Before + def setup(): Unit = { + val env = ExecutionEnvironment.getExecutionEnvironment + tableEnv = TableEnvironment.getTableEnvironment(env, new TableConfig()) + + val dataSet = CollectionDataSets.get3TupleDataSet(env) + table = dataSet.toTable(tableEnv, 'a, 'b, 'c) + + val dataSetWithNulls = dataSet.map(value => value match { + case (x, y, s) => (x, y, if (s.toLowerCase().contains("world")) null else s) + }) + tableWithNulls = dataSetWithNulls.toTable(tableEnv, 'a, 'b, 'c) + } + + @Test + def testGroupingSets() = { + val t = table + .groupingSets('b, 'c) + .select('b, 'c, 'a.avg as 'a, groupId() as 'g) + + val expected = + "6,null,18,1\n5,null,13,1\n4,null,8,1\n3,null,5,1\n2,null,2,1\n1,null,1,1\n" + + "null,Luke Skywalker,6,2\nnull,I am fine.,5,2\nnull,Hi,1,2\n" + + "null,Hello world, how are you?,4,2\nnull,Hello world,3,2\nnull,Hello,2,2\n" + + "null,Comment#9,15,2\nnull,Comment#8,14,2\nnull,Comment#7,13,2\n" + + "null,Comment#6,12,2\nnull,Comment#5,11,2\nnull,Comment#4,10,2\n" + + "null,Comment#3,9,2\nnull,Comment#2,8,2\nnull,Comment#15,21,2\n" + + "null,Comment#14,20,2\nnull,Comment#13,19,2\nnull,Comment#12,18,2\n" + + "null,Comment#11,17,2\nnull,Comment#10,16,2\nnull,Comment#1,7,2" + + val results = t.toDataSet[Row].collect() + TestBaseUtils.compareResultAsText(results.asJava, expected) + } + + @Test + def testGroupingSetsWithNulls() = { + val t = tableWithNulls + .groupingSets('b, 'c) + .select('b, 'c, 'a.avg as 'a, groupId() as 'g) + + val expected = + "6,null,18,1\n5,null,13,1\n4,null,8,1\n3,null,5,1\n2,null,2,1\n1,null,1,1\n" + + "null,Luke Skywalker,6,2\nnull,I am fine.,5,2\nnull,Hi,1,2\n" + + "null,null,3,2\nnull,Hello,2,2\nnull,Comment#9,15,2\nnull,Comment#8,14,2\n" + + "null,Comment#7,13,2\nnull,Comment#6,12,2\nnull,Comment#5,11,2\n" + + "null,Comment#4,10,2\nnull,Comment#3,9,2\nnull,Comment#2,8,2\n" + + "null,Comment#15,21,2\nnull,Comment#14,20,2\nnull,Comment#13,19,2\n" + + "null,Comment#12,18,2\nnull,Comment#11,17,2\nnull,Comment#10,16,2\n" + + "null,Comment#1,7,2" + + val results = t.toDataSet[Row].collect() + TestBaseUtils.compareResultAsText(results.asJava, expected) + } + + @Test + def testCubeAsGroupingSets() = { + val t1 = table + .cube('b, 'c) + .select( + 'b, 'c, 'a.avg as 'a, groupId() as 'g, + 'b.grouping() as 'gb, grouping('c) as 'gc, + 'b.groupingId() as 'gib, groupingId('c) as 'gic, + ('b, 'c).groupingId() as 'gid + ) + + val t2 = table + .groupingSets(('b, 'c), 'b, 'c, ()) + .select( + 'b, 'c, 'a.avg as 'a, groupId() as 'g, + grouping('b) as 'gb, 'c.grouping() as 'gc, + groupingId('b) as 'gib, 'c.groupingId() as 'gic, + groupingId('b, 'c) as 'gid + ) + + val results1 = t1.toDataSet[Row].map(_.toString).collect() + val results2 = t2.toDataSet[Row].map(_.toString).collect() + TestBaseUtils.compareResultCollections(results1.asJava, results2.asJava) + } + + @Test + def testRollupAsGroupingSets() = { + val t1 = table + .rollup('b, 'c) + .select( + 'b, 'c, 'a.avg as 'a, groupId() as 'g, + 'b.grouping() as 'gb, grouping('c) as 'gc, + 'b.groupingId() as 'gib, groupingId('c) as 'gic, + ('b, 'c).groupingId() as 'gid + ) + + val t2 = table + .groupingSets(('b, 'c), 'b, ()) + .select( + 'b, 'c, 'a.avg as 'a, groupId() as 'g, + grouping('b) as 'gb, 'c.grouping() as 'gc, + groupingId('b) as 'gib, 'c.groupingId() as 'gic, + groupingId('b, 'c) as 'gid + ) + + val results1 = t1.toDataSet[Row].map(_.toString).collect() + val results2 = t2.toDataSet[Row].map(_.toString).collect() + TestBaseUtils.compareResultCollections(results1.asJava, results2.asJava) + } +} diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/AggregationTest.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/AggregationTest.scala index cea252649a178..ce74d00d8d13c 100644 --- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/AggregationTest.scala +++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/AggregationTest.scala @@ -262,64 +262,4 @@ class AggregationTest extends TableTestBase { util.verifyTable(resultTable, expected) } - - @Test - def testGroupingSetsTableApi(): Unit = { - val env = ExecutionEnvironment.getExecutionEnvironment - val tEnv = TableEnvironment.getTableEnvironment(env, new TableConfig) - val ds = CollectionDataSets.get3TupleDataSet(env) - val table = tEnv.fromDataSet(ds, 'a, 'b, 'c) - val result = table.groupingSets(('a, 'b), 'b, ()).select('a, 'b, groupId() as 'g) - result.toDataSet[Row].print() - } - - @Test - def testStringGroupingSetsTableApi(): Unit = { - val env = ExecutionEnvironment.getExecutionEnvironment - val tEnv = TableEnvironment.getTableEnvironment(env, new TableConfig) - val ds = CollectionDataSets.get3TupleDataSet(env) - val table = tEnv.fromDataSet(ds) - val result = table.groupingSets("(_1, _2), (_1), ()").select("_1, _2, groupId") - result.toDataSet[Row].print() - } - - @Test - def testCubeTableApi(): Unit = { - val env = ExecutionEnvironment.getExecutionEnvironment - val tEnv = TableEnvironment.getTableEnvironment(env, new TableConfig) - val ds = CollectionDataSets.get3TupleDataSet(env) - val table = tEnv.fromDataSet(ds, 'a, 'b, 'c) - val result = table.cube('a, 'b).select('a, 'b) - result.toDataSet[Row].print() - } - - @Test - def testStringCubeTableApi(): Unit = { - val env = ExecutionEnvironment.getExecutionEnvironment - val tEnv = TableEnvironment.getTableEnvironment(env, new TableConfig) - val ds = CollectionDataSets.get3TupleDataSet(env) - val table = tEnv.fromDataSet(ds) - val result = table.cube("(_1, _3), _2").select("_1, _2") - result.toDataSet[Row].print() - } - - @Test - def testRollupTableApi(): Unit = { - val env = ExecutionEnvironment.getExecutionEnvironment - val tEnv = TableEnvironment.getTableEnvironment(env, new TableConfig) - val ds = CollectionDataSets.get3TupleDataSet(env) - val table = tEnv.fromDataSet(ds, 'a, 'b, 'c) - val result = table.rollup('a, 'b).select('a, 'b) - result.toDataSet[Row].print() - } - - @Test - def testStringRollupTableApi(): Unit = { - val env = ExecutionEnvironment.getExecutionEnvironment - val tEnv = TableEnvironment.getTableEnvironment(env, new TableConfig) - val ds = CollectionDataSets.get3TupleDataSet(env) - val table = tEnv.fromDataSet(ds) - val result = table.rollup("(_1, _3), _2").select("_1, _2") - result.toDataSet[Row].print() - } } diff --git a/flink-test-utils-parent/flink-test-utils/src/main/java/org/apache/flink/test/util/TestBaseUtils.java b/flink-test-utils-parent/flink-test-utils/src/main/java/org/apache/flink/test/util/TestBaseUtils.java index 84312263ccd23..53efc01d49881 100644 --- a/flink-test-utils-parent/flink-test-utils/src/main/java/org/apache/flink/test/util/TestBaseUtils.java +++ b/flink-test-utils-parent/flink-test-utils/src/main/java/org/apache/flink/test/util/TestBaseUtils.java @@ -322,8 +322,8 @@ public static void compareResultsByLinesInMemory( String msg = String.format( "Different elements in arrays: expected %d elements and received %d\n" + "files: %s\n expected: %s\n received: %s", - expected.length, result.length, - Arrays.toString(getAllInvolvedFiles(resultPath, excludePrefixes)), + expected.length, result.length, + Arrays.toString(getAllInvolvedFiles(resultPath, excludePrefixes)), Arrays.toString(expected), Arrays.toString(result)); fail(msg); } @@ -408,10 +408,20 @@ public static void compareResultCollections(List expected, List actual } } + public static > void compareResultCollections(List expected, List actual) { + Assert.assertEquals(expected.size(), actual.size()); + + Collections.sort(expected); + Collections.sort(actual); + + for (int i = 0; i < expected.size(); i++) { + Assert.assertEquals(expected.get(i), actual.get(i)); + } + } + private static File[] getAllInvolvedFiles(String resultPath, final String[] excludePrefixes) { final File result = asFile(resultPath); assertTrue("Result file was not written", result.exists()); - if (result.isDirectory()) { return result.listFiles(new FilenameFilter() { From f55f47f415c37082311ed6eb3bcedd4ba06f4b1a Mon Sep 17 00:00:00 2001 From: Aleksandr Chermenin Date: Thu, 15 Dec 2016 13:27:44 +0300 Subject: [PATCH 09/18] [FLINK-2980] Small fixes. --- .../scala/org/apache/flink/table/plan/logical/operators.scala | 2 +- .../apache/flink/api/scala/batch/table/GroupingSetsTest.scala | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/logical/operators.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/logical/operators.scala index 09b5f29ed7ae6..e3800fa3882ac 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/logical/operators.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/logical/operators.scala @@ -282,7 +282,7 @@ case class GroupingAggregation( val groupingSets = groupingExpressions.map(_.map(_.toRexNode(relBuilder)).toList).toList relBuilder.aggregate( relBuilder.groupKey( - groupingExpressions.head.map(_.toRexNode(relBuilder)).asJava, + groupingExpressions.flatten.distinct.map(_.toRexNode(relBuilder)).asJava, true, groupingSets.map(_.asJava).asJava ), aggregateExpressions.map { diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/scala/batch/table/GroupingSetsTest.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/scala/batch/table/GroupingSetsTest.scala index 5db8668133d93..8cdd6f5d2b2e2 100644 --- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/scala/batch/table/GroupingSetsTest.scala +++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/scala/batch/table/GroupingSetsTest.scala @@ -43,7 +43,7 @@ class GroupingSetsTest { table = dataSet.toTable(tableEnv, 'a, 'b, 'c) val dataSetWithNulls = dataSet.map(value => value match { - case (x, y, s) => (x, y, if (s.toLowerCase().contains("world")) null else s) + case (x, y, s: String) => (x, y, if (s.toLowerCase().contains("world")) null else s) }) tableWithNulls = dataSetWithNulls.toTable(tableEnv, 'a, 'b, 'c) } From 4507405342eb49466fd3b1babb1519414f02d662 Mon Sep 17 00:00:00 2001 From: Aleksandr Chermenin Date: Fri, 16 Dec 2016 11:12:07 +0300 Subject: [PATCH 10/18] [FLINK-2980] Improved documentation. --- docs/dev/table_api.md | 36 +++++++++++++++++++++++++++++++++++- 1 file changed, 35 insertions(+), 1 deletion(-) diff --git a/docs/dev/table_api.md b/docs/dev/table_api.md index acabfcf300cf8..8c4bfc40064fb 100644 --- a/docs/dev/table_api.md +++ b/docs/dev/table_api.md @@ -780,7 +780,7 @@ val result = in.where('b === "red"); GroupBy -

Similar to a SQL GROUPBY clause. Groups rows on the grouping keys, with a following aggregation +

Similar to a SQL GROUP BY clause. Groups rows on the grouping keys, with a following aggregation operator to aggregate rows group-wise.

{% highlight scala %} val in = ds.toTable(tableEnv, 'a, 'b, 'c); @@ -789,6 +789,40 @@ val result = in.groupBy('a).select('a, 'b.sum as 'd); + + GroupingSets + +

Similar to a SQL GROUP BY GROUPING SETS clause. A GROUPING SETS expression allows to selectively specify the + set of groups that you want to create within a GROUP BY clause.

+{% highlight scala %} +val in = ds.toTable(tableEnv, 'a, 'b, 'c); +val result = in.groupingSets(('a, 'b), 'b, ()).select('a, 'b, 'c.sum as 'd); +{% endhighlight %} + + + + + Cube + +

Similar to a SQL GROUP BY CUBE clause. A CUBE expression will generate subtotals for all combinations of the dimensions specified.

+{% highlight scala %} +val in = ds.toTable(tableEnv, 'a, 'b, 'c); +val result = in.cube('a, 'b).select('a, 'b, 'c.sum as 'd); +{% endhighlight %} + + + + + Cube + +

Similar to a SQL GROUP BY ROLLUP clause. A ROLLUP expression produces group subtotals from right to left and a grand total.

+{% highlight scala %} +val in = ds.toTable(tableEnv, 'a, 'b, 'c); +val result = in.rollup('a, 'b).select('a, 'b, 'c.sum as 'd); +{% endhighlight %} + + + Join From 97d621b5d09b4a18af2ab73a27c691dbcc764fa0 Mon Sep 17 00:00:00 2001 From: Aleksandr Chermenin Date: Fri, 16 Dec 2016 12:42:30 +0300 Subject: [PATCH 11/18] [FLINK-2980] Small docs improvements. --- docs/dev/table_api.md | 36 ++++++++++++++++++++++++++++++++++++ 1 file changed, 36 insertions(+) diff --git a/docs/dev/table_api.md b/docs/dev/table_api.md index 8c4bfc40064fb..b6e9692a9d0e5 100644 --- a/docs/dev/table_api.md +++ b/docs/dev/table_api.md @@ -2748,6 +2748,42 @@ ARRAY.element() + + + {% highlight scala %} +groupId() +{% endhighlight %} + + +

Returns an integer that uniquely identifies the combination of grouping keys.

+ + + + + + {% highlight scala %} +ANY.grouping() +grouping(ANY) +{% endhighlight %} + + +

Returns 1 if expression is rolled up in the current row’s grouping set, 0 otherwise.

+ + + + + + {% highlight scala %} +ANY.groupingId() +(ANY [, ANY ]*).groupingId() +groupingId(ANY [, ANY ]*) +{% endhighlight %} + + +

Returns a bit vector of the given grouping expressions.

+ + + From b3604aac41150d5486860c428c120fdacae932ce Mon Sep 17 00:00:00 2001 From: Aleksandr Chermenin Date: Wed, 21 Dec 2016 13:29:13 +0300 Subject: [PATCH 12/18] [FLINK-2980] Some fixes after rebase. --- .../org/apache/flink/table/api/table.scala | 50 +++++++++++-------- .../flink/table/plan/logical/operators.scala | 4 +- .../scala/batch/table/GroupingSetsTest.scala | 6 +-- .../apache/flink/table/AggregationTest.scala | 4 -- 4 files changed, 34 insertions(+), 30 deletions(-) diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/table.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/table.scala index 5fd8cb26417f3..08207fbe554ad 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/table.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/table.scala @@ -24,8 +24,8 @@ import org.apache.calcite.tools.RelBuilder import org.apache.flink.api.common.typeinfo.TypeInformation import org.apache.flink.api.java.operators.join.JoinType import org.apache.flink.table.calcite.FlinkTypeFactory -import org.apache.flink.table.plan.logical.Minus -import org.apache.flink.table.expressions.{Alias, Asc, Call, Expression, ExpressionParser, GroupedExpression, Ordering, TableFunctionCall, UnresolvedAlias} +import org.apache.flink.table.plan.logical.{Minus => MinusNode} +import org.apache.flink.table.expressions._ import org.apache.flink.table.plan.ProjectionTranslator._ import org.apache.flink.table.plan.logical._ import org.apache.flink.table.sinks.TableSink @@ -550,7 +550,7 @@ class Table( throw new ValidationException("Only tables from the same TableEnvironment can be " + "subtracted.") } - new Table(tableEnv, Minus(logicalPlan, right.logicalPlan, all = false) + new Table(tableEnv, MinusNode(logicalPlan, right.logicalPlan, all = false) .validate(tableEnv)) } @@ -575,7 +575,7 @@ class Table( throw new ValidationException("Only tables from the same TableEnvironment can be " + "subtracted.") } - new Table(tableEnv, Minus(logicalPlan, right.logicalPlan, all = true) + new Table(tableEnv, MinusNode(logicalPlan, right.logicalPlan, all = true) .validate(tableEnv)) } @@ -1054,9 +1054,9 @@ class GroupingSetsTable( */ def select(fields: Expression*): Table = { - val (projection, aggs, props) = extractAggregationsAndProperties(fields, table.tableEnv) + val (aggNames, propNames) = extractAggregationsAndProperties(fields, table.tableEnv) - if (props.nonEmpty) { + if (propNames.nonEmpty) { throw ValidationException("Window properties can only be used on windowed tables.") } @@ -1066,13 +1066,14 @@ class GroupingSetsTable( case _ => groups } + val projectsOnAgg = replaceAggregationsAndProperties( + fields, table.tableEnv, aggNames, propNames) + val projectFields = extractFieldReferences(fields ++ groupingSets.flatten.distinct) + val logical = - Project( - projection, - GroupingAggregation( - groupingSets, - aggs, - table.logicalPlan + Project(projectsOnAgg, + GroupingAggregation(groupingSets, aggNames.map(a => Alias(a._1, a._2)).toSeq, + Project(projectFields, table.logicalPlan).validate(table.tableEnv) ).validate(table.tableEnv) ).validate(table.tableEnv) @@ -1132,9 +1133,9 @@ class GroupingSetsWindowedTable( */ def select(fields: Expression*): Table = { - val (projection, aggs, props) = extractAggregationsAndProperties(fields, table.tableEnv) - - val groupWindow = window.toLogicalWindow + val (aggNames, propNames) = extractAggregationsAndProperties(fields, table.tableEnv) + val projectsOnAgg = replaceAggregationsAndProperties( + fields, table.tableEnv, aggNames, propNames) val groupingSets = sqlKind match { case SqlKind.CUBE => ExpressionUtils.cube(groups) @@ -1142,15 +1143,22 @@ class GroupingSetsWindowedTable( case _ => groups } + val projectFields = (table.tableEnv, window) match { + // event time can be arbitrary field in batch environment + case (_: BatchTableEnvironment, w: EventTimeWindow) => + extractFieldReferences(fields ++ groupingSets.flatten.distinct ++ Seq(w.timeField)) + case (_, _) => + extractFieldReferences(fields ++ groupingSets.flatten.distinct) + } + val logical = - Project( - projection, + Project(projectsOnAgg, GroupingWindowAggregate( - groupWindow, groupingSets, - props, - aggs, - table.logicalPlan + window.toLogicalWindow, + propNames.map(a => Alias(a._1, a._2)).toSeq, + aggNames.map(a => Alias(a._1, a._2)).toSeq, + Project(projectFields, table.logicalPlan).validate(table.tableEnv) ).validate(table.tableEnv) ).validate(table.tableEnv) diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/logical/operators.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/logical/operators.scala index e3800fa3882ac..fddfd5674e7c6 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/logical/operators.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/logical/operators.scala @@ -691,8 +691,8 @@ case class WindowAggregate( } case class GroupingWindowAggregate( - window: LogicalWindow, groupingExpressions: Seq[Seq[Expression]], + window: LogicalWindow, propertyExpressions: Seq[NamedExpression], aggregateExpressions: Seq[NamedExpression], child: LogicalNode @@ -797,7 +797,7 @@ case class GroupingWindowAggregate( case ValidationSuccess => // ok } - GroupingWindowAggregate(window, resolvedGroupingExprs, propertyExpressions, + GroupingWindowAggregate(resolvedGroupingExprs, window, propertyExpressions, resolvedWindowAggregate.aggregateExpressions, child) } } diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/scala/batch/table/GroupingSetsTest.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/scala/batch/table/GroupingSetsTest.scala index 8cdd6f5d2b2e2..4f2bd88b22b75 100644 --- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/scala/batch/table/GroupingSetsTest.scala +++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/scala/batch/table/GroupingSetsTest.scala @@ -18,12 +18,12 @@ package org.apache.flink.api.scala.batch.table - import org.apache.flink.api.scala._ -import org.apache.flink.api.scala.table._ import org.apache.flink.api.scala.util.CollectionDataSets -import org.apache.flink.api.table.{Row, Table, TableConfig, TableEnvironment} +import org.apache.flink.table.api.scala._ +import org.apache.flink.table.api.{Table, TableConfig, TableEnvironment} import org.apache.flink.test.util.TestBaseUtils +import org.apache.flink.types.Row import org.junit._ import scala.collection.JavaConverters._ diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/AggregationTest.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/AggregationTest.scala index ce74d00d8d13c..708e00766a0e7 100644 --- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/AggregationTest.scala +++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/AggregationTest.scala @@ -19,12 +19,8 @@ package org.apache.flink.table import org.apache.flink.api.scala._ import org.apache.flink.table.api.scala._ -import org.apache.flink.api.scala.util.CollectionDataSets -import org.apache.flink.table.api.{TableConfig, TableEnvironment} -import org.apache.flink.table.expressions.{Expression, GroupedExpression} import org.apache.flink.table.utils.TableTestBase import org.apache.flink.table.utils.TableTestUtil._ -import org.apache.flink.types.Row import org.junit.Test /** From b7e1150f152c67d11ebb3f9496bb206b09ad929b Mon Sep 17 00:00:00 2001 From: Aleksandr Chermenin Date: Thu, 19 Jan 2017 12:15:35 +0300 Subject: [PATCH 13/18] [FLINK-2980] Fixed grouping sets for Table API. --- .../flink/table/expressions/groupings.scala | 94 +++++++++++++++++-- .../flink/table/plan/logical/operators.scala | 22 ++++- .../scala/batch/table/GroupingSetsTest.scala | 28 ++++-- 3 files changed, 122 insertions(+), 22 deletions(-) diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/expressions/groupings.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/expressions/groupings.scala index 356b83b2e88be..d61dca39690e4 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/expressions/groupings.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/expressions/groupings.scala @@ -18,13 +18,88 @@ package org.apache.flink.table.expressions -import org.apache.calcite.sql.fun.SqlStdOperatorTable import org.apache.calcite.tools.RelBuilder import org.apache.flink.api.common.typeinfo.BasicTypeInfo abstract sealed class GroupFunction extends Expression { override def toString = s"GroupFunction($children)" + + private[flink] def replaceExpression( + relBuilder: RelBuilder, + groupExpressions: Option[Seq[Expression]], + children: Seq[Attribute] = Seq(), + indicator: Boolean = false): Expression = { + + if (groupExpressions.isDefined) { + val expressions = groupExpressions.get + if (!indicator) { + Cast( + Minus(Power(Literal(2), Literal(getEffectiveArgCount(expressions))), Literal(1)), + BasicTypeInfo.LONG_TYPE_INFO + ) + } else { + val operands = getOperands(expressions) + val internalFieldsMap = getInternalFields(children) + var shift = operands.size + var expression: Option[Expression] = None + operands.foreach(x => { + shift -= 1 + expression = bitValue(relBuilder, expression, x, shift, expressions, internalFieldsMap) + }) + Cast(expression.get, BasicTypeInfo.LONG_TYPE_INFO) + } + } else { + this + } + } + + private def getInternalFields(children: Seq[Attribute]) = { + val inputFields = children.map(_.name) + inputFields.map(inputFieldName => { + val base = "i$" + inputFieldName + var name = base + var i = 0 + while (inputFields.contains(name)) { + name = base + "_" + i // if i$XXX is already a field it will be suffixed by _NUMBER + i = i + 1 + } + inputFieldName -> name + }).toMap + } + + private def bitValue(relBuilder: RelBuilder, + expression: Option[Expression], operand: Int, + shift: Int, expressions: Seq[Expression], + internalFieldsMap: Map[String, String] + ): Option[Expression] = { + + val fieldName = expressions(operand) match { + case ne: NamedExpression => ne.name + case _ => "" + } + + var nextExpression: Expression = + If(IsTrue(ResolvedFieldReference( + internalFieldsMap(fieldName), BasicTypeInfo.BOOLEAN_TYPE_INFO)), + Literal(1), Literal(0)) + + if (shift > 0) { + nextExpression = Mul(nextExpression, Power(Literal(2), Literal(shift))) + } + + if (expression.isDefined) { + nextExpression = Plus(expression.get, nextExpression) + } + + Some(nextExpression) + } + + protected def getEffectiveArgCount(groupExpressions: Seq[Expression]): Int + + protected def getOperands(groupExpressions: Seq[Expression]): Seq[Int] = { + children.map(e => groupExpressions.indexOf(e)) + } } case class GroupId() extends GroupFunction { @@ -33,32 +108,31 @@ case class GroupId() extends GroupFunction { override private[flink] def children = Nil - override private[flink] def toRexNode(implicit relBuilder: RelBuilder) = { - relBuilder.call(SqlStdOperatorTable.GROUP_ID) + override protected def getEffectiveArgCount(groupExpressions: Seq[Expression]): Int = { + groupExpressions.size } + + override protected def getOperands(groupExpressions: Seq[Expression]): Seq[Int] = + groupExpressions.indices } case class Grouping(expression: Expression) extends GroupFunction { override private[flink] def resultType = BasicTypeInfo.LONG_TYPE_INFO - override private[flink] def children = Seq(expression) - override private[flink] def toRexNode(implicit relBuilder: RelBuilder) = { - relBuilder.call(SqlStdOperatorTable.GROUPING, expression.toRexNode) - } + override protected def getEffectiveArgCount(groupExpressions: Seq[Expression]): Int = 1 } case class GroupingId(expressions: Expression*) extends GroupFunction { override private[flink] def resultType = BasicTypeInfo.LONG_TYPE_INFO - override private[flink] def children = expressions - override private[flink] def toRexNode(implicit relBuilder: RelBuilder) = { - relBuilder.call(SqlStdOperatorTable.GROUPING_ID, expressions.map(_.toRexNode): _*) + override protected def getEffectiveArgCount(groupExpressions: Seq[Expression]): Int = { + expressions.size } } diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/logical/operators.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/logical/operators.scala index fddfd5674e7c6..5999e4f69a191 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/logical/operators.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/logical/operators.scala @@ -94,7 +94,24 @@ case class Project(projectList: Seq[NamedExpression], child: LogicalNode) extend override protected[logical] def construct(relBuilder: RelBuilder): RelBuilder = { child.construct(relBuilder) relBuilder.project( - projectList.map(_.toRexNode(relBuilder)).asJava, + projectList.map { + case Alias(gf: GroupFunction, name, extraNames) => + children.head match { + + case GroupingAggregation(grpExps, _, node) => + Alias(gf.replaceExpression(relBuilder, + Some(grpExps.flatten.distinct), node.output, indicator = true), + name, extraNames).toRexNode(relBuilder) + + case Aggregate(grpExps, _, node) => + Alias(gf.replaceExpression(relBuilder, Some(grpExps), node.output), + name, extraNames).toRexNode(relBuilder) + + case _ => + Alias(gf.replaceExpression(relBuilder, None), name, extraNames).toRexNode(relBuilder) + } + case x => x.toRexNode(relBuilder) + }.asJava, projectList.map(_.name).asJava, true) } @@ -279,11 +296,10 @@ case class GroupingAggregation( override protected[logical] def construct(relBuilder: RelBuilder): RelBuilder = { child.construct(relBuilder) - val groupingSets = groupingExpressions.map(_.map(_.toRexNode(relBuilder)).toList).toList relBuilder.aggregate( relBuilder.groupKey( groupingExpressions.flatten.distinct.map(_.toRexNode(relBuilder)).asJava, - true, groupingSets.map(_.asJava).asJava + true, groupingExpressions.map(_.map(_.toRexNode(relBuilder)).asJava).asJava ), aggregateExpressions.map { case Alias(agg: Aggregation, name, _) => agg.toAggCall(name)(relBuilder) diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/scala/batch/table/GroupingSetsTest.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/scala/batch/table/GroupingSetsTest.scala index 4f2bd88b22b75..ac549335cd77a 100644 --- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/scala/batch/table/GroupingSetsTest.scala +++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/scala/batch/table/GroupingSetsTest.scala @@ -52,17 +52,27 @@ class GroupingSetsTest { def testGroupingSets() = { val t = table .groupingSets('b, 'c) - .select('b, 'c, 'a.avg as 'a, groupId() as 'g) + .select( + 'b, 'c, 'a.avg as 'a, groupId() as 'g, + 'b.grouping() as 'gb, grouping('c) as 'gc, + 'b.groupingId() as 'gib, groupingId('c) as 'gic, + ('b, 'c).groupingId() as 'gid + ) val expected = - "6,null,18,1\n5,null,13,1\n4,null,8,1\n3,null,5,1\n2,null,2,1\n1,null,1,1\n" + - "null,Luke Skywalker,6,2\nnull,I am fine.,5,2\nnull,Hi,1,2\n" + - "null,Hello world, how are you?,4,2\nnull,Hello world,3,2\nnull,Hello,2,2\n" + - "null,Comment#9,15,2\nnull,Comment#8,14,2\nnull,Comment#7,13,2\n" + - "null,Comment#6,12,2\nnull,Comment#5,11,2\nnull,Comment#4,10,2\n" + - "null,Comment#3,9,2\nnull,Comment#2,8,2\nnull,Comment#15,21,2\n" + - "null,Comment#14,20,2\nnull,Comment#13,19,2\nnull,Comment#12,18,2\n" + - "null,Comment#11,17,2\nnull,Comment#10,16,2\nnull,Comment#1,7,2" + "1,null,1,1,0,1,0,1,1\n" + "6,null,18,1,0,1,0,1,1\n" + "2,null,2,1,0,1,0,1,1\n" + + "4,null,8,1,0,1,0,1,1\n" + "5,null,13,1,0,1,0,1,1\n" + "3,null,5,1,0,1,0,1,1\n" + + "null,Comment#11,17,2,1,0,1,0,2\n" + "null,Comment#8,14,2,1,0,1,0,2\n" + + "null,Comment#2,8,2,1,0,1,0,2\n" + "null,Comment#1,7,2,1,0,1,0,2\n" + + "null,Comment#14,20,2,1,0,1,0,2\n" + "null,Comment#7,13,2,1,0,1,0,2\n" + + "null,Comment#6,12,2,1,0,1,0,2\n" + "null,Comment#3,9,2,1,0,1,0,2\n" + + "null,Comment#12,18,2,1,0,1,0,2\n" + "null,Comment#5,11,2,1,0,1,0,2\n" + + "null,Comment#15,21,2,1,0,1,0,2\n" + "null,Comment#4,10,2,1,0,1,0,2\n" + + "null,Hi,1,2,1,0,1,0,2\n" + "null,Comment#10,16,2,1,0,1,0,2\n" + + "null,Hello world,3,2,1,0,1,0,2\n" + "null,I am fine.,5,2,1,0,1,0,2\n" + + "null,Hello world, how are you?,4,2,1,0,1,0,2\n" + "null,Comment#9,15,2,1,0,1,0,2\n" + + "null,Comment#13,19,2,1,0,1,0,2\n" + "null,Luke Skywalker,6,2,1,0,1,0,2\n" + + "null,Hello,2,2,1,0,1,0,2" val results = t.toDataSet[Row].collect() TestBaseUtils.compareResultAsText(results.asJava, expected) From 100c92ada4c1577fb8c6a257346e7cb02cd50a4f Mon Sep 17 00:00:00 2001 From: Aleksandr Chermenin Date: Wed, 25 Jan 2017 11:43:14 +0300 Subject: [PATCH 14/18] [FLINK-2980] Restored TestBaseUtils. --- .../api/scala/batch/table/GroupingSetsTest.scala | 15 +++++++++++++-- .../org/apache/flink/test/util/TestBaseUtils.java | 12 +----------- 2 files changed, 14 insertions(+), 13 deletions(-) diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/scala/batch/table/GroupingSetsTest.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/scala/batch/table/GroupingSetsTest.scala index ac549335cd77a..88c6fffd28096 100644 --- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/scala/batch/table/GroupingSetsTest.scala +++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/scala/batch/table/GroupingSetsTest.scala @@ -120,7 +120,7 @@ class GroupingSetsTest { val results1 = t1.toDataSet[Row].map(_.toString).collect() val results2 = t2.toDataSet[Row].map(_.toString).collect() - TestBaseUtils.compareResultCollections(results1.asJava, results2.asJava) + compareResultCollections(results1, results2) } @Test @@ -145,6 +145,17 @@ class GroupingSetsTest { val results1 = t1.toDataSet[Row].map(_.toString).collect() val results2 = t2.toDataSet[Row].map(_.toString).collect() - TestBaseUtils.compareResultCollections(results1.asJava, results2.asJava) + compareResultCollections(results1, results2) + } + + def compareResultCollections(expected: Seq[String], actual: Seq[String]): Unit = { + Assert.assertEquals(expected.size, actual.size) + val expectedSorted = expected.sorted + val actualSorted = actual.sorted + expectedSorted.zip(actualSorted) + .foreach { + case (expectedString, actualString) => + Assert.assertEquals(expectedString, actualString) + } } } diff --git a/flink-test-utils-parent/flink-test-utils/src/main/java/org/apache/flink/test/util/TestBaseUtils.java b/flink-test-utils-parent/flink-test-utils/src/main/java/org/apache/flink/test/util/TestBaseUtils.java index 53efc01d49881..3726197690195 100644 --- a/flink-test-utils-parent/flink-test-utils/src/main/java/org/apache/flink/test/util/TestBaseUtils.java +++ b/flink-test-utils-parent/flink-test-utils/src/main/java/org/apache/flink/test/util/TestBaseUtils.java @@ -408,20 +408,10 @@ public static void compareResultCollections(List expected, List actual } } - public static > void compareResultCollections(List expected, List actual) { - Assert.assertEquals(expected.size(), actual.size()); - - Collections.sort(expected); - Collections.sort(actual); - - for (int i = 0; i < expected.size(); i++) { - Assert.assertEquals(expected.get(i), actual.get(i)); - } - } - private static File[] getAllInvolvedFiles(String resultPath, final String[] excludePrefixes) { final File result = asFile(resultPath); assertTrue("Result file was not written", result.exists()); + if (result.isDirectory()) { return result.listFiles(new FilenameFilter() { From 7d77b2ad35892260927c9f4cedb11cd763aca19b Mon Sep 17 00:00:00 2001 From: Aleksandr Chermenin Date: Wed, 25 Jan 2017 12:21:18 +0300 Subject: [PATCH 15/18] [FLINK-2980] Merged grouping aggregations. --- .../org/apache/flink/table/api/table.scala | 13 +- .../flink/table/plan/logical/operators.scala | 202 ++---------------- 2 files changed, 21 insertions(+), 194 deletions(-) diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/table.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/table.scala index 08207fbe554ad..7db50cd984d98 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/table.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/table.scala @@ -24,10 +24,9 @@ import org.apache.calcite.tools.RelBuilder import org.apache.flink.api.common.typeinfo.TypeInformation import org.apache.flink.api.java.operators.join.JoinType import org.apache.flink.table.calcite.FlinkTypeFactory -import org.apache.flink.table.plan.logical.{Minus => MinusNode} import org.apache.flink.table.expressions._ import org.apache.flink.table.plan.ProjectionTranslator._ -import org.apache.flink.table.plan.logical._ +import org.apache.flink.table.plan.logical.{Minus => MinusNode, _} import org.apache.flink.table.sinks.TableSink import _root_.scala.collection.JavaConverters._ @@ -939,7 +938,7 @@ class GroupedTable( new Table(table.tableEnv, Project(projectsOnAgg, - Aggregate(groupKey, aggNames.map(a => Alias(a._1, a._2)).toSeq, + Aggregate(Seq(groupKey), aggNames.map(a => Alias(a._1, a._2)).toSeq, Project(projectFields, table.logicalPlan).validate(table.tableEnv) ).validate(table.tableEnv) ).validate(table.tableEnv)) @@ -1009,7 +1008,7 @@ class GroupWindowedTable( Project( projectsOnAgg, WindowAggregate( - groupKey, + Seq(groupKey), window.toLogicalWindow, propNames.map(a => Alias(a._1, a._2)).toSeq, aggNames.map(a => Alias(a._1, a._2)).toSeq, @@ -1072,8 +1071,8 @@ class GroupingSetsTable( val logical = Project(projectsOnAgg, - GroupingAggregation(groupingSets, aggNames.map(a => Alias(a._1, a._2)).toSeq, - Project(projectFields, table.logicalPlan).validate(table.tableEnv) + Aggregate(groupingSets, aggNames.map(a => Alias(a._1, a._2)).toSeq, + Project(projectFields, table.logicalPlan).validate(table.tableEnv) ).validate(table.tableEnv) ).validate(table.tableEnv) @@ -1153,7 +1152,7 @@ class GroupingSetsWindowedTable( val logical = Project(projectsOnAgg, - GroupingWindowAggregate( + WindowAggregate( groupingSets, window.toLogicalWindow, propNames.map(a => Alias(a._1, a._2)).toSeq, diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/logical/operators.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/logical/operators.scala index 5999e4f69a191..763c93490faed 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/logical/operators.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/logical/operators.scala @@ -23,13 +23,12 @@ import java.util import org.apache.calcite.rel.RelNode import org.apache.calcite.rel.`type`.RelDataType import org.apache.calcite.rel.core.CorrelationId -import org.apache.calcite.rel.logical.{LogicalProject, LogicalTableFunctionScan} +import org.apache.calcite.rel.logical.LogicalTableFunctionScan import org.apache.calcite.rex.{RexInputRef, RexNode} import org.apache.calcite.tools.RelBuilder import org.apache.flink.api.common.typeinfo.BasicTypeInfo._ import org.apache.flink.api.common.typeinfo.TypeInformation import org.apache.flink.api.java.operators.join.JoinType -import org.apache.flink.table._ import org.apache.flink.table.api.{StreamTableEnvironment, TableEnvironment, UnresolvedException} import org.apache.flink.table.calcite.{FlinkRelBuilder, FlinkTypeFactory} import org.apache.flink.table.expressions._ @@ -98,13 +97,8 @@ case class Project(projectList: Seq[NamedExpression], child: LogicalNode) extend case Alias(gf: GroupFunction, name, extraNames) => children.head match { - case GroupingAggregation(grpExps, _, node) => - Alias(gf.replaceExpression(relBuilder, - Some(grpExps.flatten.distinct), node.output, indicator = true), - name, extraNames).toRexNode(relBuilder) - case Aggregate(grpExps, _, node) => - Alias(gf.replaceExpression(relBuilder, Some(grpExps), node.output), + Alias(gf.replaceExpression(relBuilder, Some(grpExps.flatten.distinct), node.output), name, extraNames).toRexNode(relBuilder) case _ => @@ -218,74 +212,9 @@ case class Filter(condition: Expression, child: LogicalNode) extends UnaryNode { } case class Aggregate( - groupingExpressions: Seq[Expression], - aggregateExpressions: Seq[NamedExpression], - child: LogicalNode) extends UnaryNode { - - override def output: Seq[Attribute] = { - (groupingExpressions ++ aggregateExpressions) map { - case ne: NamedExpression => ne.toAttribute - case e => Alias(e, e.toString).toAttribute - } - } - - override protected[logical] def construct(relBuilder: RelBuilder): RelBuilder = { - child.construct(relBuilder) - relBuilder.aggregate( - relBuilder.groupKey(groupingExpressions.map(_.toRexNode(relBuilder)).asJava), - aggregateExpressions.map { - case Alias(agg: Aggregation, name, _) => agg.toAggCall(name)(relBuilder) - case _ => throw new RuntimeException("This should never happen.") - }.asJava) - } - - override def validate(tableEnv: TableEnvironment): LogicalNode = { - if (tableEnv.isInstanceOf[StreamTableEnvironment]) { - failValidation(s"Aggregate on stream tables is currently not supported.") - } - - val resolvedAggregate = super.validate(tableEnv).asInstanceOf[Aggregate] - val groupingExprs = resolvedAggregate.groupingExpressions - val aggregateExprs = resolvedAggregate.aggregateExpressions - aggregateExprs.foreach(validateAggregateExpression) - groupingExprs.foreach(validateGroupingExpression) - - def validateAggregateExpression(expr: Expression): Unit = expr match { - // check no nested aggregation exists. - case aggExpr: Aggregation => - aggExpr.children.foreach { child => - child.preOrderVisit { - case agg: Aggregation => - failValidation( - "It's not allowed to use an aggregate function as " + - "input of another aggregate function") - case _ => // OK - } - } - case a: Attribute if !groupingExprs.exists(_.checkEquals(a)) => - failValidation( - s"expression '$a' is invalid because it is neither" + - " present in group by nor an aggregate function") - case e if groupingExprs.exists(_.checkEquals(e)) => // OK - case e => e.children.foreach(validateAggregateExpression) - } - - def validateGroupingExpression(expr: Expression): Unit = { - if (!expr.resultType.isKeyType) { - failValidation( - s"expression $expr cannot be used as a grouping expression " + - "because it's not a valid key type which must be hashable and comparable") - } - } - resolvedAggregate - } -} - -case class GroupingAggregation( groupingExpressions: Seq[Seq[Expression]], aggregateExpressions: Seq[NamedExpression], - child: LogicalNode - ) extends UnaryNode { + child: LogicalNode) extends UnaryNode { override def output: Seq[Attribute] = { (groupingExpressions.flatten.distinct ++ aggregateExpressions) map { @@ -299,7 +228,8 @@ case class GroupingAggregation( relBuilder.aggregate( relBuilder.groupKey( groupingExpressions.flatten.distinct.map(_.toRexNode(relBuilder)).asJava, - true, groupingExpressions.map(_.map(_.toRexNode(relBuilder)).asJava).asJava + groupingExpressions.size > 1, + groupingExpressions.map(_.map(_.toRexNode(relBuilder)).asJava).asJava ), aggregateExpressions.map { case Alias(agg: Aggregation, name, _) => agg.toAggCall(name)(relBuilder) @@ -312,7 +242,7 @@ case class GroupingAggregation( failValidation(s"Aggregate on stream tables is currently not supported.") } - val resolvedAggregate = super.validate(tableEnv).asInstanceOf[GroupingAggregation] + val resolvedAggregate = super.validate(tableEnv).asInstanceOf[Aggregate] val groupingExprs = resolvedAggregate.groupingExpressions val aggregateExprs = resolvedAggregate.aggregateExpressions val resolvedGroupingExprs = groupingExprs.map(_.map { @@ -350,7 +280,7 @@ case class GroupingAggregation( "because it's not a valid key type which must be hashable and comparable") } } - GroupingAggregation(resolvedGroupingExprs, resolvedAggregate.aggregateExpressions, child) + Aggregate(resolvedGroupingExprs, resolvedAggregate.aggregateExpressions, child) } } @@ -605,7 +535,7 @@ case class LogicalRelNode( } case class WindowAggregate( - groupingExpressions: Seq[Expression], + groupingExpressions: Seq[Seq[Expression]], window: LogicalWindow, propertyExpressions: Seq[NamedExpression], aggregateExpressions: Seq[NamedExpression], @@ -613,7 +543,7 @@ case class WindowAggregate( extends UnaryNode { override def output: Seq[Attribute] = { - (groupingExpressions ++ aggregateExpressions ++ propertyExpressions) map { + (groupingExpressions.flatten.distinct ++ aggregateExpressions ++ propertyExpressions) map { case ne: NamedExpression => ne.toAttribute case e => Alias(e, e.toString).toAttribute } @@ -647,114 +577,12 @@ case class WindowAggregate( override protected[logical] def construct(relBuilder: RelBuilder): RelBuilder = { val flinkRelBuilder = relBuilder.asInstanceOf[FlinkRelBuilder] child.construct(flinkRelBuilder) - flinkRelBuilder.aggregate( - window, - relBuilder.groupKey(groupingExpressions.map(_.toRexNode(relBuilder)).asJava), - propertyExpressions.map { - case Alias(prop: WindowProperty, name, _) => prop.toNamedWindowProperty(name)(relBuilder) - case _ => throw new RuntimeException("This should never happen.") - }, - aggregateExpressions.map { - case Alias(agg: Aggregation, name, _) => agg.toAggCall(name)(relBuilder) - case _ => throw new RuntimeException("This should never happen.") - }.asJava) - } - - override def validate(tableEnv: TableEnvironment): LogicalNode = { - val resolvedWindowAggregate = super.validate(tableEnv).asInstanceOf[WindowAggregate] - val groupingExprs = resolvedWindowAggregate.groupingExpressions - val aggregateExprs = resolvedWindowAggregate.aggregateExpressions - aggregateExprs.foreach(validateAggregateExpression) - groupingExprs.foreach(validateGroupingExpression) - - def validateAggregateExpression(expr: Expression): Unit = expr match { - // check no nested aggregation exists. - case aggExpr: Aggregation => - aggExpr.children.foreach { child => - child.preOrderVisit { - case agg: Aggregation => - failValidation( - "It's not allowed to use an aggregate function as " + - "input of another aggregate function") - case _ => // ok - } - } - case a: Attribute if !groupingExprs.exists(_.checkEquals(a)) => - failValidation( - s"Expression '$a' is invalid because it is neither" + - " present in group by nor an aggregate function") - case e if groupingExprs.exists(_.checkEquals(e)) => // ok - case e => e.children.foreach(validateAggregateExpression) - } - - def validateGroupingExpression(expr: Expression): Unit = { - if (!expr.resultType.isKeyType) { - failValidation( - s"Expression $expr cannot be used as a grouping expression " + - "because it's not a valid key type which must be hashable and comparable") - } - } - - // validate window - resolvedWindowAggregate.window.validate(tableEnv) match { - case ValidationFailure(msg) => - failValidation(s"$window is invalid: $msg") - case ValidationSuccess => // ok - } - - resolvedWindowAggregate - } -} - -case class GroupingWindowAggregate( - groupingExpressions: Seq[Seq[Expression]], - window: LogicalWindow, - propertyExpressions: Seq[NamedExpression], - aggregateExpressions: Seq[NamedExpression], - child: LogicalNode - ) extends UnaryNode { - - override def output: Seq[Attribute] = { - (groupingExpressions.flatten.distinct ++ aggregateExpressions ++ propertyExpressions) map { - case ne: NamedExpression => ne.toAttribute - case e => Alias(e, e.toString).toAttribute - } - } - - // resolve references of this operator's parameters - override def resolveReference( - tableEnv: TableEnvironment, - name: String) - : Option[NamedExpression] = tableEnv match { - // resolve reference to rowtime attribute in a streaming environment - case _: StreamTableEnvironment if name == "rowtime" => - Some(RowtimeAttribute()) - case _ => - window.alias match { - // resolve reference to this window's alias - case Some(UnresolvedFieldReference(alias)) if name == alias => - // check if reference can already be resolved by input fields - val found = super.resolveReference(tableEnv, name) - if (found.isDefined) { - failValidation(s"Reference $name is ambiguous.") - } else { - Some(WindowReference(name)) - } - case _ => - // resolve references as usual - super.resolveReference(tableEnv, name) - } - } - - override protected[logical] def construct(relBuilder: RelBuilder): RelBuilder = { - val flinkRelBuilder = relBuilder.asInstanceOf[FlinkRelBuilder] - child.construct(flinkRelBuilder) - val groupingSets = groupingExpressions.map(_.map(_.toRexNode(relBuilder)).toList).toList flinkRelBuilder.aggregate( window, relBuilder.groupKey( - groupingExpressions.head.map(_.toRexNode(relBuilder)).asJava, - true, groupingSets.map(_.asJava).asJava + groupingExpressions.flatten.distinct.map(_.toRexNode(relBuilder)).asJava, + groupingExpressions.size > 1, + groupingExpressions.map(_.map(_.toRexNode(relBuilder)).asJava).asJava ), propertyExpressions.map { case Alias(prop: WindowProperty, name, _) => prop.toNamedWindowProperty(name)(relBuilder) @@ -767,7 +595,7 @@ case class GroupingWindowAggregate( } override def validate(tableEnv: TableEnvironment): LogicalNode = { - val resolvedWindowAggregate = super.validate(tableEnv).asInstanceOf[GroupingWindowAggregate] + val resolvedWindowAggregate = super.validate(tableEnv).asInstanceOf[WindowAggregate] val groupingExprs = resolvedWindowAggregate.groupingExpressions val aggregateExprs = resolvedWindowAggregate.aggregateExpressions val resolvedGroupingExprs = groupingExprs.map(_.map { @@ -813,8 +641,8 @@ case class GroupingWindowAggregate( case ValidationSuccess => // ok } - GroupingWindowAggregate(resolvedGroupingExprs, window, propertyExpressions, - resolvedWindowAggregate.aggregateExpressions, child) + WindowAggregate(resolvedGroupingExprs, window, propertyExpressions, + resolvedWindowAggregate.aggregateExpressions, child) } } From 183dd2cb0ecc9cd6b5f416965a324d54ec553c9b Mon Sep 17 00:00:00 2001 From: Aleksandr Chermenin Date: Thu, 26 Jan 2017 09:12:42 +0300 Subject: [PATCH 16/18] [FLINK-2980] Fixed GroupFunction. --- .../flink/table/expressions/groupings.scala | 79 +++++++++++-------- .../flink/table/plan/logical/operators.scala | 14 +--- 2 files changed, 48 insertions(+), 45 deletions(-) diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/expressions/groupings.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/expressions/groupings.scala index d61dca39690e4..5224646e63220 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/expressions/groupings.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/expressions/groupings.scala @@ -18,34 +18,60 @@ package org.apache.flink.table.expressions +import org.apache.calcite.rel.core.Aggregate +import org.apache.calcite.rex.RexNode import org.apache.calcite.tools.RelBuilder +import org.apache.calcite.util.ImmutableBitSet import org.apache.flink.api.common.typeinfo.BasicTypeInfo +import scala.collection.JavaConversions._ + abstract sealed class GroupFunction extends Expression { override def toString = s"GroupFunction($children)" + override private[flink] def toRexNode(implicit relBuilder: RelBuilder): RexNode = { + val child = relBuilder.peek() + child match { + case a: Aggregate => + val groupSet = a.getGroupSet + + val inputFields = a.getInput.getRowType.getFieldList.toList.map(_.getName) + val outputFields = a.getRowType.getFieldList.toList.map(_.getName) + + val internalFields = + getInternalFieldNames(inputFields) + .filter { + case (_, v) => outputFields.contains(v) + } + + replaceExpression(relBuilder, Some(groupSet), internalFields, a.indicator) + .toRexNode(relBuilder) + case _ => + replaceExpression(relBuilder, None).toRexNode(relBuilder) + } + } + private[flink] def replaceExpression( relBuilder: RelBuilder, - groupExpressions: Option[Seq[Expression]], - children: Seq[Attribute] = Seq(), + groupSet: Option[ImmutableBitSet], + internalFields: Map[String, String] = Map(), indicator: Boolean = false): Expression = { - if (groupExpressions.isDefined) { - val expressions = groupExpressions.get + if (groupSet.isDefined) { + val groups = groupSet.get if (!indicator) { Cast( - Minus(Power(Literal(2), Literal(getEffectiveArgCount(expressions))), Literal(1)), + Minus(Power(Literal(2), Literal(getEffectiveArgCount(groups))), Literal(1)), BasicTypeInfo.LONG_TYPE_INFO ) } else { - val operands = getOperands(expressions) - val internalFieldsMap = getInternalFields(children) + val operands = getOperands(internalFields) var shift = operands.size var expression: Option[Expression] = None operands.foreach(x => { shift -= 1 - expression = bitValue(relBuilder, expression, x, shift, expressions, internalFieldsMap) + expression = bitValue(relBuilder, expression, x, shift, internalFields) }) Cast(expression.get, BasicTypeInfo.LONG_TYPE_INFO) } @@ -54,8 +80,7 @@ abstract sealed class GroupFunction extends Expression { } } - private def getInternalFields(children: Seq[Attribute]) = { - val inputFields = children.map(_.name) + private def getInternalFieldNames(inputFields: List[String]) = { inputFields.map(inputFieldName => { val base = "i$" + inputFieldName var name = base @@ -70,18 +95,13 @@ abstract sealed class GroupFunction extends Expression { private def bitValue(relBuilder: RelBuilder, expression: Option[Expression], operand: Int, - shift: Int, expressions: Seq[Expression], - internalFieldsMap: Map[String, String] + shift: Int, internalFieldsMap: Map[String, String] ): Option[Expression] = { - val fieldName = expressions(operand) match { - case ne: NamedExpression => ne.name - case _ => "" - } + val fieldName = internalFieldsMap.values.toList.get(operand) var nextExpression: Expression = - If(IsTrue(ResolvedFieldReference( - internalFieldsMap(fieldName), BasicTypeInfo.BOOLEAN_TYPE_INFO)), + If(IsTrue(ResolvedFieldReference(fieldName, BasicTypeInfo.BOOLEAN_TYPE_INFO)), Literal(1), Literal(0)) if (shift > 0) { @@ -95,10 +115,13 @@ abstract sealed class GroupFunction extends Expression { Some(nextExpression) } - protected def getEffectiveArgCount(groupExpressions: Seq[Expression]): Int + protected def getEffectiveArgCount(groupSet: ImmutableBitSet): Int = { + groupSet.toList.size() + } - protected def getOperands(groupExpressions: Seq[Expression]): Seq[Int] = { - children.map(e => groupExpressions.indexOf(e)) + protected def getOperands(fields: Map[String, String]): Seq[Int] = { + val keys = fields.keys.toList + children.map(e => keys.indexOf(e.asInstanceOf[NamedExpression].name)) } } @@ -108,12 +131,9 @@ case class GroupId() extends GroupFunction { override private[flink] def children = Nil - override protected def getEffectiveArgCount(groupExpressions: Seq[Expression]): Int = { - groupExpressions.size + override protected def getOperands(fields: Map[String, String]): Seq[Int] = { + fields.values.toList.indices } - - override protected def getOperands(groupExpressions: Seq[Expression]): Seq[Int] = - groupExpressions.indices } case class Grouping(expression: Expression) extends GroupFunction { @@ -122,7 +142,7 @@ case class Grouping(expression: Expression) extends GroupFunction { override private[flink] def children = Seq(expression) - override protected def getEffectiveArgCount(groupExpressions: Seq[Expression]): Int = 1 + override protected def getEffectiveArgCount(groupSet: ImmutableBitSet): Int = 1 } case class GroupingId(expressions: Expression*) extends GroupFunction { @@ -130,9 +150,4 @@ case class GroupingId(expressions: Expression*) extends GroupFunction { override private[flink] def resultType = BasicTypeInfo.LONG_TYPE_INFO override private[flink] def children = expressions - - override protected def getEffectiveArgCount(groupExpressions: Seq[Expression]): Int = { - expressions.size - } } - diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/logical/operators.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/logical/operators.scala index 763c93490faed..cd31bb88ae70a 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/logical/operators.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/logical/operators.scala @@ -93,19 +93,7 @@ case class Project(projectList: Seq[NamedExpression], child: LogicalNode) extend override protected[logical] def construct(relBuilder: RelBuilder): RelBuilder = { child.construct(relBuilder) relBuilder.project( - projectList.map { - case Alias(gf: GroupFunction, name, extraNames) => - children.head match { - - case Aggregate(grpExps, _, node) => - Alias(gf.replaceExpression(relBuilder, Some(grpExps.flatten.distinct), node.output), - name, extraNames).toRexNode(relBuilder) - - case _ => - Alias(gf.replaceExpression(relBuilder, None), name, extraNames).toRexNode(relBuilder) - } - case x => x.toRexNode(relBuilder) - }.asJava, + projectList.map(_.toRexNode(relBuilder)).asJava, projectList.map(_.name).asJava, true) } From 53e8f95dfa54b34f727d6d3b167c2621a15d7c54 Mon Sep 17 00:00:00 2001 From: Aleksandr Chermenin Date: Mon, 30 Jan 2017 15:01:39 +0300 Subject: [PATCH 17/18] [FLINK-2980] Some fixes. --- .../flink/table/api/scala/expressionDsl.scala | 18 ++-- .../flink/table/api/scala/package.scala | 2 +- .../org/apache/flink/table/api/table.scala | 92 +------------------ .../flink/table/expressions/Expression.scala | 20 ++-- .../table/expressions/ExpressionParser.scala | 31 ++----- .../flink/table/expressions/groupings.scala | 44 ++++----- .../table/validate/FunctionCatalog.scala | 5 + .../scala/batch/table/GroupingSetsTest.scala | 6 +- 8 files changed, 53 insertions(+), 165 deletions(-) diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/scala/expressionDsl.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/scala/expressionDsl.scala index f5fd8bc3d1a79..a189e4e6fef32 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/scala/expressionDsl.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/scala/expressionDsl.scala @@ -585,6 +585,12 @@ trait ImplicitExpressionConversions { implicit def sqlTimestamp2Literal(sqlTimestamp: Timestamp): Expression = Literal(sqlTimestamp) implicit def array2ArrayConstructor(array: Array[_]): Expression = convertArray(array) + + implicit def unitToGroupedExpression(unit: Unit): GroupedExpression = + new GroupedExpression(Seq()) + + implicit def productToGroupedExpression(product: Product): GroupedExpression = + new GroupedExpression(product) } /** @@ -593,7 +599,8 @@ trait ImplicitExpressionConversions { trait ImplicitGroupedOperations { private[flink] def expr: GroupedExpression - implicit class UnitAsGroupedExpression(unit: Unit) extends ImplicitGroupedOperations { + implicit class UnitAsGroupedExpression(unit: Unit) + extends ImplicitGroupedOperations { override private[flink] def expr = new GroupedExpression(Seq()) } @@ -603,15 +610,6 @@ trait ImplicitGroupedOperations { } } -trait ImplicitGroupedConversions { - - implicit def unitToGroupedExpression(unit: Unit): GroupedExpression = - new GroupedExpression(Seq()) - - implicit def productToGroupedExpression(product: Product): GroupedExpression = - new GroupedExpression(product) -} - // ------------------------------------------------------------------------------------------------ // Expressions with no parameters // ------------------------------------------------------------------------------------------------ diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/scala/package.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/scala/package.scala index 72dfaeb29cd49..cd341cbe5ba48 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/scala/package.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/scala/package.scala @@ -66,7 +66,7 @@ import _root_.scala.reflect.ClassTag * }}} * */ -package object scala extends ImplicitExpressionConversions with ImplicitGroupedConversions { +package object scala extends ImplicitExpressionConversions { implicit def table2TableConversions(table: Table): TableConversions = { new TableConversions(table) diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/table.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/table.scala index 7db50cd984d98..912b739e1756c 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/table.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/table.scala @@ -24,9 +24,9 @@ import org.apache.calcite.tools.RelBuilder import org.apache.flink.api.common.typeinfo.TypeInformation import org.apache.flink.api.java.operators.join.JoinType import org.apache.flink.table.calcite.FlinkTypeFactory -import org.apache.flink.table.expressions._ +import org.apache.flink.table.expressions.{Minus => _, _} import org.apache.flink.table.plan.ProjectionTranslator._ -import org.apache.flink.table.plan.logical.{Minus => MinusNode, _} +import org.apache.flink.table.plan.logical._ import org.apache.flink.table.sinks.TableSink import _root_.scala.collection.JavaConverters._ @@ -549,7 +549,7 @@ class Table( throw new ValidationException("Only tables from the same TableEnvironment can be " + "subtracted.") } - new Table(tableEnv, MinusNode(logicalPlan, right.logicalPlan, all = false) + new Table(tableEnv, Minus(logicalPlan, right.logicalPlan, all = false) .validate(tableEnv)) } @@ -574,7 +574,7 @@ class Table( throw new ValidationException("Only tables from the same TableEnvironment can be " + "subtracted.") } - new Table(tableEnv, MinusNode(logicalPlan, right.logicalPlan, all = true) + new Table(tableEnv, Minus(logicalPlan, right.logicalPlan, all = true) .validate(tableEnv)) } @@ -1093,89 +1093,5 @@ class GroupingSetsTable( val fieldExprs = ExpressionParser.parseExpressionList(fields) select(fieldExprs: _*) } - - /** - * Groups the records of a table by assigning them to windows defined by a time or row interval. - * - * For streaming tables of infinite size, grouping into windows is required to define finite - * groups on which group-based aggregates can be computed. - * - * For batch tables of finite size, windowing essentially provides shortcuts for time-based - * groupBy. - * - * @param groupWindow group-window that specifies how elements are grouped. - * @return A windowed table. - */ - def window(groupWindow: GroupWindow): GroupingSetsWindowedTable = { - if (table.tableEnv.isInstanceOf[BatchTableEnvironment]) { - throw new ValidationException(s"Windows on batch tables are currently not supported.") - } - new GroupingSetsWindowedTable(table, groups, sqlKind, groupWindow) - } } -class GroupingSetsWindowedTable( - private[flink] val table: Table, - private[flink] val groups: Seq[Seq[Expression]], - private[flink] val sqlKind: SqlKind, - private[flink] val window: GroupWindow) { - - /** - * Performs a selection operation on a windowed table. Similar to an SQL SELECT statement. - * The field expressions can contain complex expressions and aggregations. - * - * Example: - * - * {{{ - * groupWindowTable.select('key, 'window.start, 'value.avg + " The average" as 'average) - * }}} - */ - def select(fields: Expression*): Table = { - - val (aggNames, propNames) = extractAggregationsAndProperties(fields, table.tableEnv) - val projectsOnAgg = replaceAggregationsAndProperties( - fields, table.tableEnv, aggNames, propNames) - - val groupingSets = sqlKind match { - case SqlKind.CUBE => ExpressionUtils.cube(groups) - case SqlKind.ROLLUP => ExpressionUtils.rollup(groups) - case _ => groups - } - - val projectFields = (table.tableEnv, window) match { - // event time can be arbitrary field in batch environment - case (_: BatchTableEnvironment, w: EventTimeWindow) => - extractFieldReferences(fields ++ groupingSets.flatten.distinct ++ Seq(w.timeField)) - case (_, _) => - extractFieldReferences(fields ++ groupingSets.flatten.distinct) - } - - val logical = - Project(projectsOnAgg, - WindowAggregate( - groupingSets, - window.toLogicalWindow, - propNames.map(a => Alias(a._1, a._2)).toSeq, - aggNames.map(a => Alias(a._1, a._2)).toSeq, - Project(projectFields, table.logicalPlan).validate(table.tableEnv) - ).validate(table.tableEnv) - ).validate(table.tableEnv) - - new Table(table.tableEnv, logical) - } - - /** - * Performs a selection operation on a group-windows table. Similar to an SQL SELECT statement. - * The field expressions can contain complex expressions and aggregations. - * - * Example: - * - * {{{ - * groupWindowTable.select("key, window.start, value.avg + ' The average' as average") - * }}} - */ - def select(fields: String): Table = { - val fieldExprs = ExpressionParser.parseExpressionList(fields) - select(fieldExprs: _*) - } -} diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/expressions/Expression.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/expressions/Expression.scala index 0fc484f477a57..edac4fccea243 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/expressions/Expression.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/expressions/Expression.scala @@ -20,6 +20,7 @@ package org.apache.flink.table.expressions import org.apache.calcite.rex.RexNode import org.apache.calcite.tools.RelBuilder import org.apache.flink.api.common.typeinfo.TypeInformation +import org.apache.flink.table.api.TableException import org.apache.flink.table.plan.TreeNode import org.apache.flink.table.validate.{ValidationResult, ValidationSuccess} @@ -86,7 +87,7 @@ abstract class LeafExpression extends Expression { private[flink] val children = Nil } -class GroupedExpression( +case class GroupedExpression( private[flink] val children: Seq[Expression] ) extends Expression { @@ -97,7 +98,8 @@ class GroupedExpression( case e: Expression => e case s: Symbol => UnresolvedFieldReference(s.name) case p: Product => new GroupedExpression(p) - case _ => throw new IllegalArgumentException() + case x => throw new TableException( + "Unexpected field '" + x + "' in group of expressions " + product.toString) }.toSeq ) } @@ -113,16 +115,6 @@ class GroupedExpression( * Returns the [[TypeInformation]] for evaluating this expression. * It's not applicable for grouped expressions. */ - override private[flink] def resultType = ??? - - /** - * Grouping function. Similar to a SQL GROUPING_ID function. - */ - def groupingId(): Expression = GroupingId(children: _*) - - override def productElement(n: Int): Expression = children(n) - - override def productArity: Int = children.length - - override def canEqual(that: Any) = false + override private[flink] def resultType = throw new UnsupportedOperationException( + "Result type can not be resolved from grouped expressions.") } diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/expressions/ExpressionParser.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/expressions/ExpressionParser.scala index 4d55e79227a4e..bf231f58f2a8b 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/expressions/ExpressionParser.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/expressions/ExpressionParser.scala @@ -180,10 +180,10 @@ object ExpressionParser extends JavaTokenParsers with PackratParsers { ( "(" ~> expression <~ ")" ) | literalExpr | fieldReference lazy val grouped: PackratParser[Expression] = - "(" ~> expressionList <~ ")" ^^ { l => new GroupedExpression(l.toSeq) } + "(" ~> expressionList <~ ")" ^^ { l => GroupedExpression(l) } - lazy val unit: PackratParser[Expression] = - "()" ^^ { _ => new GroupedExpression(Seq()) } + lazy val empty: PackratParser[Expression] = + "()" ^^ { _ => GroupedExpression(Seq()) } // suffix operators @@ -298,21 +298,12 @@ object ExpressionParser extends JavaTokenParsers with PackratParsers { lazy val suffixFlattening: PackratParser[Expression] = composite <~ "." ~ FLATTEN ~ opt("()") ^^ { e => Flattening(e) } - lazy val suffixGrouping: PackratParser[Expression] = - composite <~ "." ~ GROUPING ~ opt("()") ^^ { e => Grouping(e) } - - lazy val suffixGroupingId: PackratParser[Expression] = - composite <~ "." ~ GROUPING_ID ~ opt("()") ^^ { - case g: GroupedExpression => GroupingId(g.flatChildren: _*) - case e => GroupingId(e) - } - lazy val suffixed: PackratParser[Expression] = suffixTimeInterval | suffixRowInterval | suffixSum | suffixMin | suffixMax | suffixStart | suffixEnd | suffixCount | suffixAvg | suffixCast | suffixAs | suffixTrim | suffixTrimWithoutArgs | suffixIf | suffixAsc | suffixDesc | suffixToDate | suffixToTimestamp | suffixToTime | suffixExtract | suffixFloor | suffixCeil | - suffixGet | suffixFlattening | suffixGroupingId | suffixGrouping | + suffixGet | suffixFlattening | suffixFunctionCall | suffixFunctionCallOneArg // function call must always be at the end // prefix operators @@ -393,26 +384,16 @@ object ExpressionParser extends JavaTokenParsers with PackratParsers { lazy val prefixFlattening: PackratParser[Expression] = FLATTEN ~ "(" ~> composite <~ ")" ^^ { e => Flattening(e) } - lazy val prefixGrouping: PackratParser[Expression] = - GROUPING ~ "(" ~> composite <~ ")" ^^ { e => Grouping(e) } - - lazy val prefixGroupingId: PackratParser[Expression] = - GROUPING_ID ~ "(" ~> repsep(expression, ",") <~ ")" ^^ { l => GroupingId(l: _*) } - - lazy val prefixGroupId: PackratParser[Expression] = - GROUP_ID ~ opt("()") ^^ { _ => GroupId() } - lazy val prefixed: PackratParser[Expression] = prefixArray | prefixSum | prefixMin | prefixMax | prefixCount | prefixAvg | prefixStart | prefixEnd | prefixCast | prefixAs | prefixTrim | prefixTrimWithoutArgs | prefixIf | prefixExtract | prefixFloor | prefixCeil | prefixGet | prefixFlattening | - prefixGroupingId | prefixGrouping | prefixGroupId | prefixFunctionCall | prefixFunctionCallOneArg // function call must always be at the end // suffix/prefix composite - lazy val composite: PackratParser[Expression] = suffixed | prefixed | atom | grouped | + lazy val composite: PackratParser[Expression] = suffixed | prefixed | atom | failure("Composite expression expected.") // unary ops @@ -484,7 +465,7 @@ object ExpressionParser extends JavaTokenParsers with PackratParsers { case e ~ _ ~ _ ~ names ~ _ => Alias(e, names.head.name, names.tail.map(_.name)) } | logic - lazy val expression: PackratParser[Expression] = alias | grouped | unit | + lazy val expression: PackratParser[Expression] = alias | grouped | empty | failure("Invalid expression.") lazy val expressionList: Parser[List[Expression]] = rep1sep(expression, ",") diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/expressions/groupings.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/expressions/groupings.scala index 5224646e63220..c399dac70b591 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/expressions/groupings.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/expressions/groupings.scala @@ -23,6 +23,7 @@ import org.apache.calcite.rex.RexNode import org.apache.calcite.tools.RelBuilder import org.apache.calcite.util.ImmutableBitSet import org.apache.flink.api.common.typeinfo.BasicTypeInfo +import org.apache.flink.table.api.TableException import scala.collection.JavaConversions._ @@ -41,42 +42,37 @@ abstract sealed class GroupFunction extends Expression { val internalFields = getInternalFieldNames(inputFields) - .filter { - case (_, v) => outputFields.contains(v) - } + .filter(t => outputFields.contains(t._1)) - replaceExpression(relBuilder, Some(groupSet), internalFields, a.indicator) + replaceExpression(relBuilder, groupSet, internalFields, a.indicator) .toRexNode(relBuilder) + case _ => - replaceExpression(relBuilder, None).toRexNode(relBuilder) + throw new TableException("GROUPING functions only supported with " + + "GROUP BY GROUPING SETS, CUBE or ROLLUP") } } private[flink] def replaceExpression( relBuilder: RelBuilder, - groupSet: Option[ImmutableBitSet], + groupSet: ImmutableBitSet, internalFields: Map[String, String] = Map(), indicator: Boolean = false): Expression = { - if (groupSet.isDefined) { - val groups = groupSet.get - if (!indicator) { - Cast( - Minus(Power(Literal(2), Literal(getEffectiveArgCount(groups))), Literal(1)), - BasicTypeInfo.LONG_TYPE_INFO - ) - } else { - val operands = getOperands(internalFields) - var shift = operands.size - var expression: Option[Expression] = None - operands.foreach(x => { - shift -= 1 - expression = bitValue(relBuilder, expression, x, shift, internalFields) - }) - Cast(expression.get, BasicTypeInfo.LONG_TYPE_INFO) - } + if (!indicator) { + Cast( + Minus(Power(Literal(2), Literal(getEffectiveArgCount(groupSet))), Literal(1)), + BasicTypeInfo.LONG_TYPE_INFO + ) } else { - this + val operands = getOperands(internalFields) + var shift = operands.size + var expression: Option[Expression] = None + operands.foreach(x => { + shift -= 1 + expression = bitValue(relBuilder, expression, x, shift, internalFields) + }) + Cast(expression.get, BasicTypeInfo.LONG_TYPE_INFO) } } diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/validate/FunctionCatalog.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/validate/FunctionCatalog.scala index c00f8bbd7aa09..8d95d2cfb31ed 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/validate/FunctionCatalog.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/validate/FunctionCatalog.scala @@ -154,6 +154,11 @@ object FunctionCatalog { "min" -> classOf[Min], "sum" -> classOf[Sum], + // grouping function + "group_id" -> classOf[GroupId], + "grouping" -> classOf[Grouping], + "grouping_id" -> classOf[GroupingId], + // string functions "charLength" -> classOf[CharLength], "initCap" -> classOf[InitCap], diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/scala/batch/table/GroupingSetsTest.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/scala/batch/table/GroupingSetsTest.scala index 88c6fffd28096..f0eff50cba63d 100644 --- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/scala/batch/table/GroupingSetsTest.scala +++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/scala/batch/table/GroupingSetsTest.scala @@ -56,7 +56,7 @@ class GroupingSetsTest { 'b, 'c, 'a.avg as 'a, groupId() as 'g, 'b.grouping() as 'gb, grouping('c) as 'gc, 'b.groupingId() as 'gib, groupingId('c) as 'gic, - ('b, 'c).groupingId() as 'gid + groupingId('b, 'c) as 'gid ) val expected = @@ -106,7 +106,7 @@ class GroupingSetsTest { 'b, 'c, 'a.avg as 'a, groupId() as 'g, 'b.grouping() as 'gb, grouping('c) as 'gc, 'b.groupingId() as 'gib, groupingId('c) as 'gic, - ('b, 'c).groupingId() as 'gid + groupingId('b, 'c) as 'gid ) val t2 = table @@ -131,7 +131,7 @@ class GroupingSetsTest { 'b, 'c, 'a.avg as 'a, groupId() as 'g, 'b.grouping() as 'gb, grouping('c) as 'gc, 'b.groupingId() as 'gib, groupingId('c) as 'gic, - ('b, 'c).groupingId() as 'gid + groupingId('b, 'c) as 'gid ) val t2 = table From 8e67b8190c3333d639bce4335256368d73464078 Mon Sep 17 00:00:00 2001 From: Aleksandr Chermenin Date: Mon, 6 Feb 2017 12:30:33 +0300 Subject: [PATCH 18/18] [FLINK-2980] Some fixes and docs improvements. --- docs/dev/table_api.md | 42 +------- .../flink/table/api/scala/expressionDsl.scala | 12 +-- .../org/apache/flink/table/api/table.scala | 99 ++++--------------- .../table/expressions/ExpressionUtils.scala | 22 +++-- .../api/scala/batch}/GroupingSetsTest.scala | 2 +- 5 files changed, 40 insertions(+), 137 deletions(-) rename flink-libraries/flink-table/src/test/scala/org/apache/flink/{api/scala/batch/table => table/api/scala/batch}/GroupingSetsTest.scala (99%) diff --git a/docs/dev/table_api.md b/docs/dev/table_api.md index b6e9692a9d0e5..19b2583e675b8 100644 --- a/docs/dev/table_api.md +++ b/docs/dev/table_api.md @@ -804,7 +804,7 @@ val result = in.groupingSets(('a, 'b), 'b, ()).select('a, 'b, 'c.sum as 'd); Cube -

Similar to a SQL GROUP BY CUBE clause. A CUBE expression will generate subtotals for all combinations of the dimensions specified.

+

Similar to a SQL GROUP BY CUBE clause. A CUBE expression will generate subtotals for all combinations of the dimensions specified. E.g. .cube('a, 'b) is equivalent to .groupingSets(('a, 'b), ('a), ('b), ()).

{% highlight scala %} val in = ds.toTable(tableEnv, 'a, 'b, 'c); val result = in.cube('a, 'b).select('a, 'b, 'c.sum as 'd); @@ -813,9 +813,9 @@ val result = in.cube('a, 'b).select('a, 'b, 'c.sum as 'd); - Cube + Rollup -

Similar to a SQL GROUP BY ROLLUP clause. A ROLLUP expression produces group subtotals from right to left and a grand total.

+

Similar to a SQL GROUP BY ROLLUP clause. A ROLLUP expression produces group subtotals from right to left and a grand total. E.g. .cube('a, 'b) is equivalent to .groupingSets(('a, 'b), ('a), ()).

{% highlight scala %} val in = ds.toTable(tableEnv, 'a, 'b, 'c); val result = in.rollup('a, 'b).select('a, 'b, 'c.sum as 'd); @@ -2748,42 +2748,6 @@ ARRAY.element() - - - {% highlight scala %} -groupId() -{% endhighlight %} - - -

Returns an integer that uniquely identifies the combination of grouping keys.

- - - - - - {% highlight scala %} -ANY.grouping() -grouping(ANY) -{% endhighlight %} - - -

Returns 1 if expression is rolled up in the current row’s grouping set, 0 otherwise.

- - - - - - {% highlight scala %} -ANY.groupingId() -(ANY [, ANY ]*).groupingId() -groupingId(ANY [, ANY ]*) -{% endhighlight %} - - -

Returns a bit vector of the given grouping expressions.

- - - diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/scala/expressionDsl.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/scala/expressionDsl.scala index a189e4e6fef32..27b13c9cb1b32 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/scala/expressionDsl.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/scala/expressionDsl.scala @@ -725,34 +725,34 @@ object array { } /** - * Grouping function. Similar to a SQL GROUP_ID function. + * Returns an integer that uniquely identifies the combination of grouping keys. */ object groupId { /** - * Return evaluated result of the function. + * Returns an integer that uniquely identifies the combination of grouping keys. */ def apply(): Expression = GroupId() } /** - * Grouping function. Similar to a SQL GROUPING function. + * Returns 1 if expression is rolled up in the current row’s grouping set, 0 otherwise. */ object grouping { /** - * Return evaluated result of the function. + * Returns 1 if expression is rolled up in the current row’s grouping set, 0 otherwise. */ def apply(expression: Expression): Expression = Grouping(expression) } /** - * Grouping function. Similar to a SQL GROUPING function. + * Returns a bit vector of the given grouping expressions. */ object groupingId { /** - * Return evaluated result of the function. + * Returns a bit vector of the given grouping expressions. */ def apply(expression: Expression*): Expression = GroupingId(expression: _*) } diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/table.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/table.scala index 912b739e1756c..65e7a384b32f0 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/table.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/table.scala @@ -230,7 +230,7 @@ class Table( * }}} */ def groupBy(fields: Expression*): GroupedTable = { - new GroupedTable(this, fields) + new GroupedTable(this, Seq(fields)) } /** @@ -259,12 +259,12 @@ class Table( * tab.groupingSets(('a, 'b), ('a), ()).select('a, 'b, 'c.avg) * }}} */ - def groupingSets(fields: Expression*): GroupingSetsTable = { + def groupingSets(fields: Expression*): GroupedTable = { val groups = fields.map { case g: GroupedExpression => g.flatChildren case x => Seq(x) } - new GroupingSetsTable(this, groups, SqlKind.GROUPING_SETS) + new GroupedTable(this, groups) } /** @@ -278,7 +278,7 @@ class Table( * tab.groupingSets("(a, b), (a), ()").select("a, b, c.avg") * }}} */ - def groupingSets(fields: String): GroupingSetsTable = { + def groupingSets(fields: String): GroupedTable = { val fieldsExpr = ExpressionParser.parseExpressionList(fields) groupingSets(fieldsExpr: _*) } @@ -293,12 +293,12 @@ class Table( * tab.cube('a, 'b).select('a, 'b, 'c.avg) * }}} */ - def cube(fields: Expression*): GroupingSetsTable = { + def cube(fields: Expression*): GroupedTable = { val groups = fields.map { case g: GroupedExpression => g.flatChildren case x => Seq(x) } - new GroupingSetsTable(this, groups, SqlKind.CUBE) + new GroupedTable(this, ExpressionUtils.cube(groups)) } /** @@ -311,7 +311,7 @@ class Table( * tab.cube("a, b").select("a, b, c.avg") * }}} */ - def cube(fields: String): GroupingSetsTable = { + def cube(fields: String): GroupedTable = { val fieldsExpr = ExpressionParser.parseExpressionList(fields) cube(fieldsExpr: _*) } @@ -326,12 +326,12 @@ class Table( * tab.rollup('a, 'b).select('a, 'b, 'c.avg) * }}} */ - def rollup(fields: Expression*): GroupingSetsTable = { + def rollup(fields: Expression*): GroupedTable = { val groups = fields.map { case g: GroupedExpression => g.flatChildren case x => Seq(x) } - new GroupingSetsTable(this, groups, SqlKind.ROLLUP) + new GroupedTable(this, ExpressionUtils.rollup(groups)) } /** @@ -344,7 +344,7 @@ class Table( * tab.rollup("a, b").select("a, b, c.avg") * }}} */ - def rollup(fields: String): GroupingSetsTable = { + def rollup(fields: String): GroupedTable = { val fieldsExpr = ExpressionParser.parseExpressionList(fields) rollup(fieldsExpr: _*) } @@ -914,7 +914,7 @@ class Table( */ class GroupedTable( private[flink] val table: Table, - private[flink] val groupKey: Seq[Expression]) { + private[flink] val groups: Seq[Seq[Expression]]) { /** * Performs a selection operation on a grouped table. Similar to an SQL SELECT statement. @@ -934,11 +934,11 @@ class GroupedTable( val projectsOnAgg = replaceAggregationsAndProperties( fields, table.tableEnv, aggNames, propNames) - val projectFields = extractFieldReferences(fields ++ groupKey) + val projectFields = extractFieldReferences(fields ++ groups.flatten.distinct) new Table(table.tableEnv, Project(projectsOnAgg, - Aggregate(Seq(groupKey), aggNames.map(a => Alias(a._1, a._2)).toSeq, + Aggregate(groups, aggNames.map(a => Alias(a._1, a._2)).toSeq, Project(projectFields, table.logicalPlan).validate(table.tableEnv) ).validate(table.tableEnv) ).validate(table.tableEnv)) @@ -972,13 +972,13 @@ class GroupedTable( * @return A windowed table. */ def window(groupWindow: GroupWindow): GroupWindowedTable = { - new GroupWindowedTable(table, groupKey, groupWindow) + new GroupWindowedTable(table, groups, groupWindow) } } class GroupWindowedTable( private[flink] val table: Table, - private[flink] val groupKey: Seq[Expression], + private[flink] val groups: Seq[Seq[Expression]], private[flink] val window: GroupWindow) { /** @@ -999,16 +999,16 @@ class GroupWindowedTable( val projectFields = (table.tableEnv, window) match { // event time can be arbitrary field in batch environment case (_: BatchTableEnvironment, w: EventTimeWindow) => - extractFieldReferences(fields ++ groupKey ++ Seq(w.timeField)) + extractFieldReferences(fields ++ groups.flatten.distinct ++ Seq(w.timeField)) case (_, _) => - extractFieldReferences(fields ++ groupKey) + extractFieldReferences(fields ++ groups.flatten.distinct) } new Table(table.tableEnv, Project( projectsOnAgg, WindowAggregate( - Seq(groupKey), + groups, window.toLogicalWindow, propNames.map(a => Alias(a._1, a._2)).toSeq, aggNames.map(a => Alias(a._1, a._2)).toSeq, @@ -1032,66 +1032,3 @@ class GroupWindowedTable( select(fieldExprs: _*) } } - -/** - * A table that has been grouped on several sets of grouping keys. - */ -class GroupingSetsTable( - private[flink] val table: Table, - private[flink] val groups: Seq[Seq[Expression]], - private[flink] val sqlKind: SqlKind) { - - /** - * Performs a selection operation on a grouped table. Similar to an SQL SELECT statement. - * The field expressions can contain complex expressions and aggregations. - * - * Example: - * - * {{{ - * tab.groupingSets('key).select('key, 'value.avg + " The average" as 'average) - * }}} - */ - def select(fields: Expression*): Table = { - - val (aggNames, propNames) = extractAggregationsAndProperties(fields, table.tableEnv) - - if (propNames.nonEmpty) { - throw ValidationException("Window properties can only be used on windowed tables.") - } - - val groupingSets = sqlKind match { - case SqlKind.CUBE => ExpressionUtils.cube(groups) - case SqlKind.ROLLUP => ExpressionUtils.rollup(groups) - case _ => groups - } - - val projectsOnAgg = replaceAggregationsAndProperties( - fields, table.tableEnv, aggNames, propNames) - val projectFields = extractFieldReferences(fields ++ groupingSets.flatten.distinct) - - val logical = - Project(projectsOnAgg, - Aggregate(groupingSets, aggNames.map(a => Alias(a._1, a._2)).toSeq, - Project(projectFields, table.logicalPlan).validate(table.tableEnv) - ).validate(table.tableEnv) - ).validate(table.tableEnv) - - new Table(table.tableEnv, logical) - } - - /** - * Performs a selection operation on a grouped table. Similar to an SQL SELECT statement. - * The field expressions can contain complex expressions and aggregations. - * - * Example: - * - * {{{ - * tab.groupBy("key").select("key, value.avg + ' The average' as average") - * }}} - */ - def select(fields: String): Table = { - val fieldExprs = ExpressionParser.parseExpressionList(fields) - select(fieldExprs: _*) - } -} - diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/expressions/ExpressionUtils.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/expressions/ExpressionUtils.scala index 0fcaf97b0fa10..af71e48f25645 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/expressions/ExpressionUtils.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/expressions/ExpressionUtils.scala @@ -157,12 +157,13 @@ object ExpressionUtils { /** Computes the rollup of bit sets. * - *

For example, rollup({0}, {1}) - * returns ({0, 1}, {0}, {}). + * For example, {{{ rollup({0}, {1}) }}} + * returns {{{ ({0, 1}, {0}, {}) }}} * - *

Bit sets are not necessarily singletons: - * rollup({0, 2}, {3, 5}) - * returns ({0, 2, 3, 5}, {0, 2}, {}). */ + * Bit sets are not necessarily singletons: + * {{{ rollup({0, 2}, {3, 5}) }}} + * returns {{{ ({0, 2, 3, 5}, {0, 2}, {}) }}}. + */ private[flink] def rollup(groups: Seq[Seq[Expression]]): Seq[Seq[Expression]] = { val originalBitSet = for (i <- groups.indices) yield { ImmutableBitSet.builder().set(i).build() @@ -173,12 +174,13 @@ object ExpressionUtils { /** Computes the cube of bit sets. * - *

For example, rollup({0}, {1}) - * returns ({0, 1}, {0}, {}). + * For example, {{{ cube({0}, {1}) }}} + * returns {{{ ({0, 1}, {0}, {1}, {}) }}} * - *

Bit sets are not necessarily singletons: - * rollup({0, 2}, {3, 5}) - * returns ({0, 2, 3, 5}, {0, 2}, {}). */ + * Bit sets are not necessarily singletons: + * {{{ rollup({0, 2}, {3, 5}) }}} + * returns {{{ ({0, 2, 3, 5}, {0, 2}, {3, 5}, {}) }}}. + */ private[flink] def cube(groups: Seq[Seq[Expression]]): Seq[Seq[Expression]] = { val originalBitSet = for (i <- groups.indices) yield { ImmutableBitSet.builder().set(i).build() diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/scala/batch/table/GroupingSetsTest.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/batch/GroupingSetsTest.scala similarity index 99% rename from flink-libraries/flink-table/src/test/scala/org/apache/flink/api/scala/batch/table/GroupingSetsTest.scala rename to flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/batch/GroupingSetsTest.scala index f0eff50cba63d..8766ee18df5d4 100644 --- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/scala/batch/table/GroupingSetsTest.scala +++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/batch/GroupingSetsTest.scala @@ -16,7 +16,7 @@ * limitations under the License. */ -package org.apache.flink.api.scala.batch.table +package org.apache.flink.table.api.scala.batch import org.apache.flink.api.scala._ import org.apache.flink.api.scala.util.CollectionDataSets