diff --git a/core/src/main/java/org/apache/calcite/rex/RexBuilder.java b/core/src/main/java/org/apache/calcite/rex/RexBuilder.java index e58fe32c57a4..a6d3ed8664fa 100644 --- a/core/src/main/java/org/apache/calcite/rex/RexBuilder.java +++ b/core/src/main/java/org/apache/calcite/rex/RexBuilder.java @@ -648,6 +648,12 @@ boolean canRemoveCastFromLiteral(RelDataType toType, Comparable value, throw new AssertionError(toType); } } + + if (toType.getSqlTypeName() == SqlTypeName.DECIMAL) { + final BigDecimal decimalValue = (BigDecimal) value; + return SqlTypeUtil.isValidDecimalValue(decimalValue, toType); + } + return true; } @@ -952,6 +958,11 @@ protected RexLiteral makeLiteral( o = ((TimestampString) o).round(p); break; } + if (type.getSqlTypeName() == SqlTypeName.DECIMAL && !SqlTypeUtil + .isValidDecimalValue((BigDecimal) o, type)) { + throw new IllegalArgumentException( + "Cannot convert " + o + " to " + type + " due to overflow"); + } return new RexLiteral(o, type, typeName); } diff --git a/core/src/main/java/org/apache/calcite/sql/type/SqlTypeUtil.java b/core/src/main/java/org/apache/calcite/sql/type/SqlTypeUtil.java index 009ef8ff3dff..f9f4084d5db2 100644 --- a/core/src/main/java/org/apache/calcite/sql/type/SqlTypeUtil.java +++ b/core/src/main/java/org/apache/calcite/sql/type/SqlTypeUtil.java @@ -46,6 +46,7 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.Sets; +import java.math.BigDecimal; import java.nio.charset.Charset; import java.util.AbstractList; import java.util.ArrayList; @@ -1751,4 +1752,26 @@ public static RelDataType extractLastNFields(RelDataTypeFactory typeFactory, return typeFactory.createStructType( type.getFieldList().subList(fieldsCnt - numToKeep, fieldsCnt)); } + + /** + * Returns whether the decimal value is valid for the type. For example, 1111.11 is not + * valid for DECIMAL(3, 1) since it overflows. + * + * @param value Value of literal + * @param toType Type of the literal + * @return whether the value is valid for the type + */ + public static boolean isValidDecimalValue(BigDecimal value, RelDataType toType) { + if (value == null) { + return true; + } + switch (toType.getSqlTypeName()) { + case DECIMAL: + final int intDigits = value.precision() - value.scale(); + final int maxIntDigits = toType.getPrecision() - toType.getScale(); + return intDigits <= maxIntDigits; + default: + return true; + } + } } diff --git a/core/src/test/java/org/apache/calcite/rel/rel2sql/RelToSqlConverterTest.java b/core/src/test/java/org/apache/calcite/rel/rel2sql/RelToSqlConverterTest.java index f1190ede2613..b5ca456c3f94 100644 --- a/core/src/test/java/org/apache/calcite/rel/rel2sql/RelToSqlConverterTest.java +++ b/core/src/test/java/org/apache/calcite/rel/rel2sql/RelToSqlConverterTest.java @@ -5136,6 +5136,21 @@ private void checkLiteral2(String expression, String expected) { sql(expected).exec(); } + @Test void testCastDecimalOverflow() { + final String query = + "SELECT CAST('11111111111111111111111111111111.111111' AS DECIMAL(38,6)) AS \"num\" from \"product\""; + final String expected = + "SELECT CAST('11111111111111111111111111111111.111111' AS DECIMAL(19, 6)) AS \"num\"\n" + + "FROM \"foodmart\".\"product\""; + sql(query).ok(expected); + + final String query2 = + "SELECT CAST(1111111 AS DECIMAL(5,2)) AS \"num\" from \"product\""; + final String expected2 = + "SELECT CAST(1111111 AS DECIMAL(5, 2)) AS \"num\"\nFROM \"foodmart\".\"product\""; + sql(query2).ok(expected2); + } + @Test void testCastInStringIntegerComparison() { final String query = "select \"employee_id\" " + "from \"foodmart\".\"employee\" " diff --git a/core/src/test/java/org/apache/calcite/rex/RexBuilderTest.java b/core/src/test/java/org/apache/calcite/rex/RexBuilderTest.java index 8c8275535d46..895f290882da 100644 --- a/core/src/test/java/org/apache/calcite/rex/RexBuilderTest.java +++ b/core/src/test/java/org/apache/calcite/rex/RexBuilderTest.java @@ -21,6 +21,7 @@ import org.apache.calcite.rel.type.RelDataType; import org.apache.calcite.rel.type.RelDataTypeFactory; import org.apache.calcite.rel.type.RelDataTypeSystem; +import org.apache.calcite.rel.type.RelDataTypeSystemImpl; import org.apache.calcite.sql.SqlCollation; import org.apache.calcite.sql.fun.SqlStdOperatorTable; import org.apache.calcite.sql.type.BasicSqlType; @@ -590,8 +591,11 @@ private void checkDate(RexNode node) { /** Tests {@link RexBuilder#makeExactLiteral(java.math.BigDecimal)}. */ @Test void testBigDecimalLiteral() { - final RelDataTypeFactory typeFactory = - new SqlTypeFactoryImpl(RelDataTypeSystem.DEFAULT); + final RelDataTypeFactory typeFactory = new SqlTypeFactoryImpl(new RelDataTypeSystemImpl() { + @Override public int getMaxPrecision(SqlTypeName typeName) { + return 38; + } + }); final RexBuilder builder = new RexBuilder(typeFactory); checkBigDecimalLiteral(builder, "25"); checkBigDecimalLiteral(builder, "9.9"); diff --git a/core/src/test/java/org/apache/calcite/test/SqlToRelConverterTest.java b/core/src/test/java/org/apache/calcite/test/SqlToRelConverterTest.java index b0d23969ef31..fc6be77be06a 100644 --- a/core/src/test/java/org/apache/calcite/test/SqlToRelConverterTest.java +++ b/core/src/test/java/org/apache/calcite/test/SqlToRelConverterTest.java @@ -2633,7 +2633,7 @@ public final Sql sql(String sql) { } @Test void testReduceConstExpr() { - final String sql = "select sum(case when 'y' = 'n' then ename else 1 end) from emp"; + final String sql = "select sum(case when 'y' = 'n' then ename else 0.1 end) from emp"; sql(sql).ok(); } diff --git a/core/src/test/resources/org/apache/calcite/test/SqlToRelConverterTest.xml b/core/src/test/resources/org/apache/calcite/test/SqlToRelConverterTest.xml index 53ef390bd672..7d171fa9dd16 100644 --- a/core/src/test/resources/org/apache/calcite/test/SqlToRelConverterTest.xml +++ b/core/src/test/resources/org/apache/calcite/test/SqlToRelConverterTest.xml @@ -3724,7 +3724,7 @@ LogicalProject(DEPTNO=[$7]) diff --git a/core/src/test/resources/sql/misc.iq b/core/src/test/resources/sql/misc.iq index d30fd564614e..f87de8a03063 100644 --- a/core/src/test/resources/sql/misc.iq +++ b/core/src/test/resources/sql/misc.iq @@ -1689,7 +1689,7 @@ EnumerableCalc(expr#0=[{inputs}], expr#1=[123:BIGINT], EXPR$0=[$t1]) !plan # Cast an integer literal to a decimal; note: the plan does not contain CAST -values cast('123.45' as decimal(4, 2)); +values cast('123.45' as decimal(5, 2)); +--------+ | EXPR$0 | +--------+ @@ -1698,12 +1698,12 @@ values cast('123.45' as decimal(4, 2)); (1 row) !ok -EnumerableCalc(expr#0=[{inputs}], expr#1=[123.45:DECIMAL(4, 2)], EXPR$0=[$t1]) +EnumerableCalc(expr#0=[{inputs}], expr#1=[123.45:DECIMAL(5, 2)], EXPR$0=[$t1]) EnumerableValues(tuples=[[{ 0 }]]) !plan # Cast a character literal to a decimal; note: the plan does not contain CAST -values cast('123.45' as decimal(4, 2)); +values cast('123.45' as decimal(5, 2)); +--------+ | EXPR$0 | +--------+ @@ -1712,7 +1712,7 @@ values cast('123.45' as decimal(4, 2)); (1 row) !ok -EnumerableCalc(expr#0=[{inputs}], expr#1=[123.45:DECIMAL(4, 2)], EXPR$0=[$t1]) +EnumerableCalc(expr#0=[{inputs}], expr#1=[123.45:DECIMAL(5, 2)], EXPR$0=[$t1]) EnumerableValues(tuples=[[{ 0 }]]) !plan