From 0da17b1d03c8fd482f3a39873729d9f9011724e4 Mon Sep 17 00:00:00 2001 From: twalthr Date: Mon, 15 Feb 2016 17:00:41 +0100 Subject: [PATCH] [FLINK-3226] Translation of explicit casting --- .../api/table/codegen/CodeGenUtils.scala | 2 +- .../api/table/codegen/CodeGenerator.scala | 6 +- .../api/table/codegen/OperatorCodeGen.scala | 67 ++++++++++++++++- .../flink/api/table/plan/TypeConverter.scala | 16 ++-- .../api/java/table/test/CastingITCase.java | 35 ++++++++- .../api/scala/table/test/CastingITCase.scala | 75 +++++++++++++++---- 6 files changed, 178 insertions(+), 23 deletions(-) diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/codegen/CodeGenUtils.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/codegen/CodeGenUtils.scala index 5bd14679a96bc..110f5890f10ed 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/codegen/CodeGenUtils.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/codegen/CodeGenUtils.scala @@ -91,7 +91,7 @@ object CodeGenUtils { case FLOAT_TYPE_INFO => "-1.0f" case DOUBLE_TYPE_INFO => "-1.0d" case BOOLEAN_TYPE_INFO => "false" - case STRING_TYPE_INFO => "\"\"" + case STRING_TYPE_INFO => "\"\"" case CHAR_TYPE_INFO => "'\\0'" case _ => "null" } diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/codegen/CodeGenerator.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/codegen/CodeGenerator.scala index a4ae4b1407a29..c121afa2a2010 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/codegen/CodeGenerator.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/codegen/CodeGenerator.scala @@ -593,9 +593,13 @@ class CodeGenerator( case NOT => val operand = operands.head - requireBoolean(operand) + requireBoolean(operand) generateNot(nullCheck, operand) + case CAST => + val operand = operands.head + generateCast(nullCheck, operand, resultType) + case call@_ => throw new CodeGenException(s"Unsupported call: $call") } diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/codegen/OperatorCodeGen.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/codegen/OperatorCodeGen.scala index 8402569b518d3..a7b4dea68606b 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/codegen/OperatorCodeGen.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/codegen/OperatorCodeGen.scala @@ -17,8 +17,8 @@ */ package org.apache.flink.api.table.codegen -import org.apache.flink.api.common.typeinfo.BasicTypeInfo.BOOLEAN_TYPE_INFO -import org.apache.flink.api.common.typeinfo.TypeInformation +import org.apache.flink.api.common.typeinfo.BasicTypeInfo._ +import org.apache.flink.api.common.typeinfo.{NumericTypeInfo, BasicTypeInfo, TypeInformation} import org.apache.flink.api.table.codegen.CodeGenUtils._ object OperatorCodeGen { @@ -289,6 +289,69 @@ object OperatorCodeGen { } } + def generateCast( + nullCheck: Boolean, + operand: GeneratedExpression, + targetType: TypeInformation[_]) + : GeneratedExpression = { + targetType match { + // identity casting + case operand.resultType => + generateUnaryOperatorIfNotNull(nullCheck, targetType, operand) { + (operandTerm) => s"$operandTerm" + } + + // * -> String + case STRING_TYPE_INFO => + generateUnaryOperatorIfNotNull(nullCheck, targetType, operand) { + (operandTerm) => s""" "" + $operandTerm""" + } + + // * -> Date + case DATE_TYPE_INFO => + throw new CodeGenException("Date type not supported yet.") + + // * -> Void + case VOID_TYPE_INFO => + throw new CodeGenException("Void type not supported.") + + // * -> Character + case CHAR_TYPE_INFO => + throw new CodeGenException("Character type not supported.") + + // NUMERIC TYPE -> Boolean + case BOOLEAN_TYPE_INFO if isNumeric(operand) => + generateUnaryOperatorIfNotNull(nullCheck, targetType, operand) { + (operandTerm) => s"$operandTerm != 0" + } + + // String -> BASIC TYPE (not String, Date, Void, Character) + case ti: BasicTypeInfo[_] if isString(operand) => + val wrapperClass = targetType.getTypeClass.getCanonicalName + generateUnaryOperatorIfNotNull(nullCheck, targetType, operand) { + (operandTerm) => s"$wrapperClass.valueOf($operandTerm)" + } + + // NUMERIC TYPE -> NUMERIC TYPE + case nti: NumericTypeInfo[_] if isNumeric(operand) => + val targetTypeTerm = primitiveTypeTermForTypeInfo(nti) + generateUnaryOperatorIfNotNull(nullCheck, targetType, operand) { + (operandTerm) => s"($targetTypeTerm) $operandTerm" + } + + // Boolean -> NUMERIC TYPE + case nti: NumericTypeInfo[_] if isBoolean(operand) => + val targetTypeTerm = primitiveTypeTermForTypeInfo(nti) + generateUnaryOperatorIfNotNull(nullCheck, targetType, operand) { + (operandTerm) => s"($targetTypeTerm) ($operandTerm ? 1 : 0)" + } + + case _ => + throw new CodeGenException(s"Unsupported cast from '${operand.resultType}'" + + s"to '$targetType'.") + } + } + // ---------------------------------------------------------------------------------------------- private def generateUnaryOperatorIfNotNull( diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/TypeConverter.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/TypeConverter.scala index 1fc482ae873cf..7e94869e821bf 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/TypeConverter.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/TypeConverter.scala @@ -19,21 +19,21 @@ package org.apache.flink.api.table.plan import org.apache.calcite.rel.`type`.RelDataType +import org.apache.calcite.rel.core.JoinRelType import org.apache.calcite.rel.core.JoinRelType._ +import org.apache.calcite.sql.`type`.SqlTypeName import org.apache.calcite.sql.`type`.SqlTypeName._ import org.apache.flink.api.common.typeinfo.BasicTypeInfo._ import org.apache.flink.api.common.typeinfo.{AtomicType, TypeInformation} import org.apache.flink.api.common.typeutils.CompositeType +import org.apache.flink.api.java.operators.join.JoinType import org.apache.flink.api.java.tuple.Tuple import org.apache.flink.api.java.typeutils.TupleTypeInfo import org.apache.flink.api.java.typeutils.ValueTypeInfo._ import org.apache.flink.api.table.typeinfo.RowTypeInfo import org.apache.flink.api.table.{Row, TableException} + import scala.collection.JavaConversions._ -import org.apache.flink.api.scala.typeutils.CaseClassTypeInfo -import org.apache.flink.api.java.operators.join.JoinType -import org.apache.calcite.rel.core.JoinRelType -import org.apache.calcite.sql.`type`.SqlTypeName object TypeConverter { @@ -55,11 +55,17 @@ object TypeConverter { case STRING_TYPE_INFO => VARCHAR case STRING_VALUE_TYPE_INFO => VARCHAR case DATE_TYPE_INFO => DATE + + case CHAR_TYPE_INFO | CHAR_VALUE_TYPE_INFO => + throw new TableException("Character type is not supported.") + // case t: TupleTypeInfo[_] => ROW // case c: CaseClassTypeInfo[_] => ROW // case p: PojoTypeInfo[_] => STRUCTURED // case g: GenericTypeInfo[_] => OTHER - case _ => ??? // TODO more types + + case t@_ => + throw new TableException(s"Type is not supported: $t") } def sqlTypeToTypeInfo(sqlType: SqlTypeName): TypeInformation[_] = sqlType match { diff --git a/flink-libraries/flink-table/src/test/java/org/apache/flink/api/java/table/test/CastingITCase.java b/flink-libraries/flink-table/src/test/java/org/apache/flink/api/java/table/test/CastingITCase.java index 957c09342a987..601f5f00037b8 100644 --- a/flink-libraries/flink-table/src/test/java/org/apache/flink/api/java/table/test/CastingITCase.java +++ b/flink-libraries/flink-table/src/test/java/org/apache/flink/api/java/table/test/CastingITCase.java @@ -89,7 +89,40 @@ public void testNumericAutocastInComparison() throws Exception { compareResultAsText(results, expected); } - @Test(expected = CodeGenException.class) + @Test + public void testCasting() throws Exception { + ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment(); + TableEnvironment tableEnv = new TableEnvironment(); + + DataSource> input = + env.fromElements(new Tuple4<>(1, 0.0, 1L, true)); + + Table table = + tableEnv.fromDataSet(input); + + Table result = table.select( + // * -> String + "f0.cast(STRING), f1.cast(STRING), f2.cast(STRING), f3.cast(STRING)," + + // NUMERIC TYPE -> Boolean + "f0.cast(BOOL), f1.cast(BOOL), f2.cast(BOOL)," + + // NUMERIC TYPE -> NUMERIC TYPE + "f0.cast(DOUBLE), f1.cast(INT), f2.cast(SHORT)," + + // Boolean -> NUMERIC TYPE + "f3.cast(DOUBLE)," + + // identity casting + "f0.cast(INT), f1.cast(DOUBLE), f2.cast(LONG), f3.cast(BOOL)"); + + DataSet ds = tableEnv.toDataSet(result, Row.class); + List results = ds.collect(); + String expected = "1,0.0,1,true," + + "true,false,true," + + "1.0,0,1," + + "1.0," + + "1,0.0,1,true\n"; + compareResultAsText(results, expected); + } + + @Test public void testCastFromString() throws Exception { ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment(); TableEnvironment tableEnv = new TableEnvironment(); diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/scala/table/test/CastingITCase.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/scala/table/test/CastingITCase.scala index d6a853d7e4a6f..9064d36099215 100644 --- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/scala/table/test/CastingITCase.scala +++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/scala/table/test/CastingITCase.scala @@ -78,35 +78,84 @@ class CastingITCase(mode: TestExecutionMode) extends MultipleProgramsTestBase(mo TestBaseUtils.compareResultAsText(results.asJava, expected) } - @Test(expected = classOf[CodeGenException]) - def testCastFromString: Unit = { + @Test + def testCasting(): Unit = { + + val env = ExecutionEnvironment.getExecutionEnvironment + val t = env.fromElements((1, 0.0, 1L, true)) + .toTable + .select( + // * -> String + '_1.cast(BasicTypeInfo.STRING_TYPE_INFO), + '_2.cast(BasicTypeInfo.STRING_TYPE_INFO), + '_3.cast(BasicTypeInfo.STRING_TYPE_INFO), + '_4.cast(BasicTypeInfo.STRING_TYPE_INFO), + // NUMERIC TYPE -> Boolean + '_1.cast(BasicTypeInfo.BOOLEAN_TYPE_INFO), + '_2.cast(BasicTypeInfo.BOOLEAN_TYPE_INFO), + '_3.cast(BasicTypeInfo.BOOLEAN_TYPE_INFO), + // NUMERIC TYPE -> NUMERIC TYPE + '_1.cast(BasicTypeInfo.DOUBLE_TYPE_INFO), + '_2.cast(BasicTypeInfo.INT_TYPE_INFO), + '_3.cast(BasicTypeInfo.SHORT_TYPE_INFO), + // Boolean -> NUMERIC TYPE + '_4.cast(BasicTypeInfo.DOUBLE_TYPE_INFO), + // identity casting + '_1.cast(BasicTypeInfo.INT_TYPE_INFO), + '_2.cast(BasicTypeInfo.DOUBLE_TYPE_INFO), + '_3.cast(BasicTypeInfo.LONG_TYPE_INFO), + '_4.cast(BasicTypeInfo.BOOLEAN_TYPE_INFO)) + + val expected = "1,0.0,1,true," + + "true,false,true," + + "1.0,0,1," + + "1.0," + + "1,0.0,1,true\n" + val results = t.toDataSet[Row].collect() + TestBaseUtils.compareResultAsText(results.asJava, expected) + } + + @Test + def testCastFromString(): Unit = { val env = ExecutionEnvironment.getExecutionEnvironment - val t = env.fromElements(("1", "true", "2.0", - "2011-05-03", "15:51:36", "2011-05-03 15:51:36.000", "1446473775")) + val t = env.fromElements(("1", "true", "2.0")) .toTable .select( + // String -> BASIC TYPE (not String, Date, Void, Character) '_1.cast(BasicTypeInfo.BYTE_TYPE_INFO), '_1.cast(BasicTypeInfo.SHORT_TYPE_INFO), '_1.cast(BasicTypeInfo.INT_TYPE_INFO), '_1.cast(BasicTypeInfo.LONG_TYPE_INFO), '_3.cast(BasicTypeInfo.DOUBLE_TYPE_INFO), '_3.cast(BasicTypeInfo.FLOAT_TYPE_INFO), - '_2.cast(BasicTypeInfo.BOOLEAN_TYPE_INFO), - '_4.cast(BasicTypeInfo.DATE_TYPE_INFO).cast(BasicTypeInfo.STRING_TYPE_INFO), - '_5.cast(BasicTypeInfo.DATE_TYPE_INFO).cast(BasicTypeInfo.STRING_TYPE_INFO), - '_6.cast(BasicTypeInfo.DATE_TYPE_INFO).cast(BasicTypeInfo.STRING_TYPE_INFO), - '_7.cast(BasicTypeInfo.DATE_TYPE_INFO).cast(BasicTypeInfo.STRING_TYPE_INFO)) + '_2.cast(BasicTypeInfo.BOOLEAN_TYPE_INFO)) + + val expected = "1,1,1,1,2.0,2.0,true\n" + val results = t.toDataSet[Row].collect() + TestBaseUtils.compareResultAsText(results.asJava, expected) + } + + @Test(expected = classOf[CodeGenException]) + def testCastDateFromString(): Unit = { + + val env = ExecutionEnvironment.getExecutionEnvironment + val t = env.fromElements(("2011-05-03", "15:51:36", "2011-05-03 15:51:36.000", "1446473775")) + .toTable + .select( + '_1.cast(BasicTypeInfo.DATE_TYPE_INFO).cast(BasicTypeInfo.STRING_TYPE_INFO), + '_2.cast(BasicTypeInfo.DATE_TYPE_INFO).cast(BasicTypeInfo.STRING_TYPE_INFO), + '_3.cast(BasicTypeInfo.DATE_TYPE_INFO).cast(BasicTypeInfo.STRING_TYPE_INFO), + '_4.cast(BasicTypeInfo.DATE_TYPE_INFO).cast(BasicTypeInfo.STRING_TYPE_INFO)) - val expected = "1,1,1,1,2.0,2.0,true," + - "2011-05-03 00:00:00.000,1970-01-01 15:51:36.000,2011-05-03 15:51:36.000," + + val expected = "2011-05-03 00:00:00.000,1970-01-01 15:51:36.000,2011-05-03 15:51:36.000," + "1970-01-17 17:47:53.775\n" val results = t.toDataSet[Row].collect() TestBaseUtils.compareResultAsText(results.asJava, expected) } @Test(expected = classOf[CodeGenException]) - def testCastDateToStringAndLong { + def testCastDateToStringAndLong(): Unit = { val env: ExecutionEnvironment = ExecutionEnvironment.getExecutionEnvironment val ds = env.fromElements(("2011-05-03 15:51:36.000", "1304437896000")) val t = ds.toTable @@ -119,7 +168,7 @@ class CastingITCase(mode: TestExecutionMode) extends MultipleProgramsTestBase(mo val expected = "2011-05-03 15:51:36.000,1304437896000," + "2011-05-03 15:51:36.000,1304437896000\n" - val result = t.toDataSet[Row].collect + val result = t.toDataSet[Row].collect() TestBaseUtils.compareResultAsText(result.asJava, expected) } }