Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CALCITE-6265] Type coercion is failing for numeric values in prepared statements #3687

Merged
merged 1 commit into from
Apr 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
import org.apache.calcite.linq4j.tree.ParameterExpression;
import org.apache.calcite.linq4j.tree.Primitive;
import org.apache.calcite.linq4j.tree.Statement;
import org.apache.calcite.linq4j.tree.Types;
import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.rex.RexBuilder;
import org.apache.calcite.rex.RexCall;
Expand Down Expand Up @@ -1374,11 +1375,26 @@ private Result toInnerStorageType(Result result, Type storageType) {
}
final Type storageType = currentStorageType != null
? currentStorageType : typeFactory.getJavaClass(dynamicParam.getType());
final Expression valueExpression =

final boolean isNumeric = SqlTypeFamily.NUMERIC.contains(dynamicParam.getType());

// For numeric types, use java.lang.Number to prevent cast exception
// when the parameter type differs from the target type
Expression argumentExpression =
EnumUtils.convert(
Expressions.call(root, BuiltInMethod.DATA_CONTEXT_GET.method,
Expressions.constant("?" + dynamicParam.getIndex())),
storageType);
isNumeric ? java.lang.Number.class : storageType);

// Short-circuit if the expression evaluates to null. The cast
// may throw a NullPointerException as it calls methods on the
// object such as longValue().
Expression valueExpression =
Expressions.condition(
Expressions.equal(argumentExpression, Expressions.constant(null)),
Expressions.constant(null),
Types.castIfNecessary(storageType, argumentExpression));

final ParameterExpression valueVariable =
Expressions.parameter(valueExpression.getType(),
list.newName("value_dynamic_param"));
Expand Down
93 changes: 93 additions & 0 deletions core/src/test/java/org/apache/calcite/test/JdbcTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@
import org.apache.calcite.sql.parser.SqlParser;
import org.apache.calcite.sql.parser.SqlParserPos;
import org.apache.calcite.sql.parser.impl.SqlParserImpl;
import org.apache.calcite.sql.type.SqlTypeName;
import org.apache.calcite.sql2rel.SqlToRelConverter.Config;
import org.apache.calcite.test.schemata.catchall.CatchallSchema;
import org.apache.calcite.test.schemata.foodmart.FoodmartSchema;
Expand Down Expand Up @@ -8423,6 +8424,98 @@ private void checkGetTimestamp(Connection con) throws SQLException {
});
}

@Test void bindByteParameter() {
for (SqlTypeName tpe : SqlTypeName.INT_TYPES) {
final String sql =
"with cte as (select cast(100 as " + tpe.getName() + ") as empid)"
+ "select * from cte where empid = ?";
CalciteAssert.hr()
.query(sql)
.consumesPreparedStatement(p -> {
p.setByte(1, (byte) 100);
})
.returnsUnordered("EMPID=100");
}
}

@Test void bindShortParameter() {
for (SqlTypeName tpe : SqlTypeName.INT_TYPES) {
final String sql =
"with cte as (select cast(100 as " + tpe.getName() + ") as empid)"
+ "select * from cte where empid = ?";

CalciteAssert.hr()
.query(sql)
.consumesPreparedStatement(p -> {
p.setShort(1, (short) 100);
})
.returnsUnordered("EMPID=100");
}
}

@Test void bindOverflowingTinyIntParameter() {
final String sql =
"with cte as (select cast(300 as smallint) as empid)"
+ "select * from cte where empid = cast(? as tinyint)";

java.sql.SQLException t =
assertThrows(
java.sql.SQLException.class,
() -> CalciteAssert.hr()
.query(sql)
.consumesPreparedStatement(p -> {
p.setShort(1, (short) 300);
})
.returns(""));

assertThat(
"message matches",
t.getMessage().contains("value is outside the range of java.lang.Byte"));
}

@Test void bindIntParameter() {
for (SqlTypeName tpe : SqlTypeName.INT_TYPES) {
final String sql =
"with cte as (select cast(100 as " + tpe.getName() + ") as empid)"
+ "select * from cte where empid = ?";

CalciteAssert.hr()
.query(sql)
.consumesPreparedStatement(p -> {
p.setInt(1, 100);
})
.returnsUnordered("EMPID=100");
}
}

@Test void bindLongParameter() {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you also add a test where the value overflows the type, e.g., TINYINT with a value of 300?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TINYINT(300) is not valid SQL, but while working on a test for it, I noticed that NUMERIC(<digits>) conversions are currently not supported either. An exception is thrown because Primitive.ofBox(returnType)) returns null for BigDecimal. This is fixed in the last commit.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

CAST(300 as TINYINT)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for clarifying. CAST(300 as TINYINT) is correctly fine, but CAST(? as TINYINT) does not trigger an error when the value is out of bounds. I added a test and fixed it.

for (SqlTypeName tpe : SqlTypeName.INT_TYPES) {
final String sql =
"with cte as (select cast(100 as " + tpe.getName() + ") as empid)"
+ "select * from cte where empid = ?";

CalciteAssert.hr()
.query(sql)
.consumesPreparedStatement(p -> {
p.setLong(1, 100);
})
.returnsUnordered("EMPID=100");
}
}

@Test void bindNumericParameter() {
final String sql =
"with cte as (select cast(100 as numeric(5)) as empid)"
+ "select * from cte where empid = ?";

CalciteAssert.hr()
.query(sql)
.consumesPreparedStatement(p -> {
p.setLong(1, 100);
})
.returnsUnordered("EMPID=100");
}

private static String sums(int n, boolean c) {
final StringBuilder b = new StringBuilder();
for (int i = 0; i < n; i++) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -622,9 +622,49 @@ public static UnaryExpression convert_(Expression expression, Type type,
* operation that throws an exception if the target type is
* overflowed.
*/
public static UnaryExpression convertChecked(Expression expression,
public static Expression convertChecked(Expression expression,
Type type) {
throw Extensions.todo();
if (type == Byte.class
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's funny, I implemented this myself in a separate PR: #3589
There I am relying on an existing implementation in Primitive.numberValue.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice. Your approach seems to cover more scenarios.

I will keep convertChecked as-is for now. The inlined validations should be sufficient within the current PR's scope, but feel free to replace it once #3589 has been merged.

|| type == Short.class
|| type == Integer.class
|| type == Long.class) {
Class<?> typeClass = (Class<?>) type;

Object minValue;
Object maxValue;

try {
minValue = typeClass.getField("MIN_VALUE").get(null);
maxValue = typeClass.getField("MAX_VALUE").get(null);
} catch (IllegalAccessException | NoSuchFieldException e) {
throw new RuntimeException(e);
}

ThrowStatement throwStmt =
Expressions.throw_(
Expressions.new_(
IllegalArgumentException.class,
Expressions.constant("value is outside the range of " + typeClass.getName())));

// Covers all lower precision types
Expression longValue = Expressions.call(expression, "longValue");

Expression minCheck = Expressions.lessThan(longValue, Expressions.constant(minValue));
Expression maxCheck = Expressions.greaterThan(longValue, Expressions.constant(maxValue));

Primitive primitive = requireNonNull(Primitive.ofBox(type));
String primitiveName = requireNonNull(primitive.primitiveName);
Expression convertExpr = Expressions.call(expression, primitiveName + "Value");

return Expressions.convert_(
Expressions.makeTernary(
ExpressionType.Conditional,
Expressions.or(minCheck, maxCheck),
Expressions.fromStatement(throwStmt),
convertExpr), type);
}

throw new IllegalArgumentException("Type " + type.getTypeName() + " is not supported yet");
}

/**
Expand Down Expand Up @@ -2822,6 +2862,18 @@ public static SymbolDocumentInfo symbolDocument(String filename,
throw Extensions.todo();
}

/**
* Create an expression from a statement.
*/
public static Expression fromStatement(Statement statement) {
FunctionExpression<Function<?>> lambda =
Expressions.lambda(
Blocks.toFunctionBlock(statement),
Collections.emptyList());

return Expressions.call(lambda, "apply");
}

/**
* Creates a statement that represents the throwing of an exception.
*/
Expand Down
25 changes: 20 additions & 5 deletions linq4j/src/main/java/org/apache/calcite/linq4j/tree/Types.java
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import java.lang.reflect.ParameterizedType;
import java.lang.reflect.Type;
import java.lang.reflect.TypeVariable;
import java.math.BigDecimal;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
Expand Down Expand Up @@ -428,11 +429,25 @@ public static Expression castIfNecessary(Type returnType,
&& Number.class.isAssignableFrom((Class) returnType)
&& type instanceof Class
&& Number.class.isAssignableFrom((Class) type)) {
// E.g.
// Integer foo(BigDecimal o) {
// return o.intValue();
// }
return Expressions.unbox(expression, requireNonNull(Primitive.ofBox(returnType)));

if (returnType == BigDecimal.class) {
return Expressions.call(
BigDecimal.class,
"valueOf",
Expressions.call(expression, "longValue"));
} else if (
returnType == Byte.class
|| returnType == Short.class
|| returnType == Integer.class
|| returnType == Long.class) {
return Expressions.convertChecked(expression, returnType);
} else {
// E.g.
// Integer foo(BigDecimal o) {
// return o.intValue();
// }
return Expressions.unbox(expression, requireNonNull(Primitive.ofBox(returnType)));
}
}
if (Primitive.is(returnType) && !Primitive.is(type)) {
// E.g.
Expand Down