From fd783341b93d7a7ffba11c827a317adb35248ea0 Mon Sep 17 00:00:00 2001 From: twalthr Date: Sat, 13 Feb 2016 12:38:12 +0100 Subject: [PATCH 1/2] [FLINK-3226] Casting support for arithmetic operators --- .../api/table/codegen/OperatorCodeGen.scala | 42 +++++++++++++-- .../api/java/table/test/CastingITCase.java | 54 +++++++++---------- .../api/scala/table/test/CastingITCase.scala | 47 ++++++++-------- 3 files changed, 90 insertions(+), 53 deletions(-) diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/codegen/OperatorCodeGen.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/codegen/OperatorCodeGen.scala index 8402569b518d3..95a7e9d054986 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/codegen/OperatorCodeGen.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/codegen/OperatorCodeGen.scala @@ -18,20 +18,56 @@ package org.apache.flink.api.table.codegen import org.apache.flink.api.common.typeinfo.BasicTypeInfo.BOOLEAN_TYPE_INFO -import org.apache.flink.api.common.typeinfo.TypeInformation +import org.apache.flink.api.common.typeinfo.{NumericTypeInfo, TypeInformation} import org.apache.flink.api.table.codegen.CodeGenUtils._ object OperatorCodeGen { - def generateArithmeticOperator( + def generateArithmeticOperator( operator: String, nullCheck: Boolean, resultType: TypeInformation[_], left: GeneratedExpression, right: GeneratedExpression) : GeneratedExpression = { - generateOperatorIfNotNull(nullCheck, resultType, left, right) { + // String arithmetic // TODO rework + if (isString(left)) { + generateOperatorIfNotNull(nullCheck, resultType, left, right) { (leftTerm, rightTerm) => s"$leftTerm $operator $rightTerm" + } + } + // Numeric arithmetic + else if (isNumeric(left) && isNumeric(right)) { + val leftType = left.resultType.asInstanceOf[NumericTypeInfo[_]] + val rightType = right.resultType.asInstanceOf[NumericTypeInfo[_]] + + generateOperatorIfNotNull(nullCheck, resultType, left, right) { + (leftTerm, rightTerm) => + // insert auto casting for "narrowing primitive conversions" + if (leftType != rightType) { + // leftType can not be casted to rightType automatically -> narrow + if (!leftType.shouldAutocastTo(rightType)) { + val typeTerm = primitiveTypeTermForTypeInfo(rightType) + s"(($typeTerm) $leftTerm) $operator $rightTerm" + } + // rightType can not be casted to leftType automatically -> narrow + else if (!rightType.shouldAutocastTo(leftType)) { + val typeTerm = primitiveTypeTermForTypeInfo(leftType) + s"$leftTerm $operator (($typeTerm) $rightTerm)" + } + // no narrowing required, widening happens implicitly + else { + s"$leftTerm $operator $rightTerm" + } + } + // no casting / conversion required + else { + s"$leftTerm $operator $rightTerm" + } + } + } + else { + throw new CodeGenException("Unsupported arithmetic operation.") } } diff --git a/flink-libraries/flink-table/src/test/java/org/apache/flink/api/java/table/test/CastingITCase.java b/flink-libraries/flink-table/src/test/java/org/apache/flink/api/java/table/test/CastingITCase.java index 957c09342a987..e1553c961fe40 100644 --- a/flink-libraries/flink-table/src/test/java/org/apache/flink/api/java/table/test/CastingITCase.java +++ b/flink-libraries/flink-table/src/test/java/org/apache/flink/api/java/table/test/CastingITCase.java @@ -18,77 +18,75 @@ package org.apache.flink.api.java.table.test; +import java.util.List; +import org.apache.flink.api.java.DataSet; +import org.apache.flink.api.java.ExecutionEnvironment; +import org.apache.flink.api.java.operators.DataSource; +import org.apache.flink.api.java.table.TableEnvironment; import org.apache.flink.api.java.tuple.Tuple2; import org.apache.flink.api.java.tuple.Tuple3; import org.apache.flink.api.java.tuple.Tuple4; -import org.apache.flink.api.table.Table; +import org.apache.flink.api.java.tuple.Tuple6; +import org.apache.flink.api.java.tuple.Tuple7; import org.apache.flink.api.table.Row; +import org.apache.flink.api.table.Table; import org.apache.flink.api.table.codegen.CodeGenException; -import org.apache.flink.api.java.DataSet; -import org.apache.flink.api.java.ExecutionEnvironment; -import org.apache.flink.api.java.table.TableEnvironment; -import org.apache.flink.api.java.operators.DataSource; -import org.apache.flink.api.java.tuple.Tuple7; -import org.apache.flink.test.util.MultipleProgramsTestBase; -import org.junit.Ignore; +import org.apache.flink.api.table.test.TableProgramsTestBase; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.Parameterized; -import scala.NotImplementedError; - -import java.util.List; - @RunWith(Parameterized.class) -public class CastingITCase extends MultipleProgramsTestBase { +public class CastingITCase extends TableProgramsTestBase { - public CastingITCase(TestExecutionMode mode){ - super(mode); + public CastingITCase(TestExecutionMode mode, TableConfigMode configMode){ + super(mode, configMode); } - @Ignore - @Test(expected = NotImplementedError.class) + @Test public void testNumericAutocastInArithmetic() throws Exception { ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment(); - TableEnvironment tableEnv = new TableEnvironment(); + TableEnvironment tableEnv = getJavaTableEnvironment(); - DataSource> input = - env.fromElements(new Tuple7<>((byte) 1, (short) 1, 1, 1L, 1.0f, 1.0d, "Hello")); + DataSource> input = + env.fromElements(new Tuple7<>((byte) 1, (short) 1, 1, 1L, 1.0f, 1.0d, 1L)); Table table = tableEnv.fromDataSet(input); Table result = table.select("f0 + 1, f1 +" + - " 1, f2 + 1L, f3 + 1.0f, f4 + 1.0d, f5 + 1"); + " 1, f2 + 1L, f3 + 1.0f, f4 + 1.0d, f5 + 1, f6 + 1.0d"); DataSet ds = tableEnv.toDataSet(result, Row.class); List results = ds.collect(); - String expected = "2,2,2,2.0,2.0,2.0"; + String expected = "2,2,2,2.0,2.0,2.0,2.0"; compareResultAsText(results, expected); } @Test public void testNumericAutocastInComparison() throws Exception { ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment(); - TableEnvironment tableEnv = new TableEnvironment(); + TableEnvironment tableEnv = getJavaTableEnvironment(); - DataSource> input = + DataSource> input = env.fromElements( - new Tuple7<>((byte) 1, (short) 1, 1, 1L, 1.0f, 1.0d, "Hello"), - new Tuple7<>((byte) 2, (short) 2, 2, 2L, 2.0f, 2.0d, "Hello")); + new Tuple6<>((byte) 1, (short) 1, 1, 1L, 1.0f, 1.0d), + new Tuple6<>((byte) 2, (short) 2, 2, 2L, 2.0f, 2.0d)); Table table = - tableEnv.fromDataSet(input, "a,b,c,d,e,f,g"); + tableEnv.fromDataSet(input, "a,b,c,d,e,f"); Table result = table .filter("a > 1 && b > 1 && c > 1L && d > 1.0f && e > 1.0d && f > 1"); DataSet ds = tableEnv.toDataSet(result, Row.class); List results = ds.collect(); - String expected = "2,2,2,2,2.0,2.0,Hello"; + String expected = "2,2,2,2,2.0,2.0"; compareResultAsText(results, expected); } + // TODO support advanced String operations + @Test(expected = CodeGenException.class) public void testCastFromString() throws Exception { ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment(); diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/scala/table/test/CastingITCase.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/scala/table/test/CastingITCase.scala index d6a853d7e4a6f..e2b708248cac1 100644 --- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/scala/table/test/CastingITCase.scala +++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/scala/table/test/CastingITCase.scala @@ -19,45 +19,33 @@ package org.apache.flink.api.scala.table.test import java.util.Date -import org.junit._ -import org.junit.runner.RunWith -import org.junit.runners.Parameterized + import org.apache.flink.api.common.typeinfo.BasicTypeInfo import org.apache.flink.api.scala._ import org.apache.flink.api.scala.table._ import org.apache.flink.api.table.Row -import org.apache.flink.test.util.{TestBaseUtils, MultipleProgramsTestBase} +import org.apache.flink.api.table.codegen.CodeGenException import org.apache.flink.test.util.MultipleProgramsTestBase.TestExecutionMode +import org.apache.flink.test.util.{MultipleProgramsTestBase, TestBaseUtils} +import org.junit._ +import org.junit.runner.RunWith +import org.junit.runners.Parameterized + import scala.collection.JavaConverters._ -import org.apache.flink.api.table.codegen.CodeGenException @RunWith(classOf[Parameterized]) class CastingITCase(mode: TestExecutionMode) extends MultipleProgramsTestBase(mode) { - @Ignore // String autocasting not yet supported @Test - def testAutoCastToString(): Unit = { - - val env = ExecutionEnvironment.getExecutionEnvironment - val t = env.fromElements((1: Byte, 1: Short, 1, 1L, 1.0f, 1.0d, new Date(0))).toTable - .select('_1 + "b", '_2 + "s", '_3 + "i", '_4 + "L", '_5 + "f", '_6 + "d", '_7 + "Date") - - val expected = "1b,1s,1i,1L,1.0f,1.0d,1970-01-01 00:00:00.000Date" - val results = t.toDataSet[Row].collect() - TestBaseUtils.compareResultAsText(results.asJava, expected) - } - - @Ignore // gives different types of exceptions for cluster and collection modes - @Test(expected = classOf[NotImplementedError]) def testNumericAutoCastInArithmetic(): Unit = { // don't test everything, just some common cast directions val env = ExecutionEnvironment.getExecutionEnvironment - val t = env.fromElements((1: Byte, 1: Short, 1, 1L, 1.0f, 1.0d)).toTable - .select('_1 + 1, '_2 + 1, '_3 + 1L, '_4 + 1.0f, '_5 + 1.0d, '_6 + 1) + val t = env.fromElements((1: Byte, 1: Short, 1, 1L, 1.0f, 1.0d, 1L)).toTable + .select('_1 + 1, '_2 + 1, '_3 + 1L, '_4 + 1.0f, '_5 + 1.0d, '_6 + 1, '_7 + 1.0d) - val expected = "2,2,2,2.0,2.0,2.0" + val expected = "2,2,2,2.0,2.0,2.0,2.0" val results = t.toDataSet[Row].collect() TestBaseUtils.compareResultAsText(results.asJava, expected) } @@ -78,6 +66,21 @@ class CastingITCase(mode: TestExecutionMode) extends MultipleProgramsTestBase(mo TestBaseUtils.compareResultAsText(results.asJava, expected) } + // TODO support advanced String operations + + @Ignore + @Test + def testAutoCastToString(): Unit = { + + val env = ExecutionEnvironment.getExecutionEnvironment + val t = env.fromElements((1: Byte, 1: Short, 1, 1L, 1.0f, 1.0d, new Date(0))).toTable + .select('_1 + "b", '_2 + "s", '_3 + "i", '_4 + "L", '_5 + "f", '_6 + "d", '_7 + "Date") + + val expected = "1b,1s,1i,1L,1.0f,1.0d,1970-01-01 00:00:00.000Date" + val results = t.toDataSet[Row].collect() + TestBaseUtils.compareResultAsText(results.asJava, expected) + } + @Test(expected = classOf[CodeGenException]) def testCastFromString: Unit = { From 2abe12f10768bc4d6397e87f88a5f661a4850824 Mon Sep 17 00:00:00 2001 From: twalthr Date: Mon, 15 Feb 2016 22:55:52 +0100 Subject: [PATCH 2/2] Casting reworked --- .../api/table/codegen/OperatorCodeGen.scala | 32 ++++++++----------- .../api/java/table/test/CastingITCase.java | 9 +++--- .../api/scala/table/test/CastingITCase.scala | 6 ++-- 3 files changed, 22 insertions(+), 25 deletions(-) diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/codegen/OperatorCodeGen.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/codegen/OperatorCodeGen.scala index 95a7e9d054986..0f8083e311a7c 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/codegen/OperatorCodeGen.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/codegen/OperatorCodeGen.scala @@ -40,29 +40,25 @@ object OperatorCodeGen { else if (isNumeric(left) && isNumeric(right)) { val leftType = left.resultType.asInstanceOf[NumericTypeInfo[_]] val rightType = right.resultType.asInstanceOf[NumericTypeInfo[_]] + val resultTypeTerm = primitiveTypeTermForTypeInfo(resultType) generateOperatorIfNotNull(nullCheck, resultType, left, right) { (leftTerm, rightTerm) => - // insert auto casting for "narrowing primitive conversions" - if (leftType != rightType) { - // leftType can not be casted to rightType automatically -> narrow - if (!leftType.shouldAutocastTo(rightType)) { - val typeTerm = primitiveTypeTermForTypeInfo(rightType) - s"(($typeTerm) $leftTerm) $operator $rightTerm" - } - // rightType can not be casted to leftType automatically -> narrow - else if (!rightType.shouldAutocastTo(leftType)) { - val typeTerm = primitiveTypeTermForTypeInfo(leftType) - s"$leftTerm $operator (($typeTerm) $rightTerm)" - } - // no narrowing required, widening happens implicitly - else { - s"$leftTerm $operator $rightTerm" - } + // no casting required + if (leftType == resultType && rightType == resultType) { + s"$leftTerm $operator $rightTerm" + } + // left needs casting + else if (leftType != resultType && rightType == resultType) { + s"(($resultTypeTerm) $leftTerm) $operator $rightTerm" } - // no casting / conversion required + // right needs casting + else if (leftType == resultType && rightType != resultType) { + s"$leftTerm $operator (($resultTypeTerm) $rightTerm)" + } + // both sides need casting else { - s"$leftTerm $operator $rightTerm" + s"(($resultTypeTerm) $leftTerm) $operator (($resultTypeTerm) $rightTerm)" } } } diff --git a/flink-libraries/flink-table/src/test/java/org/apache/flink/api/java/table/test/CastingITCase.java b/flink-libraries/flink-table/src/test/java/org/apache/flink/api/java/table/test/CastingITCase.java index e1553c961fe40..e5b5f582361fc 100644 --- a/flink-libraries/flink-table/src/test/java/org/apache/flink/api/java/table/test/CastingITCase.java +++ b/flink-libraries/flink-table/src/test/java/org/apache/flink/api/java/table/test/CastingITCase.java @@ -28,6 +28,7 @@ import org.apache.flink.api.java.tuple.Tuple4; import org.apache.flink.api.java.tuple.Tuple6; import org.apache.flink.api.java.tuple.Tuple7; +import org.apache.flink.api.java.tuple.Tuple8; import org.apache.flink.api.table.Row; import org.apache.flink.api.table.Table; import org.apache.flink.api.table.codegen.CodeGenException; @@ -48,18 +49,18 @@ public void testNumericAutocastInArithmetic() throws Exception { ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment(); TableEnvironment tableEnv = getJavaTableEnvironment(); - DataSource> input = - env.fromElements(new Tuple7<>((byte) 1, (short) 1, 1, 1L, 1.0f, 1.0d, 1L)); + DataSource> input = + env.fromElements(new Tuple8<>((byte) 1, (short) 1, 1, 1L, 1.0f, 1.0d, 1L, 1001.1)); Table table = tableEnv.fromDataSet(input); Table result = table.select("f0 + 1, f1 +" + - " 1, f2 + 1L, f3 + 1.0f, f4 + 1.0d, f5 + 1, f6 + 1.0d"); + " 1, f2 + 1L, f3 + 1.0f, f4 + 1.0d, f5 + 1, f6 + 1.0d, f7 + f0"); DataSet ds = tableEnv.toDataSet(result, Row.class); List results = ds.collect(); - String expected = "2,2,2,2.0,2.0,2.0,2.0"; + String expected = "2,2,2,2.0,2.0,2.0,2.0,1002.1"; compareResultAsText(results, expected); } diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/scala/table/test/CastingITCase.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/scala/table/test/CastingITCase.scala index e2b708248cac1..6121cb696c375 100644 --- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/scala/table/test/CastingITCase.scala +++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/scala/table/test/CastingITCase.scala @@ -42,10 +42,10 @@ class CastingITCase(mode: TestExecutionMode) extends MultipleProgramsTestBase(mo // don't test everything, just some common cast directions val env = ExecutionEnvironment.getExecutionEnvironment - val t = env.fromElements((1: Byte, 1: Short, 1, 1L, 1.0f, 1.0d, 1L)).toTable - .select('_1 + 1, '_2 + 1, '_3 + 1L, '_4 + 1.0f, '_5 + 1.0d, '_6 + 1, '_7 + 1.0d) + val t = env.fromElements((1: Byte, 1: Short, 1, 1L, 1.0f, 1.0d, 1L, 1001.1)).toTable + .select('_1 + 1, '_2 + 1, '_3 + 1L, '_4 + 1.0f, '_5 + 1.0d, '_6 + 1, '_7 + 1.0d, '_8 + '_1) - val expected = "2,2,2,2.0,2.0,2.0,2.0" + val expected = "2,2,2,2.0,2.0,2.0,2.0,1002.1" val results = t.toDataSet[Row].collect() TestBaseUtils.compareResultAsText(results.asJava, expected) }