From ba57e784d2e251b2b1b275724fe039f82789ff4e Mon Sep 17 00:00:00 2001 From: twalthr Date: Wed, 13 Apr 2016 16:17:07 +0200 Subject: [PATCH] [FLINK-3749] [table] Improve decimal handling --- .../api/table/codegen/CodeGenerator.scala | 16 ++-- .../api/table/typeutils/TypeConverter.scala | 2 +- .../scala/sql/test/ExpressionsITCase.scala | 89 +++++++++++++++++++ 3 files changed, 100 insertions(+), 7 deletions(-) create mode 100644 flink-libraries/flink-table/src/test/scala/org/apache/flink/api/scala/sql/test/ExpressionsITCase.scala diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/codegen/CodeGenerator.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/codegen/CodeGenerator.scala index f213d4cdffea4..0d4527efe71eb 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/codegen/CodeGenerator.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/codegen/CodeGenerator.scala @@ -551,23 +551,27 @@ class CodeGenerator( case BIGINT => val decimal = BigDecimal(value.asInstanceOf[java.math.BigDecimal]) if (decimal.isValidLong) { - generateNonNullLiteral(resultType, decimal.longValue().toString) + generateNonNullLiteral(resultType, decimal.longValue().toString + "L") } else { throw new CodeGenException("Decimal can not be converted to long.") } case FLOAT => val decimal = BigDecimal(value.asInstanceOf[java.math.BigDecimal]) - if (decimal.isValidFloat) { - generateNonNullLiteral(resultType, decimal.floatValue().toString + "f") + // check if we loose/change digits when converting to float + val converted = BigDecimal(decimal.floatValue().toString) + if (converted == decimal) { + generateNonNullLiteral(resultType, converted.toString + "f") } else { throw new CodeGenException("Decimal can not be converted to float.") } - case DOUBLE => + case DOUBLE | DECIMAL => val decimal = BigDecimal(value.asInstanceOf[java.math.BigDecimal]) - if (decimal.isValidDouble) { - generateNonNullLiteral(resultType, decimal.doubleValue().toString) + // check if we loose/change digits when converting to double + val converted = BigDecimal(decimal.doubleValue().toString) + if (converted == decimal) { + generateNonNullLiteral(resultType, converted.toString() + "d") } else { throw new CodeGenException("Decimal can not be converted to double.") diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/typeutils/TypeConverter.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/typeutils/TypeConverter.scala index dc3abb7ceebd5..7090c6a4cf155 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/typeutils/TypeConverter.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/typeutils/TypeConverter.scala @@ -71,7 +71,7 @@ object TypeConverter { case INTEGER => INT_TYPE_INFO case BIGINT => LONG_TYPE_INFO case FLOAT => FLOAT_TYPE_INFO - case DOUBLE => DOUBLE_TYPE_INFO + case DOUBLE | DECIMAL => DOUBLE_TYPE_INFO case VARCHAR | CHAR => STRING_TYPE_INFO case DATE => DATE_TYPE_INFO diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/scala/sql/test/ExpressionsITCase.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/scala/sql/test/ExpressionsITCase.scala new file mode 100644 index 0000000000000..2ff7b09013815 --- /dev/null +++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/scala/sql/test/ExpressionsITCase.scala @@ -0,0 +1,89 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.api.scala.sql.test + +import org.apache.flink.api.scala.{ExecutionEnvironment, _} +import org.apache.flink.api.scala.table._ +import org.apache.flink.api.table.Row +import org.apache.flink.api.table.codegen.CodeGenException +import org.apache.flink.api.table.plan.TranslationContext +import org.apache.flink.api.table.test.utils.TableProgramsTestBase +import org.apache.flink.api.table.test.utils.TableProgramsTestBase.TableConfigMode +import org.apache.flink.test.util.MultipleProgramsTestBase.TestExecutionMode +import org.apache.flink.test.util.TestBaseUtils +import org.junit.Test +import org.junit.runner.RunWith +import org.junit.runners.Parameterized + +import scala.collection.JavaConverters._ + +@RunWith(classOf[Parameterized]) +class ExpressionsITCase( + mode: TestExecutionMode, + configMode: TableConfigMode) + extends TableProgramsTestBase(mode, configMode) { + + @Test + def testExactDecimal(): Unit = { + + val env = ExecutionEnvironment.getExecutionEnvironment + val tEnv = getScalaTableEnvironment + TranslationContext.reset() + + val sqlQuery = s"SELECT 11.2, 0.7623533651719233, 7623533651719233.0, " + + s"${Double.MaxValue}, ${Double.MinValue}, " + + s"CAST(${Float.MaxValue} AS FLOAT), CAST(${Float.MinValue} AS FLOAT), " + + s"CAST(${Byte.MaxValue} AS TINYINT), CAST(${Byte.MinValue} AS TINYINT), " + + s"CAST(${Short.MaxValue} AS SMALLINT), CAST(${Short.MinValue} AS SMALLINT), " + + s"CAST(${Int.MaxValue} AS INTEGER), CAST(${Int.MinValue} AS INTEGER), " + + s"CAST(${Long.MaxValue} AS BIGINT), CAST(${Long.MinValue} AS BIGINT) FROM MyTable" + + val ds = env.fromElements((1, 0)) + tEnv.registerDataSet("MyTable", ds, 'a, 'b) + + val result = tEnv.sql(sqlQuery) + + val expected = "11.2,0.7623533651719233,7.623533651719233E15," + + "1.7976931348623157E308,-1.7976931348623157E308," + + "3.4028235E38,-3.4028235E38," + + "127,-128," + + "32767,-32768," + + "2147483647,-2147483648," + + "9223372036854775807,-9223372036854775808" + val results = result.toDataSet[Row](getConfig).collect() + TestBaseUtils.compareResultAsText(results.asJava, expected) + } + + @Test(expected = classOf[CodeGenException]) + def testUnsupportedDecimal(): Unit = { + val env = ExecutionEnvironment.getExecutionEnvironment + val tEnv = getScalaTableEnvironment + TranslationContext.reset() + + val sqlQuery = s"SELECT 0.76235336517192335 FROM MyTable" + + val ds = env.fromElements((1, 0)) + tEnv.registerDataSet("MyTable", ds, 'a, 'b) + + val result = tEnv.sql(sqlQuery) + + result.toDataSet[Row](getConfig).collect() + } + +}