Skip to content

Commit

Permalink
[CALCITE-6265] Type coercion is failing for numeric values in prepare…
Browse files Browse the repository at this point in the history
…d statements

Given a column of type `INT`. When providing a `short` value as a
placeholder in a prepared statement, a `ClassCastException` is thrown.

Test case:
```
final String sql =
    "select \"empid\" from \"hr\".\"emps\" where \"empid\" in (?, ?)";
CalciteAssert.hr()
    .query(sql)
    .consumesPreparedStatement(p -> {
        p.setShort(1, (short) 100);
        p.setShort(2, (short) 110);
    })
    .returnsUnordered("empid=100", "empid=110");
```

Stack trace:
```
java.lang.ClassCastException: class java.lang.Short cannot be cast to class java.lang.Integer (java.lang.Short and java.lang.Integer are in module java.base of loader 'bootstrap')
     at Baz$1$1.moveNext(Unknown Source)
     at org.apache.calcite.linq4j.Linq4j$EnumeratorIterator.<init>(Linq4j.java:679)
```
  • Loading branch information
tindzk authored and mihaibudiu committed Apr 8, 2024
1 parent 1f5ba75 commit 4e6a320
Show file tree
Hide file tree
Showing 4 changed files with 185 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
import org.apache.calcite.linq4j.tree.ParameterExpression;
import org.apache.calcite.linq4j.tree.Primitive;
import org.apache.calcite.linq4j.tree.Statement;
import org.apache.calcite.linq4j.tree.Types;
import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.rex.RexBuilder;
import org.apache.calcite.rex.RexCall;
Expand Down Expand Up @@ -1374,11 +1375,26 @@ private Result toInnerStorageType(Result result, Type storageType) {
}
final Type storageType = currentStorageType != null
? currentStorageType : typeFactory.getJavaClass(dynamicParam.getType());
final Expression valueExpression =

final boolean isNumeric = SqlTypeFamily.NUMERIC.contains(dynamicParam.getType());

// For numeric types, use java.lang.Number to prevent cast exception
// when the parameter type differs from the target type
Expression argumentExpression =
EnumUtils.convert(
Expressions.call(root, BuiltInMethod.DATA_CONTEXT_GET.method,
Expressions.constant("?" + dynamicParam.getIndex())),
storageType);
isNumeric ? java.lang.Number.class : storageType);

// Short-circuit if the expression evaluates to null. The cast
// may throw a NullPointerException as it calls methods on the
// object such as longValue().
Expression valueExpression =
Expressions.condition(
Expressions.equal(argumentExpression, Expressions.constant(null)),
Expressions.constant(null),
Types.castIfNecessary(storageType, argumentExpression));

final ParameterExpression valueVariable =
Expressions.parameter(valueExpression.getType(),
list.newName("value_dynamic_param"));
Expand Down
93 changes: 93 additions & 0 deletions core/src/test/java/org/apache/calcite/test/JdbcTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@
import org.apache.calcite.sql.parser.SqlParser;
import org.apache.calcite.sql.parser.SqlParserPos;
import org.apache.calcite.sql.parser.impl.SqlParserImpl;
import org.apache.calcite.sql.type.SqlTypeName;
import org.apache.calcite.sql2rel.SqlToRelConverter.Config;
import org.apache.calcite.test.schemata.catchall.CatchallSchema;
import org.apache.calcite.test.schemata.foodmart.FoodmartSchema;
Expand Down Expand Up @@ -8423,6 +8424,98 @@ private void checkGetTimestamp(Connection con) throws SQLException {
});
}

@Test void bindByteParameter() {
for (SqlTypeName tpe : SqlTypeName.INT_TYPES) {
final String sql =
"with cte as (select cast(100 as " + tpe.getName() + ") as empid)"
+ "select * from cte where empid = ?";
CalciteAssert.hr()
.query(sql)
.consumesPreparedStatement(p -> {
p.setByte(1, (byte) 100);
})
.returnsUnordered("EMPID=100");
}
}

@Test void bindShortParameter() {
for (SqlTypeName tpe : SqlTypeName.INT_TYPES) {
final String sql =
"with cte as (select cast(100 as " + tpe.getName() + ") as empid)"
+ "select * from cte where empid = ?";

CalciteAssert.hr()
.query(sql)
.consumesPreparedStatement(p -> {
p.setShort(1, (short) 100);
})
.returnsUnordered("EMPID=100");
}
}

@Test void bindOverflowingTinyIntParameter() {
final String sql =
"with cte as (select cast(300 as smallint) as empid)"
+ "select * from cte where empid = cast(? as tinyint)";

java.sql.SQLException t =
assertThrows(
java.sql.SQLException.class,
() -> CalciteAssert.hr()
.query(sql)
.consumesPreparedStatement(p -> {
p.setShort(1, (short) 300);
})
.returns(""));

assertThat(
"message matches",
t.getMessage().contains("value is outside the range of java.lang.Byte"));
}

@Test void bindIntParameter() {
for (SqlTypeName tpe : SqlTypeName.INT_TYPES) {
final String sql =
"with cte as (select cast(100 as " + tpe.getName() + ") as empid)"
+ "select * from cte where empid = ?";

CalciteAssert.hr()
.query(sql)
.consumesPreparedStatement(p -> {
p.setInt(1, 100);
})
.returnsUnordered("EMPID=100");
}
}

@Test void bindLongParameter() {
for (SqlTypeName tpe : SqlTypeName.INT_TYPES) {
final String sql =
"with cte as (select cast(100 as " + tpe.getName() + ") as empid)"
+ "select * from cte where empid = ?";

CalciteAssert.hr()
.query(sql)
.consumesPreparedStatement(p -> {
p.setLong(1, 100);
})
.returnsUnordered("EMPID=100");
}
}

@Test void bindNumericParameter() {
final String sql =
"with cte as (select cast(100 as numeric(5)) as empid)"
+ "select * from cte where empid = ?";

CalciteAssert.hr()
.query(sql)
.consumesPreparedStatement(p -> {
p.setLong(1, 100);
})
.returnsUnordered("EMPID=100");
}

private static String sums(int n, boolean c) {
final StringBuilder b = new StringBuilder();
for (int i = 0; i < n; i++) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -622,9 +622,49 @@ public static UnaryExpression convert_(Expression expression, Type type,
* operation that throws an exception if the target type is
* overflowed.
*/
public static UnaryExpression convertChecked(Expression expression,
public static Expression convertChecked(Expression expression,
Type type) {
throw Extensions.todo();
if (type == Byte.class
|| type == Short.class
|| type == Integer.class
|| type == Long.class) {
Class<?> typeClass = (Class<?>) type;

Object minValue;
Object maxValue;

try {
minValue = typeClass.getField("MIN_VALUE").get(null);
maxValue = typeClass.getField("MAX_VALUE").get(null);
} catch (IllegalAccessException | NoSuchFieldException e) {
throw new RuntimeException(e);
}

ThrowStatement throwStmt =
Expressions.throw_(
Expressions.new_(
IllegalArgumentException.class,
Expressions.constant("value is outside the range of " + typeClass.getName())));

// Covers all lower precision types
Expression longValue = Expressions.call(expression, "longValue");

Expression minCheck = Expressions.lessThan(longValue, Expressions.constant(minValue));
Expression maxCheck = Expressions.greaterThan(longValue, Expressions.constant(maxValue));

Primitive primitive = requireNonNull(Primitive.ofBox(type));
String primitiveName = requireNonNull(primitive.primitiveName);
Expression convertExpr = Expressions.call(expression, primitiveName + "Value");

return Expressions.convert_(
Expressions.makeTernary(
ExpressionType.Conditional,
Expressions.or(minCheck, maxCheck),
Expressions.fromStatement(throwStmt),
convertExpr), type);
}

throw new IllegalArgumentException("Type " + type.getTypeName() + " is not supported yet");
}

/**
Expand Down Expand Up @@ -2822,6 +2862,18 @@ public static SymbolDocumentInfo symbolDocument(String filename,
throw Extensions.todo();
}

/**
* Create an expression from a statement.
*/
public static Expression fromStatement(Statement statement) {
FunctionExpression<Function<?>> lambda =
Expressions.lambda(
Blocks.toFunctionBlock(statement),
Collections.emptyList());

return Expressions.call(lambda, "apply");
}

/**
* Creates a statement that represents the throwing of an exception.
*/
Expand Down
25 changes: 20 additions & 5 deletions linq4j/src/main/java/org/apache/calcite/linq4j/tree/Types.java
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import java.lang.reflect.ParameterizedType;
import java.lang.reflect.Type;
import java.lang.reflect.TypeVariable;
import java.math.BigDecimal;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
Expand Down Expand Up @@ -428,11 +429,25 @@ public static Expression castIfNecessary(Type returnType,
&& Number.class.isAssignableFrom((Class) returnType)
&& type instanceof Class
&& Number.class.isAssignableFrom((Class) type)) {
// E.g.
// Integer foo(BigDecimal o) {
// return o.intValue();
// }
return Expressions.unbox(expression, requireNonNull(Primitive.ofBox(returnType)));

if (returnType == BigDecimal.class) {
return Expressions.call(
BigDecimal.class,
"valueOf",
Expressions.call(expression, "longValue"));
} else if (
returnType == Byte.class
|| returnType == Short.class
|| returnType == Integer.class
|| returnType == Long.class) {
return Expressions.convertChecked(expression, returnType);
} else {
// E.g.
// Integer foo(BigDecimal o) {
// return o.intValue();
// }
return Expressions.unbox(expression, requireNonNull(Primitive.ofBox(returnType)));
}
}
if (Primitive.is(returnType) && !Primitive.is(type)) {
// E.g.
Expand Down

0 comments on commit 4e6a320

Please sign in to comment.