From 1898cf8bc21944b0f6eb80ab1e961d98b2c8f8d3 Mon Sep 17 00:00:00 2001 From: Shuyi Chen Date: Thu, 25 Jan 2018 15:36:36 -0800 Subject: [PATCH 1/2] Support access of subfields of Array element if the element is a composite type (e.g. case class, pojo, tuple or row). --- .../flink/table/codegen/CodeGenerator.scala | 6 + .../table/codegen/calls/ScalarOperators.scala | 76 ++++++++- .../expressions/CompositeAccessTest.scala | 28 ++++ .../utils/CompositeTypeTestBase.scala | 12 +- .../table/runtime/batch/sql/CalcITCase.scala | 29 +++- .../table/runtime/stream/sql/SqlITCase.scala | 158 +++++++++++++++++- 6 files changed, 298 insertions(+), 11 deletions(-) diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/CodeGenerator.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/CodeGenerator.scala index aca0ba33aae87..eb0f2abe1f289 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/CodeGenerator.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/CodeGenerator.scala @@ -1001,6 +1001,12 @@ abstract class CodeGenerator( requireArray(array) generateArrayElement(this, array) + case DOT => + // Due to https://issues.apache.org/jira/browse/CALCITE-2162, expression such as + // "array[1].a.b" won't work now. + require(operands.size == 2) + generateDot(this, call, operands.head, operands(1)) + case ScalarSqlFunctions.CONCAT => generateConcat(this.nullCheck, operands) diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/calls/ScalarOperators.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/calls/ScalarOperators.scala index 742ee7de3b28d..9f25aa2dc9dec 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/calls/ScalarOperators.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/calls/ScalarOperators.scala @@ -19,14 +19,17 @@ package org.apache.flink.table.codegen.calls import org.apache.calcite.avatica.util.DateTimeUtils.MILLIS_PER_DAY import org.apache.calcite.avatica.util.{DateTimeUtils, TimeUnitRange} -import org.apache.calcite.util.BuiltInMethod +import org.apache.calcite.rex.{RexCall, RexLiteral} +import org.apache.calcite.util.{BuiltInMethod, NlsString} import org.apache.flink.api.common.typeinfo.BasicTypeInfo._ import org.apache.flink.api.common.typeinfo._ -import org.apache.flink.api.common.typeutils.CompositeType -import org.apache.flink.api.java.typeutils.{MapTypeInfo, ObjectArrayTypeInfo, RowTypeInfo} +import org.apache.flink.api.java.typeutils._ +import org.apache.flink.api.scala.typeutils.CaseClassTypeInfo +import org.apache.flink.table.calcite.FlinkTypeFactory import org.apache.flink.table.codegen.CodeGenUtils._ import org.apache.flink.table.codegen.calls.CallGenerator.generateCallIfArgsNotNull import org.apache.flink.table.codegen.{CodeGenException, CodeGenerator, GeneratedExpression} +import org.apache.flink.table.plan.schema.CompositeRelDataType import org.apache.flink.table.typeutils.{TimeIndicatorTypeInfo, TimeIntervalTypeInfo, TypeCoercion} import org.apache.flink.table.typeutils.TypeCheckUtils._ @@ -984,6 +987,73 @@ object ScalarOperators { } } + def generateDot( + codeGenerator: CodeGenerator, + dot: RexCall, + record: GeneratedExpression, + subField: GeneratedExpression) + : GeneratedExpression = { + val nullTerm = newName("isNull") + val resultTerm = newName("result") + val resultType = FlinkTypeFactory.toTypeInfo(dot.getType) + val resultTypeTerm = boxedTypeTermForTypeInfo(resultType) + dot.operands.get(0).getType match { + case crdt: CompositeRelDataType => { + val fieldName = dot.operands.get(1).asInstanceOf[RexLiteral] + .getValue.asInstanceOf[NlsString].getValue + if (crdt.compositeType.isInstanceOf[TupleTypeInfo[_]]) { + return GeneratedExpression(resultTerm, nullTerm, + s""" + |${record.code} + |${subField.code} + |${resultTypeTerm} $resultTerm = null; + |if (${record.resultTerm} != null) { + | $resultTerm = (${resultTypeTerm}) ${record.resultTerm}.productElement( + | ${fieldName.substring(1).toInt} - 1); + |} + |boolean $nullTerm = ${resultTerm} == null; + |""".stripMargin, resultType) + } else if (crdt.compositeType.isInstanceOf[CaseClassTypeInfo[_]]) { + return GeneratedExpression(resultTerm, nullTerm, + s""" + |${record.code} + |${resultTypeTerm} $resultTerm = null; + |if (${record.resultTerm} != null) { + | $resultTerm = (${resultTypeTerm}) ${record.resultTerm}.${fieldName}(); + |} + |boolean $nullTerm = ${resultTerm} == null; + |""".stripMargin, resultType) + } else if (crdt.compositeType.isInstanceOf[PojoTypeInfo[_]]) { + return GeneratedExpression(resultTerm, nullTerm, + s""" + |${record.code} + |${resultTypeTerm} $resultTerm = null; + |if (${record.resultTerm} != null) { + | $resultTerm = + | (${resultTypeTerm}) ${record.resultTerm}.${fieldName}; + |} + |boolean $nullTerm = ${resultTerm} == null; + |""".stripMargin, resultType) + } else if (crdt.compositeType.isInstanceOf[RowTypeInfo]) { + val fieldIndex = dot.operands.get(0).getType.asInstanceOf[CompositeRelDataType] + .compositeType.getFieldIndex(fieldName) + return GeneratedExpression(resultTerm, nullTerm, + s""" + |${record.code} + |${resultTypeTerm} $resultTerm = null; + |if (${record.resultTerm} != null) { + | $resultTerm = (${resultTypeTerm}) ${record.resultTerm}.getField(${fieldIndex}); + |} + |boolean $nullTerm = ${resultTerm} == null; + |""".stripMargin, resultType) + } + } + case _ => + } + + throw new CodeGenException("Unsupported type: %s".format(dot.operands.get(0).getType)) + } + def generateArrayElement( codeGenerator: CodeGenerator, array: GeneratedExpression) diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/expressions/CompositeAccessTest.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/expressions/CompositeAccessTest.scala index bd1c19990f5d8..1a5f41cb2fc4e 100644 --- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/expressions/CompositeAccessTest.scala +++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/expressions/CompositeAccessTest.scala @@ -116,6 +116,34 @@ class CompositeAccessTest extends CompositeTypeTestBase { "42") testSqlApi("f0.*", "42") + testAllApis( + 'f8.at(1).get("_1"), + "f8.at(1).get('_1')", + "f8[1]._1", + "true" + ) + + testAllApis( + 'f9.at(1).get("_1"), + "f9.at(1).get('_1')", + "f9[1]._1", + "null" + ) + + testAllApis( + 'f10.at(1).get("stringField"), + "f10.at(1).get('stringField')", + "f10[1].stringField", + "Bob" + ) + + testAllApis( + 'f11.at(1).get("myString"), + "f11.at(1).get('myString')", + "f11[1].myString", + "Hello" + ) + testTableApi(12.flatten(), "12.flatten()", "12") testTableApi('f5.flatten(), "f5.flatten()", "13") diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/expressions/utils/CompositeTypeTestBase.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/expressions/utils/CompositeTypeTestBase.scala index 8f7360ebee56a..2354f3929e6ff 100644 --- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/expressions/utils/CompositeTypeTestBase.scala +++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/expressions/utils/CompositeTypeTestBase.scala @@ -28,7 +28,7 @@ import org.apache.flink.types.Row class CompositeTypeTestBase extends ExpressionTestBase { def testData: Row = { - val testData = new Row(8) + val testData = new Row(12) testData.setField(0, MyCaseClass(42, "Bob", booleanField = true)) testData.setField(1, MyCaseClass2(MyCaseClass(25, "Timo", booleanField = false))) testData.setField(2, ("a", "b")) @@ -37,6 +37,10 @@ class CompositeTypeTestBase extends ExpressionTestBase { testData.setField(5, 13) testData.setField(6, MyCaseClass2(null)) testData.setField(7, Tuple1(true)) + testData.setField(8, Array(Tuple1(true))) + testData.setField(9, Array(Tuple1(null))) + testData.setField(10, Array(MyCaseClass(42, "Bob", booleanField = true))) + testData.setField(11, Array(new MyPojo())) testData } @@ -49,7 +53,11 @@ class CompositeTypeTestBase extends ExpressionTestBase { TypeExtractor.createTypeInfo(classOf[MyPojo]), Types.INT, createTypeInformation[MyCaseClass2], - createTypeInformation[Tuple1[Boolean]] + createTypeInformation[Tuple1[Boolean]], + createTypeInformation[Array[Tuple1[Boolean]]], + createTypeInformation[Array[Tuple1[Boolean]]], + createTypeInformation[Array[MyCaseClass]], + createTypeInformation[Array[MyPojo]] ).asInstanceOf[TypeInformation[Any]] } } diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/batch/sql/CalcITCase.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/batch/sql/CalcITCase.scala index 6aed9a83c1bc2..55b63561c09bf 100644 --- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/batch/sql/CalcITCase.scala +++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/batch/sql/CalcITCase.scala @@ -29,7 +29,7 @@ import org.apache.flink.table.expressions.utils.SplitUDF import org.apache.flink.table.functions.ScalarFunction import org.apache.flink.table.runtime.batch.table.OldHashCode import org.apache.flink.table.runtime.utils.TableProgramsTestBase.TableConfigMode -import org.apache.flink.table.runtime.utils.{TableProgramsCollectionTestBase, TableProgramsTestBase} +import org.apache.flink.table.runtime.utils.{JavaPojos, TableProgramsCollectionTestBase, TableProgramsTestBase} import org.apache.flink.test.util.TestBaseUtils import org.apache.flink.types.Row import org.junit._ @@ -389,6 +389,33 @@ class CalcITCase( val expected = List("a,a,d,d,e,e", "x,x,z,z,z,z").mkString("\n") TestBaseUtils.compareResultAsText(results.asJava, expected) } + + @Test + def testArrayElementAtFromTableForPojo(): Unit = { + + val env = ExecutionEnvironment.getExecutionEnvironment + val tEnv = TableEnvironment.getTableEnvironment(env) + + val p1 = new JavaPojos.Pojo1(); + p1.msg = "msg1"; + val p2 = new JavaPojos.Pojo1(); + p2.msg = "msg2"; + val data = List( + (1, Array(p1)), + (2, Array(p2)) + ) + val stream = env.fromCollection(data) + tEnv.registerDataSet("T", stream, 'a, 'b) + + val sqlQuery = "SELECT a, b[1].msg FROM T" + + val results = tEnv.sqlQuery(sqlQuery).toDataSet[Row].collect() + + val expected = List( + "1,msg1", + "2,msg2").mkString("\n") + TestBaseUtils.compareResultAsText(results.asJava, expected) + } } object MyHashCode extends ScalarFunction { diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/stream/sql/SqlITCase.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/stream/sql/SqlITCase.scala index 0633b712837e0..8e48e0c0ccc9f 100644 --- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/stream/sql/SqlITCase.scala +++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/stream/sql/SqlITCase.scala @@ -18,25 +18,29 @@ package org.apache.flink.table.runtime.stream.sql -import org.apache.flink.api.common.typeinfo.{BasicTypeInfo, TypeInformation} -import org.apache.flink.api.java.typeutils.RowTypeInfo +import org.apache.flink.api.common.typeinfo.{BasicTypeInfo, TypeInformation, Types} +import org.apache.flink.api.java.typeutils.{ObjectArrayTypeInfo, RowTypeInfo} import org.apache.flink.api.scala._ +import org.apache.flink.streaming.api import org.apache.flink.streaming.api.TimeCharacteristic +import org.apache.flink.streaming.api.datastream.DataStream import org.apache.flink.streaming.api.functions.AssignerWithPunctuatedWatermarks import org.apache.flink.streaming.api.scala.StreamExecutionEnvironment import org.apache.flink.streaming.api.watermark.Watermark -import org.apache.flink.table.api.{TableEnvironment, Types} +import org.apache.flink.table.api.{TableEnvironment, TableSchema} import org.apache.flink.table.api.scala._ import org.apache.flink.table.expressions.utils.SplitUDF import org.apache.flink.table.expressions.utils.Func15 -import org.apache.flink.table.runtime.stream.sql.SqlITCase.TimestampAndWatermarkWithOffset +import org.apache.flink.table.runtime.stream.sql.SqlITCase.{TestCaseClass, TimestampAndWatermarkWithOffset} import org.apache.flink.table.runtime.utils.TimeTestUtil.EventTimeSourceFunction -import org.apache.flink.table.runtime.utils.{StreamITCase, StreamTestData, StreamingWithStateTestBase} +import org.apache.flink.table.runtime.utils._ +import org.apache.flink.table.sources.StreamTableSource import org.apache.flink.types.Row import org.apache.flink.table.utils.MemoryTableSinkUtil import org.junit.Assert._ import org.junit._ +import scala.collection.JavaConverters._ import scala.collection.mutable class SqlITCase extends StreamingWithStateTestBase { @@ -469,6 +473,148 @@ class SqlITCase extends StreamingWithStateTestBase { assertEquals(expected.sorted, StreamITCase.testResults.sorted) } + @Test + def testArrayElementAtFromTableForTuple(): Unit = { + + val env = StreamExecutionEnvironment.getExecutionEnvironment + val tEnv = TableEnvironment.getTableEnvironment(env) + StreamITCase.clear + + val data = List( + (1, Array((12, 45), (2, 5))), + (2, Array(null, (1, 49))), + (3, Array((18, 42), (127, 454))) + ) + val stream = env.fromCollection(data) + tEnv.registerDataStream("T", stream, 'a, 'b) + + val sqlQuery = "SELECT a, b[1]._1 FROM T" + + val result = tEnv.sqlQuery(sqlQuery).toAppendStream[Row] + result.addSink(new StreamITCase.StringSink[Row]) + env.execute() + + val expected = List( + "1,12", + "2,null", + "3,18") + assertEquals(expected.sorted, StreamITCase.testResults.sorted) + } + + @Test + def testArrayElementAtFromTableForCaseClass(): Unit = { + + val env = StreamExecutionEnvironment.getExecutionEnvironment + val tEnv = TableEnvironment.getTableEnvironment(env) + StreamITCase.clear + + val data = List( + (1, Array(TestCaseClass(12, 45), TestCaseClass(2, 5))), + (2, Array(TestCaseClass(41, 5), TestCaseClass(1, 49))), + (3, Array(TestCaseClass(18, 42), TestCaseClass(127, 454))) + ) + val stream = env.fromCollection(data) + tEnv.registerDataStream("T", stream, 'a, 'b) + + val sqlQuery = "SELECT a, b[1].f1 FROM T" + + val result = tEnv.sqlQuery(sqlQuery).toAppendStream[Row] + result.addSink(new StreamITCase.StringSink[Row]) + env.execute() + + val expected = List( + "1,45", + "2,5", + "3,42") + assertEquals(expected.sorted, StreamITCase.testResults.sorted) + } + + @Test + def testArrayElementAtFromTableForPojo(): Unit = { + + val env = StreamExecutionEnvironment.getExecutionEnvironment + val tEnv = TableEnvironment.getTableEnvironment(env) + StreamITCase.clear + + val p1 = new JavaPojos.Pojo1(); + p1.msg = "msg1"; + val p2 = new JavaPojos.Pojo1(); + p2.msg = "msg2"; + val data = List( + (1, Array(p1)), + (2, Array(p2)) + ) + val stream = env.fromCollection(data) + tEnv.registerDataStream("T", stream, 'a, 'b) + + val sqlQuery = "SELECT a, b[1].msg FROM T" + + val result = tEnv.sqlQuery(sqlQuery).toAppendStream[Row] + result.addSink(new StreamITCase.StringSink[Row]) + env.execute() + + val expected = List( + "1,msg1", + "2,msg2") + assertEquals(expected.sorted, StreamITCase.testResults.sorted) + } + + @Test + def testArrayElementAtFromTableForRow(): Unit = { + val env = StreamExecutionEnvironment.getExecutionEnvironment + val tEnv = TableEnvironment.getTableEnvironment(env) + tEnv.registerTableSource("mytable", new StreamTableSource[Row] { + private val fieldNames: Array[String] = Array("name", "record") + private val fieldTypes: Array[TypeInformation[_]] = + Array( + Types.STRING, + ObjectArrayTypeInfo.getInfoFor(Types.ROW_NAMED( + Array[String]("longField", "strField", "floatField", "arrayField"), + Types.LONG, + Types.STRING, + Types.FLOAT, + ObjectArrayTypeInfo.getInfoFor( + Types.ROW_NAMED(Array[String]("nestedLong"), Types.LONG))))) + .asInstanceOf[Array[TypeInformation[_]]] + + override def getDataStream(execEnv: api.environment.StreamExecutionEnvironment) + : DataStream[Row] = { + val nestRow1 = new Row(1) + nestRow1.setField(0, 1213L) + val mockRow1 = new Row(4) + mockRow1.setField(0, 273132121L) + mockRow1.setField(1, "str1") + mockRow1.setField(2, 123.4f) + mockRow1.setField(3, Array(nestRow1)) + val mockRow2 = new Row(4) + mockRow2.setField(0, 27318121L) + mockRow2.setField(1, "str2") + mockRow2.setField(2, 987.2f) + mockRow2.setField(3, Array(nestRow1)) + val data = List( + Row.of("Mary", Array(mockRow1, mockRow2)), + Row.of("Mary", Array(mockRow2, mockRow1))).asJava + execEnv.fromCollection(data, getReturnType) + } + + override def getReturnType: TypeInformation[Row] = new RowTypeInfo(fieldTypes, fieldNames) + override def getTableSchema: TableSchema = new TableSchema(fieldNames, fieldTypes) + }) + StreamITCase.clear + + val sqlQuery = "SELECT name, record[1].floatField, record[1].strField, " + + "record[2].longField, record[1].arrayField[1].nestedLong FROM mytable" + + val result = tEnv.sqlQuery(sqlQuery).toAppendStream[Row] + result.addSink(new StreamITCase.StringSink[Row]) + env.execute() + + val expected = List( + "Mary,123.4,str1,27318121,1213", + "Mary,987.2,str2,273132121,1213") + assertEquals(expected.sorted, StreamITCase.testResults.sorted) + } + @Test def testHopStartEndWithHaving(): Unit = { val env = StreamExecutionEnvironment.getExecutionEnvironment @@ -616,6 +762,8 @@ class SqlITCase extends StreamingWithStateTestBase { object SqlITCase { + case class TestCaseClass(f0: Integer, f1: Integer) extends Serializable + class TimestampAndWatermarkWithOffset[T <: Product]( offset: Long) extends AssignerWithPunctuatedWatermarks[T] { From 6ebfde759bf502b3efd30d7eb8ad703a37602db5 Mon Sep 17 00:00:00 2001 From: Shuyi Chen Date: Fri, 2 Feb 2018 14:24:32 -0800 Subject: [PATCH 2/2] refactor to use generateFieldAccess(); move all tests to CompositeAccessTest --- .../flink/table/codegen/CodeGenerator.scala | 21 ++- .../table/codegen/calls/ScalarOperators.scala | 76 +-------- .../expressions/CompositeAccessTest.scala | 22 ++- .../utils/CompositeTypeTestBase.scala | 16 +- .../table/runtime/batch/sql/CalcITCase.scala | 29 +--- .../table/runtime/stream/sql/SqlITCase.scala | 158 +----------------- 6 files changed, 54 insertions(+), 268 deletions(-) diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/CodeGenerator.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/CodeGenerator.scala index eb0f2abe1f289..78086a0b4d8a5 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/CodeGenerator.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/CodeGenerator.scala @@ -26,6 +26,7 @@ import org.apache.calcite.sql.SqlOperator import org.apache.calcite.sql.`type`.SqlTypeName._ import org.apache.calcite.sql.`type`.{ReturnTypes, SqlTypeName} import org.apache.calcite.sql.fun.SqlStdOperatorTable.{ROW, _} +import org.apache.calcite.util.NlsString import org.apache.commons.lang3.StringEscapeUtils import org.apache.flink.api.common.functions._ import org.apache.flink.api.common.typeinfo._ @@ -42,6 +43,7 @@ import org.apache.flink.table.codegen.calls.{CurrentTimePointCallGen, FunctionGe import org.apache.flink.table.functions.sql.{ProctimeSqlFunction, ScalarSqlFunctions, StreamRecordTimestampSqlFunction} import org.apache.flink.table.functions.utils.UserDefinedFunctionUtils import org.apache.flink.table.functions.{FunctionContext, UserDefinedFunction} +import org.apache.flink.table.plan.schema.CompositeRelDataType import org.apache.flink.table.typeutils.TimeIndicatorTypeInfo import org.apache.flink.table.typeutils.TypeCheckUtils._ import org.joda.time.format.DateTimeFormatter @@ -578,9 +580,9 @@ abstract class CodeGenerator( override def visitTableInputRef(rexTableInputRef: RexTableInputRef): GeneratedExpression = visitInputRef(rexTableInputRef) - override def visitFieldAccess(rexFieldAccess: RexFieldAccess): GeneratedExpression = { - val refExpr = rexFieldAccess.getReferenceExpr.accept(this) - val index = rexFieldAccess.getField.getIndex + private def generateFieldAccessExpr( + refExpr: GeneratedExpression, + index: Int): GeneratedExpression = { val fieldAccessExpr = generateFieldAccess( refExpr.resultType, refExpr.resultTerm, @@ -616,6 +618,12 @@ abstract class CodeGenerator( GeneratedExpression(resultTerm, nullTerm, resultCode, fieldAccessExpr.resultType) } + override def visitFieldAccess(rexFieldAccess: RexFieldAccess): GeneratedExpression = { + val refExpr = rexFieldAccess.getReferenceExpr.accept(this) + val index = rexFieldAccess.getField.getIndex + generateFieldAccessExpr(refExpr, index) + } + override def visitLiteral(literal: RexLiteral): GeneratedExpression = { val resultType = FlinkTypeFactory.toTypeInfo(literal.getType) val value = literal.getValue3 @@ -1005,7 +1013,12 @@ abstract class CodeGenerator( // Due to https://issues.apache.org/jira/browse/CALCITE-2162, expression such as // "array[1].a.b" won't work now. require(operands.size == 2) - generateDot(this, call, operands.head, operands(1)) + val fieldName = + call.operands.get(1).asInstanceOf[RexLiteral].getValue.asInstanceOf[NlsString].getValue + generateFieldAccessExpr( + operands.head, + call.operands.get(0).getType.asInstanceOf[CompositeRelDataType] + .compositeType.getFieldIndex(fieldName)) case ScalarSqlFunctions.CONCAT => generateConcat(this.nullCheck, operands) diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/calls/ScalarOperators.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/calls/ScalarOperators.scala index 9f25aa2dc9dec..742ee7de3b28d 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/calls/ScalarOperators.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/calls/ScalarOperators.scala @@ -19,17 +19,14 @@ package org.apache.flink.table.codegen.calls import org.apache.calcite.avatica.util.DateTimeUtils.MILLIS_PER_DAY import org.apache.calcite.avatica.util.{DateTimeUtils, TimeUnitRange} -import org.apache.calcite.rex.{RexCall, RexLiteral} -import org.apache.calcite.util.{BuiltInMethod, NlsString} +import org.apache.calcite.util.BuiltInMethod import org.apache.flink.api.common.typeinfo.BasicTypeInfo._ import org.apache.flink.api.common.typeinfo._ -import org.apache.flink.api.java.typeutils._ -import org.apache.flink.api.scala.typeutils.CaseClassTypeInfo -import org.apache.flink.table.calcite.FlinkTypeFactory +import org.apache.flink.api.common.typeutils.CompositeType +import org.apache.flink.api.java.typeutils.{MapTypeInfo, ObjectArrayTypeInfo, RowTypeInfo} import org.apache.flink.table.codegen.CodeGenUtils._ import org.apache.flink.table.codegen.calls.CallGenerator.generateCallIfArgsNotNull import org.apache.flink.table.codegen.{CodeGenException, CodeGenerator, GeneratedExpression} -import org.apache.flink.table.plan.schema.CompositeRelDataType import org.apache.flink.table.typeutils.{TimeIndicatorTypeInfo, TimeIntervalTypeInfo, TypeCoercion} import org.apache.flink.table.typeutils.TypeCheckUtils._ @@ -987,73 +984,6 @@ object ScalarOperators { } } - def generateDot( - codeGenerator: CodeGenerator, - dot: RexCall, - record: GeneratedExpression, - subField: GeneratedExpression) - : GeneratedExpression = { - val nullTerm = newName("isNull") - val resultTerm = newName("result") - val resultType = FlinkTypeFactory.toTypeInfo(dot.getType) - val resultTypeTerm = boxedTypeTermForTypeInfo(resultType) - dot.operands.get(0).getType match { - case crdt: CompositeRelDataType => { - val fieldName = dot.operands.get(1).asInstanceOf[RexLiteral] - .getValue.asInstanceOf[NlsString].getValue - if (crdt.compositeType.isInstanceOf[TupleTypeInfo[_]]) { - return GeneratedExpression(resultTerm, nullTerm, - s""" - |${record.code} - |${subField.code} - |${resultTypeTerm} $resultTerm = null; - |if (${record.resultTerm} != null) { - | $resultTerm = (${resultTypeTerm}) ${record.resultTerm}.productElement( - | ${fieldName.substring(1).toInt} - 1); - |} - |boolean $nullTerm = ${resultTerm} == null; - |""".stripMargin, resultType) - } else if (crdt.compositeType.isInstanceOf[CaseClassTypeInfo[_]]) { - return GeneratedExpression(resultTerm, nullTerm, - s""" - |${record.code} - |${resultTypeTerm} $resultTerm = null; - |if (${record.resultTerm} != null) { - | $resultTerm = (${resultTypeTerm}) ${record.resultTerm}.${fieldName}(); - |} - |boolean $nullTerm = ${resultTerm} == null; - |""".stripMargin, resultType) - } else if (crdt.compositeType.isInstanceOf[PojoTypeInfo[_]]) { - return GeneratedExpression(resultTerm, nullTerm, - s""" - |${record.code} - |${resultTypeTerm} $resultTerm = null; - |if (${record.resultTerm} != null) { - | $resultTerm = - | (${resultTypeTerm}) ${record.resultTerm}.${fieldName}; - |} - |boolean $nullTerm = ${resultTerm} == null; - |""".stripMargin, resultType) - } else if (crdt.compositeType.isInstanceOf[RowTypeInfo]) { - val fieldIndex = dot.operands.get(0).getType.asInstanceOf[CompositeRelDataType] - .compositeType.getFieldIndex(fieldName) - return GeneratedExpression(resultTerm, nullTerm, - s""" - |${record.code} - |${resultTypeTerm} $resultTerm = null; - |if (${record.resultTerm} != null) { - | $resultTerm = (${resultTypeTerm}) ${record.resultTerm}.getField(${fieldIndex}); - |} - |boolean $nullTerm = ${resultTerm} == null; - |""".stripMargin, resultType) - } - } - case _ => - } - - throw new CodeGenException("Unsupported type: %s".format(dot.operands.get(0).getType)) - } - def generateArrayElement( codeGenerator: CodeGenerator, array: GeneratedExpression) diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/expressions/CompositeAccessTest.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/expressions/CompositeAccessTest.scala index 1a5f41cb2fc4e..5c4f4aa4ab545 100644 --- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/expressions/CompositeAccessTest.scala +++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/expressions/CompositeAccessTest.scala @@ -120,13 +120,20 @@ class CompositeAccessTest extends CompositeTypeTestBase { 'f8.at(1).get("_1"), "f8.at(1).get('_1')", "f8[1]._1", - "true" + "null" + ) + + testAllApis( + 'f8.at(1).get("_2"), + "f8.at(1).get('_2')", + "f8[1]._2", + "23" ) testAllApis( - 'f9.at(1).get("_1"), - "f9.at(1).get('_1')", - "f9[1]._1", + 'f9.at(2).get("_1"), + "f9.at(2).get('_1')", + "f9[2]._1", "null" ) @@ -144,6 +151,13 @@ class CompositeAccessTest extends CompositeTypeTestBase { "Hello" ) + testAllApis( + 'f12.at(1).get("arrayField").at(1).get("stringField"), + "f12.at(1).get('arrayField').at(1).get('stringField')", + "f12[1].arrayField[1].stringField", + "Alice" + ) + testTableApi(12.flatten(), "12.flatten()", "12") testTableApi('f5.flatten(), "f5.flatten()", "13") diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/expressions/utils/CompositeTypeTestBase.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/expressions/utils/CompositeTypeTestBase.scala index 2354f3929e6ff..71dad4f4a17fe 100644 --- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/expressions/utils/CompositeTypeTestBase.scala +++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/expressions/utils/CompositeTypeTestBase.scala @@ -22,13 +22,13 @@ import org.apache.flink.api.common.typeinfo.TypeInformation import org.apache.flink.api.java.typeutils.{RowTypeInfo, TupleTypeInfo, TypeExtractor} import org.apache.flink.api.scala.createTypeInformation import org.apache.flink.table.api.Types -import org.apache.flink.table.expressions.utils.CompositeTypeTestBase.{MyCaseClass, MyCaseClass2, MyPojo} +import org.apache.flink.table.expressions.utils.CompositeTypeTestBase.{MyCaseClass, MyCaseClass2, MyCaseClass3, MyPojo} import org.apache.flink.types.Row class CompositeTypeTestBase extends ExpressionTestBase { def testData: Row = { - val testData = new Row(12) + val testData = new Row(13) testData.setField(0, MyCaseClass(42, "Bob", booleanField = true)) testData.setField(1, MyCaseClass2(MyCaseClass(25, "Timo", booleanField = false))) testData.setField(2, ("a", "b")) @@ -37,10 +37,11 @@ class CompositeTypeTestBase extends ExpressionTestBase { testData.setField(5, 13) testData.setField(6, MyCaseClass2(null)) testData.setField(7, Tuple1(true)) - testData.setField(8, Array(Tuple1(true))) - testData.setField(9, Array(Tuple1(null))) + testData.setField(8, Array(Tuple2(null, 23), Tuple2(false, 12))) + testData.setField(9, Array(Tuple1(true), null)) testData.setField(10, Array(MyCaseClass(42, "Bob", booleanField = true))) testData.setField(11, Array(new MyPojo())) + testData.setField(12, Array(MyCaseClass3(Array(MyCaseClass(42, "Alice", booleanField = true))))) testData } @@ -54,10 +55,11 @@ class CompositeTypeTestBase extends ExpressionTestBase { Types.INT, createTypeInformation[MyCaseClass2], createTypeInformation[Tuple1[Boolean]], - createTypeInformation[Array[Tuple1[Boolean]]], + createTypeInformation[Array[Tuple2[Boolean, Int]]], createTypeInformation[Array[Tuple1[Boolean]]], createTypeInformation[Array[MyCaseClass]], - createTypeInformation[Array[MyPojo]] + createTypeInformation[Array[MyPojo]], + createTypeInformation[Array[MyCaseClass3]] ).asInstanceOf[TypeInformation[Any]] } } @@ -67,6 +69,8 @@ object CompositeTypeTestBase { case class MyCaseClass2(objectField: MyCaseClass) + case class MyCaseClass3(arrayField: Array[MyCaseClass]) + class MyPojo { private var myInt: Int = 0 private var myString: String = "Hello" diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/batch/sql/CalcITCase.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/batch/sql/CalcITCase.scala index 55b63561c09bf..6aed9a83c1bc2 100644 --- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/batch/sql/CalcITCase.scala +++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/batch/sql/CalcITCase.scala @@ -29,7 +29,7 @@ import org.apache.flink.table.expressions.utils.SplitUDF import org.apache.flink.table.functions.ScalarFunction import org.apache.flink.table.runtime.batch.table.OldHashCode import org.apache.flink.table.runtime.utils.TableProgramsTestBase.TableConfigMode -import org.apache.flink.table.runtime.utils.{JavaPojos, TableProgramsCollectionTestBase, TableProgramsTestBase} +import org.apache.flink.table.runtime.utils.{TableProgramsCollectionTestBase, TableProgramsTestBase} import org.apache.flink.test.util.TestBaseUtils import org.apache.flink.types.Row import org.junit._ @@ -389,33 +389,6 @@ class CalcITCase( val expected = List("a,a,d,d,e,e", "x,x,z,z,z,z").mkString("\n") TestBaseUtils.compareResultAsText(results.asJava, expected) } - - @Test - def testArrayElementAtFromTableForPojo(): Unit = { - - val env = ExecutionEnvironment.getExecutionEnvironment - val tEnv = TableEnvironment.getTableEnvironment(env) - - val p1 = new JavaPojos.Pojo1(); - p1.msg = "msg1"; - val p2 = new JavaPojos.Pojo1(); - p2.msg = "msg2"; - val data = List( - (1, Array(p1)), - (2, Array(p2)) - ) - val stream = env.fromCollection(data) - tEnv.registerDataSet("T", stream, 'a, 'b) - - val sqlQuery = "SELECT a, b[1].msg FROM T" - - val results = tEnv.sqlQuery(sqlQuery).toDataSet[Row].collect() - - val expected = List( - "1,msg1", - "2,msg2").mkString("\n") - TestBaseUtils.compareResultAsText(results.asJava, expected) - } } object MyHashCode extends ScalarFunction { diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/stream/sql/SqlITCase.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/stream/sql/SqlITCase.scala index 8e48e0c0ccc9f..0633b712837e0 100644 --- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/stream/sql/SqlITCase.scala +++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/stream/sql/SqlITCase.scala @@ -18,29 +18,25 @@ package org.apache.flink.table.runtime.stream.sql -import org.apache.flink.api.common.typeinfo.{BasicTypeInfo, TypeInformation, Types} -import org.apache.flink.api.java.typeutils.{ObjectArrayTypeInfo, RowTypeInfo} +import org.apache.flink.api.common.typeinfo.{BasicTypeInfo, TypeInformation} +import org.apache.flink.api.java.typeutils.RowTypeInfo import org.apache.flink.api.scala._ -import org.apache.flink.streaming.api import org.apache.flink.streaming.api.TimeCharacteristic -import org.apache.flink.streaming.api.datastream.DataStream import org.apache.flink.streaming.api.functions.AssignerWithPunctuatedWatermarks import org.apache.flink.streaming.api.scala.StreamExecutionEnvironment import org.apache.flink.streaming.api.watermark.Watermark -import org.apache.flink.table.api.{TableEnvironment, TableSchema} +import org.apache.flink.table.api.{TableEnvironment, Types} import org.apache.flink.table.api.scala._ import org.apache.flink.table.expressions.utils.SplitUDF import org.apache.flink.table.expressions.utils.Func15 -import org.apache.flink.table.runtime.stream.sql.SqlITCase.{TestCaseClass, TimestampAndWatermarkWithOffset} +import org.apache.flink.table.runtime.stream.sql.SqlITCase.TimestampAndWatermarkWithOffset import org.apache.flink.table.runtime.utils.TimeTestUtil.EventTimeSourceFunction -import org.apache.flink.table.runtime.utils._ -import org.apache.flink.table.sources.StreamTableSource +import org.apache.flink.table.runtime.utils.{StreamITCase, StreamTestData, StreamingWithStateTestBase} import org.apache.flink.types.Row import org.apache.flink.table.utils.MemoryTableSinkUtil import org.junit.Assert._ import org.junit._ -import scala.collection.JavaConverters._ import scala.collection.mutable class SqlITCase extends StreamingWithStateTestBase { @@ -473,148 +469,6 @@ class SqlITCase extends StreamingWithStateTestBase { assertEquals(expected.sorted, StreamITCase.testResults.sorted) } - @Test - def testArrayElementAtFromTableForTuple(): Unit = { - - val env = StreamExecutionEnvironment.getExecutionEnvironment - val tEnv = TableEnvironment.getTableEnvironment(env) - StreamITCase.clear - - val data = List( - (1, Array((12, 45), (2, 5))), - (2, Array(null, (1, 49))), - (3, Array((18, 42), (127, 454))) - ) - val stream = env.fromCollection(data) - tEnv.registerDataStream("T", stream, 'a, 'b) - - val sqlQuery = "SELECT a, b[1]._1 FROM T" - - val result = tEnv.sqlQuery(sqlQuery).toAppendStream[Row] - result.addSink(new StreamITCase.StringSink[Row]) - env.execute() - - val expected = List( - "1,12", - "2,null", - "3,18") - assertEquals(expected.sorted, StreamITCase.testResults.sorted) - } - - @Test - def testArrayElementAtFromTableForCaseClass(): Unit = { - - val env = StreamExecutionEnvironment.getExecutionEnvironment - val tEnv = TableEnvironment.getTableEnvironment(env) - StreamITCase.clear - - val data = List( - (1, Array(TestCaseClass(12, 45), TestCaseClass(2, 5))), - (2, Array(TestCaseClass(41, 5), TestCaseClass(1, 49))), - (3, Array(TestCaseClass(18, 42), TestCaseClass(127, 454))) - ) - val stream = env.fromCollection(data) - tEnv.registerDataStream("T", stream, 'a, 'b) - - val sqlQuery = "SELECT a, b[1].f1 FROM T" - - val result = tEnv.sqlQuery(sqlQuery).toAppendStream[Row] - result.addSink(new StreamITCase.StringSink[Row]) - env.execute() - - val expected = List( - "1,45", - "2,5", - "3,42") - assertEquals(expected.sorted, StreamITCase.testResults.sorted) - } - - @Test - def testArrayElementAtFromTableForPojo(): Unit = { - - val env = StreamExecutionEnvironment.getExecutionEnvironment - val tEnv = TableEnvironment.getTableEnvironment(env) - StreamITCase.clear - - val p1 = new JavaPojos.Pojo1(); - p1.msg = "msg1"; - val p2 = new JavaPojos.Pojo1(); - p2.msg = "msg2"; - val data = List( - (1, Array(p1)), - (2, Array(p2)) - ) - val stream = env.fromCollection(data) - tEnv.registerDataStream("T", stream, 'a, 'b) - - val sqlQuery = "SELECT a, b[1].msg FROM T" - - val result = tEnv.sqlQuery(sqlQuery).toAppendStream[Row] - result.addSink(new StreamITCase.StringSink[Row]) - env.execute() - - val expected = List( - "1,msg1", - "2,msg2") - assertEquals(expected.sorted, StreamITCase.testResults.sorted) - } - - @Test - def testArrayElementAtFromTableForRow(): Unit = { - val env = StreamExecutionEnvironment.getExecutionEnvironment - val tEnv = TableEnvironment.getTableEnvironment(env) - tEnv.registerTableSource("mytable", new StreamTableSource[Row] { - private val fieldNames: Array[String] = Array("name", "record") - private val fieldTypes: Array[TypeInformation[_]] = - Array( - Types.STRING, - ObjectArrayTypeInfo.getInfoFor(Types.ROW_NAMED( - Array[String]("longField", "strField", "floatField", "arrayField"), - Types.LONG, - Types.STRING, - Types.FLOAT, - ObjectArrayTypeInfo.getInfoFor( - Types.ROW_NAMED(Array[String]("nestedLong"), Types.LONG))))) - .asInstanceOf[Array[TypeInformation[_]]] - - override def getDataStream(execEnv: api.environment.StreamExecutionEnvironment) - : DataStream[Row] = { - val nestRow1 = new Row(1) - nestRow1.setField(0, 1213L) - val mockRow1 = new Row(4) - mockRow1.setField(0, 273132121L) - mockRow1.setField(1, "str1") - mockRow1.setField(2, 123.4f) - mockRow1.setField(3, Array(nestRow1)) - val mockRow2 = new Row(4) - mockRow2.setField(0, 27318121L) - mockRow2.setField(1, "str2") - mockRow2.setField(2, 987.2f) - mockRow2.setField(3, Array(nestRow1)) - val data = List( - Row.of("Mary", Array(mockRow1, mockRow2)), - Row.of("Mary", Array(mockRow2, mockRow1))).asJava - execEnv.fromCollection(data, getReturnType) - } - - override def getReturnType: TypeInformation[Row] = new RowTypeInfo(fieldTypes, fieldNames) - override def getTableSchema: TableSchema = new TableSchema(fieldNames, fieldTypes) - }) - StreamITCase.clear - - val sqlQuery = "SELECT name, record[1].floatField, record[1].strField, " + - "record[2].longField, record[1].arrayField[1].nestedLong FROM mytable" - - val result = tEnv.sqlQuery(sqlQuery).toAppendStream[Row] - result.addSink(new StreamITCase.StringSink[Row]) - env.execute() - - val expected = List( - "Mary,123.4,str1,27318121,1213", - "Mary,987.2,str2,273132121,1213") - assertEquals(expected.sorted, StreamITCase.testResults.sorted) - } - @Test def testHopStartEndWithHaving(): Unit = { val env = StreamExecutionEnvironment.getExecutionEnvironment @@ -762,8 +616,6 @@ class SqlITCase extends StreamingWithStateTestBase { object SqlITCase { - case class TestCaseClass(f0: Integer, f1: Integer) extends Serializable - class TimestampAndWatermarkWithOffset[T <: Product]( offset: Long) extends AssignerWithPunctuatedWatermarks[T] {