-
Notifications
You must be signed in to change notification settings - Fork 1.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
support decimal data type for the optimizer rule of PreCastLitInComparisonExpressions #3245
Changes from 1 commit
51029d5
6a82273
3d2caf5
6170817
3c64aef
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -18,7 +18,9 @@ | |
//! Pre-cast literal binary comparison rule can be only used to the binary comparison expr. | ||
//! It can reduce adding the `Expr::Cast` to the expr instead of adding the `Expr::Cast` to literal expr. | ||
use crate::{OptimizerConfig, OptimizerRule}; | ||
use arrow::datatypes::DataType; | ||
use arrow::datatypes::{ | ||
DataType, MAX_DECIMAL_FOR_EACH_PRECISION, MIN_DECIMAL_FOR_EACH_PRECISION, | ||
}; | ||
use datafusion_common::{DFSchemaRef, DataFusionError, Result, ScalarValue}; | ||
use datafusion_expr::utils::from_plan; | ||
use datafusion_expr::{binary_expr, lit, Expr, ExprSchemable, LogicalPlan, Operator}; | ||
|
@@ -97,8 +99,8 @@ fn visit_expr(expr: Expr, schema: &DFSchemaRef) -> Result<Expr> { | |
if left_type.is_err() || right_type.is_err() { | ||
return Ok(expr.clone()); | ||
} | ||
let left_type = left_type.unwrap(); | ||
let right_type = right_type.unwrap(); | ||
let left_type = left_type?; | ||
let right_type = right_type?; | ||
if !left_type.eq(&right_type) | ||
&& is_support_data_type(&left_type) | ||
&& is_support_data_type(&right_type) | ||
|
@@ -108,32 +110,27 @@ fn visit_expr(expr: Expr, schema: &DFSchemaRef) -> Result<Expr> { | |
(Expr::Literal(_), Expr::Literal(_)) => { | ||
// do nothing | ||
} | ||
(Expr::Literal(left_lit_value), _) | ||
if can_integer_literal_cast_to_type( | ||
left_lit_value, | ||
&right_type, | ||
)? => | ||
{ | ||
// cast the left literal to the right type | ||
return Ok(binary_expr( | ||
cast_to_other_scalar_expr(left_lit_value, &right_type)?, | ||
*op, | ||
right, | ||
)); | ||
(Expr::Literal(left_lit_value), _) => { | ||
let (can_cast, casted_scalar_value) = | ||
try_cast_literal_to_type(left_lit_value, &right_type)?; | ||
if can_cast { | ||
return Ok(binary_expr( | ||
lit(casted_scalar_value.unwrap()), | ||
*op, | ||
right, | ||
)); | ||
} | ||
} | ||
(_, Expr::Literal(right_lit_value)) | ||
if can_integer_literal_cast_to_type( | ||
right_lit_value, | ||
&left_type, | ||
) | ||
.unwrap() => | ||
{ | ||
// cast the right literal to the left type | ||
return Ok(binary_expr( | ||
left, | ||
*op, | ||
cast_to_other_scalar_expr(right_lit_value, &left_type)?, | ||
)); | ||
(_, Expr::Literal(right_lit_value)) => { | ||
let (can_cast, casted_scalar_value) = | ||
try_cast_literal_to_type(right_lit_value, &left_type)?; | ||
if can_cast { | ||
return Ok(binary_expr( | ||
left, | ||
*op, | ||
lit(casted_scalar_value.unwrap()), | ||
)); | ||
} | ||
} | ||
(_, _) => { | ||
// do nothing | ||
|
@@ -150,43 +147,6 @@ fn visit_expr(expr: Expr, schema: &DFSchemaRef) -> Result<Expr> { | |
} | ||
} | ||
|
||
fn cast_to_other_scalar_expr( | ||
origin_value: &ScalarValue, | ||
target_type: &DataType, | ||
) -> Result<Expr> { | ||
// null case | ||
if origin_value.is_null() { | ||
// if the origin value is null, just convert to another type of null value | ||
// The target type must be satisfied `is_support_data_type` method, we can unwrap safely | ||
return Ok(lit(ScalarValue::try_from(target_type).unwrap())); | ||
} | ||
// no null case | ||
let value: i64 = match origin_value { | ||
ScalarValue::Int8(Some(v)) => *v as i64, | ||
ScalarValue::Int16(Some(v)) => *v as i64, | ||
ScalarValue::Int32(Some(v)) => *v as i64, | ||
ScalarValue::Int64(Some(v)) => *v as i64, | ||
other_value => { | ||
return Err(DataFusionError::Internal(format!( | ||
"Invalid type and value {}", | ||
other_value | ||
))) | ||
} | ||
}; | ||
Ok(lit(match target_type { | ||
DataType::Int8 => ScalarValue::Int8(Some(value as i8)), | ||
DataType::Int16 => ScalarValue::Int16(Some(value as i16)), | ||
DataType::Int32 => ScalarValue::Int32(Some(value as i32)), | ||
DataType::Int64 => ScalarValue::Int64(Some(value)), | ||
other_type => { | ||
return Err(DataFusionError::Internal(format!( | ||
"Invalid target data type {:?}", | ||
other_type | ||
))) | ||
} | ||
})) | ||
} | ||
|
||
fn is_comparison_op(op: &Operator) -> bool { | ||
matches!( | ||
op, | ||
|
@@ -200,47 +160,110 @@ fn is_comparison_op(op: &Operator) -> bool { | |
} | ||
|
||
fn is_support_data_type(data_type: &DataType) -> bool { | ||
// TODO support decimal with other data type | ||
matches!( | ||
data_type, | ||
DataType::Int8 | DataType::Int16 | DataType::Int32 | DataType::Int64 | ||
DataType::Int8 | ||
| DataType::Int16 | ||
| DataType::Int32 | ||
| DataType::Int64 | ||
| DataType::Decimal128(_, _) | ||
) | ||
} | ||
|
||
fn can_integer_literal_cast_to_type( | ||
integer_lit_value: &ScalarValue, | ||
fn try_cast_literal_to_type( | ||
lit_value: &ScalarValue, | ||
target_type: &DataType, | ||
) -> Result<bool> { | ||
if integer_lit_value.is_null() { | ||
) -> Result<(bool, Option<ScalarValue>)> { | ||
if lit_value.is_null() { | ||
// null value can be cast to any type of null value | ||
return Ok(true); | ||
return Ok((true, Some(ScalarValue::try_from(target_type)?))); | ||
} | ||
let mul = match target_type { | ||
DataType::Int8 | DataType::Int16 | DataType::Int32 | DataType::Int64 => 1_i128, | ||
DataType::Decimal128(_, scale) => 10_i128.pow(*scale as u32), | ||
other_type => { | ||
return Err(DataFusionError::Internal(format!( | ||
"Error target data type {:?}", | ||
other_type | ||
))); | ||
} | ||
}; | ||
let (target_min, target_max) = match target_type { | ||
DataType::Int8 => (i8::MIN as i128, i8::MAX as i128), | ||
DataType::Int16 => (i16::MIN as i128, i16::MAX as i128), | ||
DataType::Int32 => (i32::MIN as i128, i32::MAX as i128), | ||
DataType::Int64 => (i64::MIN as i128, i64::MAX as i128), | ||
DataType::Decimal128(precision, _) => ( | ||
MIN_DECIMAL_FOR_EACH_PRECISION[*precision - 1], | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could you add a comment here to explain what is going on? |
||
MAX_DECIMAL_FOR_EACH_PRECISION[*precision - 1], | ||
), | ||
other_type => { | ||
return Err(DataFusionError::Internal(format!( | ||
"Error target data type {:?}", | ||
other_type | ||
))) | ||
))); | ||
} | ||
}; | ||
let lit_value = match integer_lit_value { | ||
ScalarValue::Int8(Some(v)) => *v as i128, | ||
ScalarValue::Int16(Some(v)) => *v as i128, | ||
ScalarValue::Int32(Some(v)) => *v as i128, | ||
ScalarValue::Int64(Some(v)) => *v as i128, | ||
let lit_value_target_type = match lit_value { | ||
ScalarValue::Int8(Some(v)) => (*v as i128).checked_mul(mul), | ||
ScalarValue::Int16(Some(v)) => (*v as i128).checked_mul(mul), | ||
ScalarValue::Int32(Some(v)) => (*v as i128).checked_mul(mul), | ||
ScalarValue::Int64(Some(v)) => (*v as i128).checked_mul(mul), | ||
ScalarValue::Decimal128(Some(v), _, scale) => { | ||
let lit_scale_mul = 10_i128.pow(*scale as u32); | ||
if mul >= lit_scale_mul { | ||
// Example: | ||
// lit is decimal(123,3,2) | ||
// target type is decimal(5,3) | ||
// the lit can be converted to the decimal(1230,5,3) | ||
(*v).checked_mul(mul / lit_scale_mul) | ||
} else if (*v) % (lit_scale_mul / mul) == 0 { | ||
// Example: | ||
// lit is decimal(123000,10,3) | ||
// target type is int32: the lit can be converted to INT32(123) | ||
// target type is decimal(10,2): the lit can be converted to decimal(12300,10,2) | ||
Some(*v / (lit_scale_mul / mul)) | ||
} else { | ||
// can't convert the lit decimal to the target data type | ||
None | ||
} | ||
} | ||
other_value => { | ||
return Err(DataFusionError::Internal(format!( | ||
"Invalid literal value {:?}", | ||
other_value | ||
))) | ||
))); | ||
} | ||
}; | ||
|
||
Ok(lit_value >= target_min && lit_value <= target_max) | ||
match lit_value_target_type { | ||
None => Ok((false, None)), | ||
Some(value) => { | ||
match value >= target_min && value <= target_max { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think it is better to use an There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done |
||
// the value casted from lit to the target type is in the range of target type. | ||
// return the target type of scalar value | ||
true => { | ||
let result_scalar = match target_type { | ||
DataType::Int8 => ScalarValue::Int8(Some(value as i8)), | ||
DataType::Int16 => ScalarValue::Int16(Some(value as i16)), | ||
DataType::Int32 => ScalarValue::Int32(Some(value as i32)), | ||
DataType::Int64 => ScalarValue::Int64(Some(value as i64)), | ||
DataType::Decimal128(p, s) => { | ||
ScalarValue::Decimal128(Some(value), *p, *s) | ||
} | ||
other_type => { | ||
return Err(DataFusionError::Internal(format!( | ||
"Error target data type {:?}", | ||
other_type | ||
))); | ||
} | ||
}; | ||
Ok((true, Some(result_scalar))) | ||
} | ||
false => Ok((false, None)), | ||
} | ||
} | ||
} | ||
} | ||
|
||
#[cfg(test)] | ||
|
@@ -292,6 +315,67 @@ mod tests { | |
assert_eq!(optimize_test(c1_lt_lit_null, &schema), expected); | ||
} | ||
|
||
#[test] | ||
fn test_not_cast_with_decimal_lit_comparison() { | ||
let schema = expr_test_schema(); | ||
// integer to decimal: value is out of the bounds of the decimal | ||
// c3 = INT64(100000000000000000) | ||
let expr_eq = col("c3").eq(lit(ScalarValue::Int64(Some(100000000000000000)))); | ||
let expected = col("c3").eq(lit(ScalarValue::Int64(Some(100000000000000000)))); | ||
assert_eq!(optimize_test(expr_eq, &schema), expected); | ||
// c4 = INT64(1000) will overflow the i128 | ||
let expr_eq = col("c4").eq(lit(ScalarValue::Int64(Some(1000)))); | ||
let expected = col("c4").eq(lit(ScalarValue::Int64(Some(1000)))); | ||
assert_eq!(optimize_test(expr_eq, &schema), expected); | ||
|
||
// decimal to decimal: value will lose the scale when convert to the target data type | ||
// c3 = DECIMAL(12340,20,4) | ||
let expr_eq = col("c3").eq(lit(ScalarValue::Decimal128(Some(12340), 20, 4))); | ||
let expected = col("c3").eq(lit(ScalarValue::Decimal128(Some(12340), 20, 4))); | ||
assert_eq!(optimize_test(expr_eq, &schema), expected); | ||
|
||
// decimal to integer | ||
// c1 = DECIMAL(123, 10, 1): value will lose the scale when convert to the target data type | ||
let expr_eq = col("c1").eq(lit(ScalarValue::Decimal128(Some(123), 10, 1))); | ||
let expected = col("c1").eq(lit(ScalarValue::Decimal128(Some(123), 10, 1))); | ||
assert_eq!(optimize_test(expr_eq, &schema), expected); | ||
// c1 = DECIMAL(1230, 10, 2): value will lose the scale when convert to the target data type | ||
let expr_eq = col("c1").eq(lit(ScalarValue::Decimal128(Some(1230), 10, 2))); | ||
let expected = col("c1").eq(lit(ScalarValue::Decimal128(Some(1230), 10, 2))); | ||
assert_eq!(optimize_test(expr_eq, &schema), expected); | ||
} | ||
|
||
#[test] | ||
fn test_pre_cast_with_decimal_lit_comparison() { | ||
let schema = expr_test_schema(); | ||
// integer to decimal | ||
// c3 < INT64(16) -> c3 < (CAST(INT64(16) AS DECIMAL(18,2)); | ||
let expr_lt = col("c3").lt(lit(ScalarValue::Int64(Some(16)))); | ||
let expected = col("c3").lt(lit(ScalarValue::Decimal128(Some(1600), 18, 2))); | ||
assert_eq!(optimize_test(expr_lt, &schema), expected); | ||
|
||
// c3 < INT64(NULL) | ||
let c1_lt_lit_null = col("c3").lt(lit(ScalarValue::Int64(None))); | ||
let expected = col("c3").lt(lit(ScalarValue::Decimal128(None, 18, 2))); | ||
assert_eq!(optimize_test(c1_lt_lit_null, &schema), expected); | ||
|
||
// decimal to decimal | ||
// c3 < Decimal(123,10,0) -> c3 < CAST(DECIMAL(123,10,0) AS DECIMAL(18,2)) -> c3 < DECIMAL(12300,18,2) | ||
let expr_lt = col("c3").lt(lit(ScalarValue::Decimal128(Some(123), 10, 0))); | ||
let expected = col("c3").lt(lit(ScalarValue::Decimal128(Some(12300), 18, 2))); | ||
assert_eq!(optimize_test(expr_lt, &schema), expected); | ||
// c3 < Decimal(1230,10,3) -> c3 < CAST(DECIMAL(1230,10,3) AS DECIMAL(18,2)) -> c3 < DECIMAL(123,18,2) | ||
let expr_lt = col("c3").lt(lit(ScalarValue::Decimal128(Some(1230), 10, 3))); | ||
let expected = col("c3").lt(lit(ScalarValue::Decimal128(Some(123), 18, 2))); | ||
assert_eq!(optimize_test(expr_lt, &schema), expected); | ||
|
||
// decimal to integer | ||
// c1 < Decimal(12300, 10, 2) -> c1 < CAST(DECIMAL(12300,10,2) AS INT32) -> c1 < INT32(123) | ||
let expr_lt = col("c1").lt(lit(ScalarValue::Decimal128(Some(12300), 10, 2))); | ||
let expected = col("c1").lt(lit(ScalarValue::Int32(Some(123)))); | ||
assert_eq!(optimize_test(expr_lt, &schema), expected); | ||
} | ||
|
||
fn optimize_test(expr: Expr, schema: &DFSchemaRef) -> Expr { | ||
visit_expr(expr, schema).unwrap() | ||
} | ||
|
@@ -302,6 +386,8 @@ mod tests { | |
vec![ | ||
DFField::new(None, "c1", DataType::Int32, false), | ||
DFField::new(None, "c2", DataType::Int64, false), | ||
DFField::new(None, "c3", DataType::Decimal128(18, 2), false), | ||
DFField::new(None, "c4", DataType::Decimal128(38, 37), false), | ||
], | ||
HashMap::new(), | ||
) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It looks like this function always returns either
(true, Some(_))
or(false, None)
so maybe it should just return theOption
without the bool?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
good suggestion.
Done for your comments.
PTAL @andygrove