From 7ea31c98decd5e5a161992da8187e111758553f0 Mon Sep 17 00:00:00 2001 From: twalthr Date: Mon, 19 Jun 2017 17:06:44 +0200 Subject: [PATCH] [FLINK-6881] [FLINK-6896] [table] Creating a table from a POJO and defining a time attribute fails This closes #4144. This closes #4111. --- .../table/api/StreamTableEnvironment.scala | 112 ++++++++++-------- .../flink/table/api/scala/package.scala | 2 +- .../calcite/RelTimeIndicatorConverter.scala | 27 ++++- .../flink/table/codegen/CodeGenerator.scala | 42 +++---- .../table/expressions/ExpressionParser.scala | 18 ++- .../flink/table/api/java/utils/Pojos.java | 2 +- .../datastream/TimeAttributesITCase.scala | 89 +++++++++++++- 7 files changed, 208 insertions(+), 84 deletions(-) diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/StreamTableEnvironment.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/StreamTableEnvironment.scala index 178bd9f86ccdb..eb3eb5c461d41 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/StreamTableEnvironment.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/StreamTableEnvironment.scala @@ -20,11 +20,10 @@ package org.apache.flink.table.api import _root_.java.lang.{Boolean => JBool} import _root_.java.util.concurrent.atomic.AtomicInteger -import _root_.java.util.{List => JList} import org.apache.calcite.plan.RelOptUtil import org.apache.calcite.plan.hep.HepMatchOrder -import org.apache.calcite.rel.`type`.{RelDataType, RelDataTypeField, RelDataTypeFieldImpl, RelRecordType} +import org.apache.calcite.rel.`type`.RelDataType import org.apache.calcite.rel.{RelNode, RelVisitor} import org.apache.calcite.rex.{RexCall, RexInputRef, RexNode} import org.apache.calcite.sql.SqlKind @@ -34,14 +33,14 @@ import org.apache.flink.api.common.functions.MapFunction import org.apache.flink.api.common.typeinfo.{AtomicType, TypeInformation} import org.apache.flink.api.common.typeutils.CompositeType import org.apache.flink.api.java.tuple.{Tuple2 => JTuple2} -import org.apache.flink.api.java.typeutils.TupleTypeInfo +import org.apache.flink.api.java.typeutils.{PojoTypeInfo, TupleTypeInfo} import org.apache.flink.api.scala.typeutils.CaseClassTypeInfo import org.apache.flink.streaming.api.TimeCharacteristic import org.apache.flink.streaming.api.datastream.DataStream import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment -import org.apache.flink.table.calcite.{FlinkTypeFactory, RelTimeIndicatorConverter} +import org.apache.flink.table.calcite.RelTimeIndicatorConverter import org.apache.flink.table.explain.PlanJsonParser -import org.apache.flink.table.expressions.{Expression, ProctimeAttribute, RowtimeAttribute, UnresolvedFieldReference} +import org.apache.flink.table.expressions._ import org.apache.flink.table.plan.nodes.FlinkConventions import org.apache.flink.table.plan.nodes.datastream.{DataStreamRel, UpdateAsRetractionTrait, _} import org.apache.flink.table.plan.rules.FlinkRuleSets @@ -438,39 +437,69 @@ abstract class StreamTableEnvironment( var rowtime: Option[(Int, String)] = None var proctime: Option[(Int, String)] = None - exprs.zipWithIndex.foreach { - case (RowtimeAttribute(reference@UnresolvedFieldReference(name)), idx) => - if (rowtime.isDefined) { + def extractRowtime(idx: Int, name: String, origName: Option[String]): Unit = { + if (rowtime.isDefined) { + throw new TableException( + "The rowtime attribute can only be defined once in a table schema.") + } else { + val mappedIdx = streamType match { + case pti: PojoTypeInfo[_] => + pti.getFieldIndex(origName.getOrElse(name)) + case _ => idx; + } + // check type of field that is replaced + if (mappedIdx < 0) { throw new TableException( - "The rowtime attribute can only be defined once in a table schema.") - } else { - // check type of field that is replaced - if (idx < fieldTypes.length && - !(TypeCheckUtils.isLong(fieldTypes(idx)) || - TypeCheckUtils.isTimePoint(fieldTypes(idx)))) { - throw new TableException( - "The rowtime attribute can only be replace a field with a valid time type, such as " + - "Timestamp or Long.") - } - rowtime = Some(idx, name) + s"The rowtime attribute can only replace a valid field. " + + s"${origName.getOrElse(name)} is not a field of type $streamType.") } - case (ProctimeAttribute(reference@UnresolvedFieldReference(name)), idx) => - if (proctime.isDefined) { + else if (mappedIdx < fieldTypes.length && + !(TypeCheckUtils.isLong(fieldTypes(mappedIdx)) || + TypeCheckUtils.isTimePoint(fieldTypes(mappedIdx)))) { + throw new TableException( + s"The rowtime attribute can only replace a field with a valid time type, " + + s"such as Timestamp or Long. But was: ${fieldTypes(mappedIdx)}") + } + + rowtime = Some(idx, name) + } + } + + def extractProctime(idx: Int, name: String): Unit = { + if (proctime.isDefined) { throw new TableException( "The proctime attribute can only be defined once in a table schema.") - } else { - // check that proctime is only appended - if (idx < fieldTypes.length) { - throw new TableException( - "The proctime attribute can only be appended to the table schema and not replace " + - "an existing field. Please move it to the end of the schema.") - } - proctime = Some(idx, name) + } else { + // check that proctime is only appended + if (idx < fieldTypes.length) { + throw new TableException( + "The proctime attribute can only be appended to the table schema and not replace " + + "an existing field. Please move it to the end of the schema.") } - case (u: UnresolvedFieldReference, _) => fieldNames = u.name :: fieldNames + proctime = Some(idx, name) + } + } - case _ => - throw new TableException("Time attributes can only be defined on field references.") + exprs.zipWithIndex.foreach { + case (RowtimeAttribute(UnresolvedFieldReference(name)), idx) => + extractRowtime(idx, name, None) + + case (RowtimeAttribute(Alias(UnresolvedFieldReference(origName), name, _)), idx) => + extractRowtime(idx, name, Some(origName)) + + case (ProctimeAttribute(UnresolvedFieldReference(name)), idx) => + extractProctime(idx, name) + + case (ProctimeAttribute(Alias(UnresolvedFieldReference(_), name, _)), idx) => + extractProctime(idx, name) + + case (UnresolvedFieldReference(name), _) => fieldNames = name :: fieldNames + + case (Alias(UnresolvedFieldReference(_), name, _), _) => fieldNames = name :: fieldNames + + case (e, _) => + throw new TableException(s"Time attributes can only be defined on field references or " + + s"aliases of field references. But was: $e") } if (rowtime.isDefined && fieldNames.contains(rowtime.get._2)) { @@ -606,21 +635,10 @@ abstract class StreamTableEnvironment( val relNode = table.getRelNode val dataStreamPlan = optimize(relNode, updatesAsRetraction) - // zip original field names with optimized field types - val fieldTypes = relNode.getRowType.getFieldList.asScala - .zip(dataStreamPlan.getRowType.getFieldList.asScala) - // get name of original plan and type of optimized plan - .map(x => (x._1.getName, x._2.getType)) - // add field indexes - .zipWithIndex - // build new field types - .map(x => new RelDataTypeFieldImpl(x._1._1, x._2, x._1._2)) - - // build a record type from list of field types - val rowType = new RelRecordType( - fieldTypes.toList.asJava.asInstanceOf[JList[RelDataTypeField]]) - - translate(dataStreamPlan, rowType, queryConfig, withChangeFlag) + // we convert the logical row type to the output row type + val convertedOutputType = RelTimeIndicatorConverter.convertOutputType(relNode) + + translate(dataStreamPlan, convertedOutputType, queryConfig, withChangeFlag) } /** diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/scala/package.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/scala/package.scala index 9d15c1488cee5..cc1a388ccf29a 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/scala/package.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/scala/package.scala @@ -82,7 +82,7 @@ package object scala extends ImplicitExpressionConversions { } implicit def dataStream2DataStreamConversions[T](set: DataStream[T]): DataStreamConversions[T] = { - new DataStreamConversions[T](set, set.dataType.asInstanceOf[CompositeType[T]]) + new DataStreamConversions[T](set, set.dataType) } implicit def table2RowDataStream(table: Table): DataStream[Row] = { diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/calcite/RelTimeIndicatorConverter.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/calcite/RelTimeIndicatorConverter.scala index 21fa70bd6758a..b28e3f8615d1d 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/calcite/RelTimeIndicatorConverter.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/calcite/RelTimeIndicatorConverter.scala @@ -18,7 +18,7 @@ package org.apache.flink.table.calcite -import org.apache.calcite.rel.`type`.RelDataType +import org.apache.calcite.rel.`type`.{RelDataType, RelDataTypeFieldImpl, RelRecordType} import org.apache.calcite.rel.core._ import org.apache.calcite.rel.logical._ import org.apache.calcite.rel.{RelNode, RelShuttle} @@ -26,10 +26,10 @@ import org.apache.calcite.rex._ import org.apache.calcite.sql.fun.SqlStdOperatorTable import org.apache.flink.api.common.typeinfo.SqlTimeTypeInfo import org.apache.flink.table.api.{TableException, ValidationException} -import org.apache.flink.table.functions.TimeMaterializationSqlFunction -import org.apache.flink.table.plan.schema.TimeIndicatorRelDataType import org.apache.flink.table.calcite.FlinkTypeFactory.isTimeIndicatorType +import org.apache.flink.table.functions.TimeMaterializationSqlFunction import org.apache.flink.table.plan.logical.rel.LogicalWindowAggregate +import org.apache.flink.table.plan.schema.TimeIndicatorRelDataType import scala.collection.JavaConversions._ import scala.collection.mutable @@ -391,4 +391,25 @@ object RelTimeIndicatorConverter { convertedRoot } } + + def convertOutputType(rootRel: RelNode): RelDataType = { + + val timestamp = rootRel + .getCluster + .getRexBuilder + .getTypeFactory + .asInstanceOf[FlinkTypeFactory] + .createTypeFromTypeInfo(SqlTimeTypeInfo.TIMESTAMP) + + // convert all time indicators types to timestamps + val fields = rootRel.getRowType.getFieldList.map { field => + if (isTimeIndicatorType(field.getType)) { + new RelDataTypeFieldImpl(field.getName, field.getIndex, timestamp) + } else { + field + } + } + + new RelRecordType(fields) + } } diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/CodeGenerator.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/CodeGenerator.scala index 4e8dfb6afbe43..75dc1d0648fbc 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/CodeGenerator.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/CodeGenerator.scala @@ -871,12 +871,16 @@ class CodeGenerator( returnType: TypeInformation[_ <: Any], resultFieldNames: Seq[String]) : GeneratedExpression = { - val input1AccessExprs = for (i <- 0 until input1.getArity if input1Mapping.contains(i)) - yield generateInputAccess(input1, input1Term, i, input1Mapping) + val input1AccessExprs = input1Mapping.map { idx => + generateInputAccess(input1, input1Term, idx, input1Mapping) + } val input2AccessExprs = input2 match { - case Some(ti) => for (i <- 0 until ti.getArity if input2Mapping.contains(i)) - yield generateInputAccess(ti, input2Term, i, input2Mapping) + case Some(ti) => + input2Mapping.map { idx => + generateInputAccess(ti, input2Term, idx, input2Mapping) + }.toSeq + case None => Seq() // add nothing } @@ -887,15 +891,18 @@ class CodeGenerator( * Generates an expression from the left input and the right table function. */ def generateCorrelateAccessExprs: (Seq[GeneratedExpression], Seq[GeneratedExpression]) = { - val input1AccessExprs = for (i <- 0 until input1.getArity) - yield generateInputAccess(input1, input1Term, i, input1Mapping) + val input1AccessExprs = input1Mapping.map { idx => + generateInputAccess(input1, input1Term, idx, input1Mapping) + } val input2AccessExprs = input2 match { - case Some(ti) => for (i <- 0 until ti.getArity if input2Mapping.contains(i)) + case Some(ti) => // use generateFieldAccess instead of generateInputAccess to avoid the generated table // function's field access code is put on the top of function body rather than // the while loop - yield generateFieldAccess(ti, input2Term, i, input2Mapping) + input2Mapping.map { idx => + generateFieldAccess(ti, input2Term, idx, input2Mapping) + }.toSeq case None => throw new CodeGenException("Type information of input2 must not be null.") } (input1AccessExprs, input2AccessExprs) @@ -1625,14 +1632,7 @@ class CodeGenerator( val nullTerm = newName("isNull") val fieldType = inputType match { - case ct: CompositeType[_] => - val fieldIndex = if (ct.isInstanceOf[PojoTypeInfo[_]]) { - fieldMapping(index) - } - else { - index - } - ct.getTypeAt(fieldIndex) + case ct: CompositeType[_] => ct.getTypeAt(index) case at: AtomicType[_] => at case _ => throw new CodeGenException("Unsupported type for input field access.") } @@ -1666,14 +1666,8 @@ class CodeGenerator( : GeneratedExpression = { inputType match { case ct: CompositeType[_] => - val fieldIndex = if (ct.isInstanceOf[PojoTypeInfo[_]]) { - fieldMapping(index) - } - else { - index - } - val accessor = fieldAccessorFor(ct, fieldIndex) - val fieldType: TypeInformation[Any] = ct.getTypeAt(fieldIndex) + val accessor = fieldAccessorFor(ct, index) + val fieldType: TypeInformation[Any] = ct.getTypeAt(index) val fieldTypeTerm = boxedTypeTermForTypeInfo(fieldType) accessor match { diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/expressions/ExpressionParser.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/expressions/ExpressionParser.scala index e1ffb33eba61b..f67fbacb76e9c 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/expressions/ExpressionParser.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/expressions/ExpressionParser.scala @@ -469,13 +469,15 @@ object ExpressionParser extends JavaTokenParsers with PackratParsers { lazy val timeIndicator: PackratParser[Expression] = proctime | rowtime - lazy val proctime: PackratParser[Expression] = fieldReference ~ "." ~ PROCTIME ^^ { - case f ~ _ ~ _ => ProctimeAttribute(f) - } + lazy val proctime: PackratParser[Expression] = + (aliasMapping | "(" ~> aliasMapping <~ ")" | fieldReference) ~ "." ~ PROCTIME ^^ { + case f ~ _ ~ _ => ProctimeAttribute(f) + } - lazy val rowtime: PackratParser[Expression] = fieldReference ~ "." ~ ROWTIME ^^ { - case f ~ _ ~ _ => RowtimeAttribute(f) - } + lazy val rowtime: PackratParser[Expression] = + (aliasMapping | "(" ~> aliasMapping <~ ")" | fieldReference) ~ "." ~ ROWTIME ^^ { + case f ~ _ ~ _ => RowtimeAttribute(f) + } // alias @@ -485,6 +487,10 @@ object ExpressionParser extends JavaTokenParsers with PackratParsers { case e ~ _ ~ _ ~ names ~ _ => Alias(e, names.head.name, names.tail.map(_.name)) } | logic + lazy val aliasMapping: PackratParser[Expression] = fieldReference ~ AS ~ fieldReference ^^ { + case e ~ _ ~ name => Alias(e, name.name) + } + lazy val expression: PackratParser[Expression] = timeIndicator | overConstant | alias | failure("Invalid expression.") diff --git a/flink-libraries/flink-table/src/test/java/org/apache/flink/table/api/java/utils/Pojos.java b/flink-libraries/flink-table/src/test/java/org/apache/flink/table/api/java/utils/Pojos.java index 30488350ac89f..69b789097db85 100644 --- a/flink-libraries/flink-table/src/test/java/org/apache/flink/table/api/java/utils/Pojos.java +++ b/flink-libraries/flink-table/src/test/java/org/apache/flink/table/api/java/utils/Pojos.java @@ -22,7 +22,7 @@ import java.sql.Timestamp; /** - * POJOs for table api testing. + * POJOs for Table API testing. */ public class Pojos { diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/datastream/TimeAttributesITCase.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/datastream/TimeAttributesITCase.scala index 73cb701cfe106..c434f47fa3601 100644 --- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/datastream/TimeAttributesITCase.scala +++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/datastream/TimeAttributesITCase.scala @@ -31,8 +31,8 @@ import org.apache.flink.table.api.scala._ import org.apache.flink.table.api.scala.stream.utils.StreamITCase import org.apache.flink.table.api.{TableEnvironment, TableException, Types, ValidationException} import org.apache.flink.table.calcite.RelTimeIndicatorConverterTest.TableFunc -import org.apache.flink.table.expressions.TimeIntervalUnit -import org.apache.flink.table.runtime.datastream.TimeAttributesITCase.TimestampWithEqualWatermark +import org.apache.flink.table.expressions.{ExpressionParser, TimeIntervalUnit} +import org.apache.flink.table.runtime.datastream.TimeAttributesITCase.{TestPojo, TimestampWithEqualWatermark, TimestampWithEqualWatermarkPojo} import org.apache.flink.types.Row import org.junit.Assert._ import org.junit.Test @@ -337,6 +337,67 @@ class TimeAttributesITCase extends StreamingMultipleProgramsTestBase { assertEquals(expected.sorted, StreamITCase.testResults.sorted) } + @Test + def testPojoSupport(): Unit = { + val env = StreamExecutionEnvironment.getExecutionEnvironment + env.setStreamTimeCharacteristic(TimeCharacteristic.EventTime) + val tEnv = TableEnvironment.getTableEnvironment(env) + StreamITCase.testResults = mutable.MutableList() + + val p1 = new TestPojo + p1.a = 12 + p1.b = 42L + p1.c = "Test me." + + val p2 = new TestPojo + p2.a = 13 + p2.b = 43L + p2.c = "And me." + + val stream = env + .fromElements(p1, p2) + .assignTimestampsAndWatermarks(new TimestampWithEqualWatermarkPojo) + // use aliases, swap all attributes, and skip b2 + val table = stream.toTable(tEnv, ('b as 'b).rowtime, 'c as 'c, 'a as 'a) + // no aliases, no swapping + val table2 = stream.toTable(tEnv, 'a, 'b.rowtime, 'c) + // use proctime, no skipping + val table3 = stream.toTable(tEnv, 'a, 'b.rowtime, 'c, 'b2, 'proctime.proctime) + + // Java expressions + + // use aliases, swap all attributes, and skip b2 + val table4 = stream.toTable( + tEnv, + ExpressionParser.parseExpressionList("(b as b).rowtime, c as c, a as a"): _*) + // no aliases, no swapping + val table5 = stream.toTable( + tEnv, + ExpressionParser.parseExpressionList("a, b.rowtime, c"): _*) + + val t = table.select('b, 'c , 'a) + .unionAll(table2.select('b, 'c, 'a)) + .unionAll(table3.select('b, 'c, 'a)) + .unionAll(table4.select('b, 'c, 'a)) + .unionAll(table5.select('b, 'c, 'a)) + + val results = t.toAppendStream[Row] + results.addSink(new StreamITCase.StringSink[Row]) + env.execute() + + val expected = Seq( + "1970-01-01 00:00:00.042,Test me.,12", + "1970-01-01 00:00:00.042,Test me.,12", + "1970-01-01 00:00:00.042,Test me.,12", + "1970-01-01 00:00:00.042,Test me.,12", + "1970-01-01 00:00:00.042,Test me.,12", + "1970-01-01 00:00:00.043,And me.,13", + "1970-01-01 00:00:00.043,And me.,13", + "1970-01-01 00:00:00.043,And me.,13", + "1970-01-01 00:00:00.043,And me.,13", + "1970-01-01 00:00:00.043,And me.,13") + assertEquals(expected.sorted, StreamITCase.testResults.sorted) + } } object TimeAttributesITCase { @@ -356,4 +417,28 @@ object TimeAttributesITCase { element._1 } } + + class TimestampWithEqualWatermarkPojo + extends AssignerWithPunctuatedWatermarks[TestPojo] { + + override def checkAndGetNextWatermark( + lastElement: TestPojo, + extractedTimestamp: Long) + : Watermark = { + new Watermark(extractedTimestamp) + } + + override def extractTimestamp( + element: TestPojo, + previousElementTimestamp: Long): Long = { + element.b + } + } + + class TestPojo() { + var a: Int = _ + var b: Long = _ + var b2: String = "skip me" + var c: String = _ + } }