Skip to content

Commit

Permalink
[CALCITE-6389] RexBuilder.removeCastFromLiteral does not preserve sem…
Browse files Browse the repository at this point in the history
…antics for some types of literal

Signed-off-by: Mihai Budiu <mbudiu@feldera.com>
  • Loading branch information
mihaibudiu committed May 10, 2024
1 parent ea441e7 commit c228804
Show file tree
Hide file tree
Showing 4 changed files with 125 additions and 26 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -268,22 +268,29 @@ Expression translate(RexNode expr, RexImpTable.NullAs nullAs,
/**
* Used for safe operators that return null if an exception is thrown.
*/
private static Expression expressionHandlingSafe(Expression body, boolean safe) {
return safe ? safeExpression(body) : body;
private Expression expressionHandlingSafe(
Expression body, boolean safe, RelDataType targetType) {
return safe ? safeExpression(body, targetType) : body;
}

private static Expression safeExpression(Expression body) {
private Expression safeExpression(Expression body, RelDataType targetType) {
final ParameterExpression e_ =
Expressions.parameter(Exception.class, new BlockBuilder().newName("e"));

return Expressions.call(
Expressions.lambda(
Expressions.block(
Expressions.tryCatch(
Expressions.return_(null, body),
Expressions.catch_(e_,
Expressions.return_(null, constant(null)))))),
BuiltInMethod.FUNCTION0_APPLY.method);
// The type received for the targetType is never nullable.
// But safe casts may return null
RelDataType nullableTargetType = typeFactory.createTypeWithNullability(targetType, true);
Expression result =
Expressions.call(
Expressions.lambda(
Expressions.block(
Expressions.tryCatch(
Expressions.return_(null, body),
Expressions.catch_(e_,
Expressions.return_(null, constant(null)))))),
BuiltInMethod.FUNCTION0_APPLY.method);
// FUNCTION0 always returns Object, so we need a cast to the target type
return EnumUtils.convert(result, typeFactory.getJavaClass(nullableTargetType));
}

Expression translateCast(
Expand All @@ -294,7 +301,7 @@ Expression translateCast(
ConstantExpression format) {
Expression convert = getConvertExpression(sourceType, targetType, operand, format);
Expression convert2 = checkExpressionPadTruncate(convert, sourceType, targetType);
Expression convert3 = expressionHandlingSafe(convert2, safe);
Expression convert3 = expressionHandlingSafe(convert2, safe, targetType);
return scaleValue(sourceType, targetType, convert3);
}

Expand Down
37 changes: 23 additions & 14 deletions core/src/main/java/org/apache/calcite/rex/RexBuilder.java
Original file line number Diff line number Diff line change
Expand Up @@ -694,7 +694,9 @@ boolean canRemoveCastFromLiteral(RelDataType toType, @Nullable Comparable value,
return false;
}
if (toType.getSqlTypeName() != fromTypeName
&& SqlTypeFamily.DATETIME.getTypeNames().contains(fromTypeName)) {
&& (SqlTypeFamily.DATETIME.getTypeNames().contains(fromTypeName)
|| SqlTypeFamily.INTERVAL_DAY_TIME.getTypeNames().contains(fromTypeName)
|| SqlTypeFamily.INTERVAL_YEAR_MONTH.getTypeNames().contains(fromTypeName))) {
return false;
}
if (value instanceof NlsString) {
Expand All @@ -720,9 +722,10 @@ boolean canRemoveCastFromLiteral(RelDataType toType, @Nullable Comparable value,
}
}

if (toType.getSqlTypeName() == SqlTypeName.DECIMAL) {
if (toType.getSqlTypeName() == SqlTypeName.DECIMAL
&& fromTypeName.getFamily() == SqlTypeFamily.NUMERIC) {
final BigDecimal decimalValue = (BigDecimal) value;
return SqlTypeUtil.isValidDecimalValue(decimalValue, toType);
return SqlTypeUtil.canBeRepresentedExactly(decimalValue, toType);
}

if (SqlTypeName.INT_TYPES.contains(sqlType)) {
Expand All @@ -731,17 +734,23 @@ boolean canRemoveCastFromLiteral(RelDataType toType, @Nullable Comparable value,
if (s != 0) {
return false;
}
long l = decimalValue.longValue();
switch (sqlType) {
case TINYINT:
return l >= Byte.MIN_VALUE && l <= Byte.MAX_VALUE;
case SMALLINT:
return l >= Short.MIN_VALUE && l <= Short.MAX_VALUE;
case INTEGER:
return l >= Integer.MIN_VALUE && l <= Integer.MAX_VALUE;
case BIGINT:
default:
return true;
try {
// will trigger ArithmeticException when the value
// cannot be represented exactly as a long
long l = decimalValue.longValueExact();
switch (sqlType) {
case TINYINT:
return l >= Byte.MIN_VALUE && l <= Byte.MAX_VALUE;
case SMALLINT:
return l >= Short.MIN_VALUE && l <= Short.MAX_VALUE;
case INTEGER:
return l >= Integer.MIN_VALUE && l <= Integer.MAX_VALUE;
case BIGINT:
default:
return true;
}
} catch (ArithmeticException ex) {
return false;
}
}

Expand Down
30 changes: 30 additions & 0 deletions core/src/main/java/org/apache/calcite/sql/type/SqlTypeUtil.java
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
import org.checkerframework.checker.nullness.qual.Nullable;

import java.math.BigDecimal;
import java.math.RoundingMode;
import java.nio.charset.Charset;
import java.util.AbstractList;
import java.util.ArrayList;
Expand Down Expand Up @@ -1822,6 +1823,35 @@ public static RelDataType extractLastNFields(RelDataTypeFactory typeFactory,
type.getFieldList().subList(fieldsCnt - numToKeep, fieldsCnt));
}

/**
* Returns whether the decimal value can be represented without information loss
* using the specified type.
* For example, 1111.11
* - cannot be represented exactly using DECIMAL(3, 1) since it overflows.
* - cannot be represented exactly using DECIMAL(6, 3) since it overflows.
* - cannot be represented exactly using DECIMAL(6, 1) since it requires rounding.
* - can be represented exactly using DECIMAL(6, 2)
*
* @param value A decimal value
* @param toType A DECIMAL type.
* @return whether the value is valid for the type
*/
public static boolean canBeRepresentedExactly(@Nullable BigDecimal value, RelDataType toType) {
assert toType.getSqlTypeName() == SqlTypeName.DECIMAL;
if (value == null) {
return true;
}
value = value.stripTrailingZeros();
if (value.scale() < 0) {
// Negative scale, convert to 0 scale.
// Rounding mode is irrelevant, since value is integer
value = value.setScale(0, RoundingMode.DOWN);
}
final int intDigits = value.precision() - value.scale();
final int maxIntDigits = toType.getPrecision() - toType.getScale();
return (intDigits <= maxIntDigits) && (value.scale() <= toType.getScale());
}

/**
* Returns whether the decimal value is valid for the type. For example, 1111.11 is not
* valid for DECIMAL(3, 1) since it overflows.
Expand Down
53 changes: 53 additions & 0 deletions core/src/test/java/org/apache/calcite/rex/RexBuilderTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@
import static org.hamcrest.MatcherAssert.assertThat;
import static org.hamcrest.Matchers.hasToString;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertNotEquals;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.junit.jupiter.api.Assertions.assertTrue;
Expand Down Expand Up @@ -213,6 +214,58 @@ private static class MySqlTypeFactoryImpl extends SqlTypeFactoryImpl {
hasToString("1969-07-21 02:56:15.102"));
}

/** Test cases for
* <a href="https://issues.apache.org/jira/browse/CALCITE-6389">[CALCITE-6389]
* RexBuilder.removeCastFromLiteral does not preserve semantics for some types of literal</a>. */
@Test void testRemoveCast() {
final RelDataTypeFactory typeFactory = new SqlTypeFactoryImpl(RelDataTypeSystem.DEFAULT);
RexBuilder builder = new RexBuilder(typeFactory);

// Can remove cast of an integer to an integer
BigDecimal value = new BigDecimal(10);
RelDataType toType = builder.typeFactory.createSqlType(SqlTypeName.INTEGER);
assertTrue(builder.canRemoveCastFromLiteral(toType, value, SqlTypeName.INTEGER));

// Can remove cast from integer to decimal
toType = builder.typeFactory.createSqlType(SqlTypeName.DECIMAL);
assertTrue(builder.canRemoveCastFromLiteral(toType, value, SqlTypeName.INTEGER));

// 250 is too large for a TINYINT
value = new BigDecimal(250);
toType = builder.typeFactory.createSqlType(SqlTypeName.TINYINT);
assertFalse(builder.canRemoveCastFromLiteral(toType, value, SqlTypeName.INTEGER));

// 50 isn't too large for a TINYINT
value = new BigDecimal(50);
toType = builder.typeFactory.createSqlType(SqlTypeName.TINYINT);
assertTrue(builder.canRemoveCastFromLiteral(toType, value, SqlTypeName.INTEGER));

// 120.25 cannot be represented with precision 2 and scale 2 without loss
value = new BigDecimal("120.25");
toType = builder.typeFactory.createSqlType(SqlTypeName.DECIMAL, 2, 2);
assertFalse(builder.canRemoveCastFromLiteral(toType, value, SqlTypeName.DECIMAL));

// 120.25 cannot be represented with precision 5 and scale 1 without rounding
value = new BigDecimal("120.25");
toType = builder.typeFactory.createSqlType(SqlTypeName.DECIMAL, 5, 1);
assertFalse(builder.canRemoveCastFromLiteral(toType, value, SqlTypeName.DECIMAL));

// longmax + 1 cannot be represented as a long
value = new BigDecimal(Long.MAX_VALUE).add(BigDecimal.ONE);
toType = builder.typeFactory.createSqlType(SqlTypeName.BIGINT);
assertFalse(builder.canRemoveCastFromLiteral(toType, value, SqlTypeName.DECIMAL));

// Cast to decimal of an INTERVAL '5' seconds cannot be removed
value = new BigDecimal("5");
toType = builder.typeFactory.createSqlType(SqlTypeName.DECIMAL, 5, 1);
assertFalse(builder.canRemoveCastFromLiteral(toType, value, SqlTypeName.INTERVAL_SECOND));

// Cast to decimal of an INTERVAL '5' minutes cannot be removed
value = new BigDecimal("5");
toType = builder.typeFactory.createSqlType(SqlTypeName.DECIMAL, 5, 1);
assertFalse(builder.canRemoveCastFromLiteral(toType, value, SqlTypeName.INTERVAL_MINUTE));
}

@Test void testTimestampString() {
final TimestampString ts = new TimestampString(1969, 7, 21, 2, 56, 15);
assertThat(ts, hasToString("1969-07-21 02:56:15"));
Expand Down

0 comments on commit c228804

Please sign in to comment.