diff --git a/datafusion/optimizer/src/pre_cast_lit_in_comparison.rs b/datafusion/optimizer/src/pre_cast_lit_in_comparison.rs index 68c738ca8739..6e89afd600be 100644 --- a/datafusion/optimizer/src/pre_cast_lit_in_comparison.rs +++ b/datafusion/optimizer/src/pre_cast_lit_in_comparison.rs @@ -18,7 +18,9 @@ //! Pre-cast literal binary comparison rule can be only used to the binary comparison expr. //! It can reduce adding the `Expr::Cast` to the expr instead of adding the `Expr::Cast` to literal expr. use crate::{OptimizerConfig, OptimizerRule}; -use arrow::datatypes::DataType; +use arrow::datatypes::{ + DataType, MAX_DECIMAL_FOR_EACH_PRECISION, MIN_DECIMAL_FOR_EACH_PRECISION, +}; use datafusion_common::{DFSchemaRef, DataFusionError, Result, ScalarValue}; use datafusion_expr::expr_rewriter::{ExprRewritable, ExprRewriter, RewriteRecursion}; use datafusion_expr::utils::from_plan; @@ -99,7 +101,6 @@ impl ExprRewriter for PreCastLitExprRewriter { } fn mutate(&mut self, expr: Expr) -> Result { - // traverse the expr by dfs match &expr { Expr::BinaryExpr { left, op, right } => { let left = left.as_ref().clone(); @@ -121,32 +122,19 @@ impl ExprRewriter for PreCastLitExprRewriter { (Expr::Literal(_), Expr::Literal(_)) => { // do nothing } - (Expr::Literal(left_lit_value), _) - if can_integer_literal_cast_to_type( - left_lit_value, - &right_type, - )? => - { - // cast the left literal to the right type - return Ok(binary_expr( - cast_to_other_scalar_expr(left_lit_value, &right_type)?, - *op, - right, - )); + (Expr::Literal(left_lit_value), _) => { + let casted_scalar_value = + try_cast_literal_to_type(left_lit_value, &right_type)?; + if let Some(value) = casted_scalar_value { + return Ok(binary_expr(lit(value), *op, right)); + } } - (_, Expr::Literal(right_lit_value)) - if can_integer_literal_cast_to_type( - right_lit_value, - &left_type, - ) - .unwrap() => - { - // cast the right literal to the left type - return Ok(binary_expr( - left, - *op, - cast_to_other_scalar_expr(right_lit_value, &left_type)?, - )); + (_, Expr::Literal(right_lit_value)) => { + let casted_scalar_value = + try_cast_literal_to_type(right_lit_value, &left_type)?; + if let Some(value) = casted_scalar_value { + return Ok(binary_expr(left, *op, lit(value))); + } } (_, _) => { // do nothing @@ -164,43 +152,6 @@ impl ExprRewriter for PreCastLitExprRewriter { } } -fn cast_to_other_scalar_expr( - origin_value: &ScalarValue, - target_type: &DataType, -) -> Result { - // null case - if origin_value.is_null() { - // if the origin value is null, just convert to another type of null value - // The target type must be satisfied `is_support_data_type` method, we can unwrap safely - return Ok(lit(ScalarValue::try_from(target_type).unwrap())); - } - // no null case - let value: i64 = match origin_value { - ScalarValue::Int8(Some(v)) => *v as i64, - ScalarValue::Int16(Some(v)) => *v as i64, - ScalarValue::Int32(Some(v)) => *v as i64, - ScalarValue::Int64(Some(v)) => *v as i64, - other_value => { - return Err(DataFusionError::Internal(format!( - "Invalid type and value {}", - other_value - ))) - } - }; - Ok(lit(match target_type { - DataType::Int8 => ScalarValue::Int8(Some(value as i8)), - DataType::Int16 => ScalarValue::Int16(Some(value as i16)), - DataType::Int32 => ScalarValue::Int32(Some(value as i32)), - DataType::Int64 => ScalarValue::Int64(Some(value)), - other_type => { - return Err(DataFusionError::Internal(format!( - "Invalid target data type {:?}", - other_type - ))) - } - })) -} - fn is_comparison_op(op: &Operator) -> bool { matches!( op, @@ -214,47 +165,112 @@ fn is_comparison_op(op: &Operator) -> bool { } fn is_support_data_type(data_type: &DataType) -> bool { - // TODO support decimal with other data type matches!( data_type, - DataType::Int8 | DataType::Int16 | DataType::Int32 | DataType::Int64 + DataType::Int8 + | DataType::Int16 + | DataType::Int32 + | DataType::Int64 + | DataType::Decimal128(_, _) ) } -fn can_integer_literal_cast_to_type( - integer_lit_value: &ScalarValue, +fn try_cast_literal_to_type( + lit_value: &ScalarValue, target_type: &DataType, -) -> Result { - if integer_lit_value.is_null() { +) -> Result> { + if lit_value.is_null() { // null value can be cast to any type of null value - return Ok(true); + return Ok(Some(ScalarValue::try_from(target_type)?)); } + let mul = match target_type { + DataType::Int8 | DataType::Int16 | DataType::Int32 | DataType::Int64 => 1_i128, + DataType::Decimal128(_, scale) => 10_i128.pow(*scale as u32), + other_type => { + return Err(DataFusionError::Internal(format!( + "Error target data type {:?}", + other_type + ))); + } + }; let (target_min, target_max) = match target_type { DataType::Int8 => (i8::MIN as i128, i8::MAX as i128), DataType::Int16 => (i16::MIN as i128, i16::MAX as i128), DataType::Int32 => (i32::MIN as i128, i32::MAX as i128), DataType::Int64 => (i64::MIN as i128, i64::MAX as i128), + DataType::Decimal128(precision, _) => ( + // Different precision for decimal128 can store different range of value. + // For example, the precision is 3, the max of value is `999` and the min + // value is `-999` + MIN_DECIMAL_FOR_EACH_PRECISION[*precision - 1], + MAX_DECIMAL_FOR_EACH_PRECISION[*precision - 1], + ), other_type => { return Err(DataFusionError::Internal(format!( "Error target data type {:?}", other_type - ))) + ))); } }; - let lit_value = match integer_lit_value { - ScalarValue::Int8(Some(v)) => *v as i128, - ScalarValue::Int16(Some(v)) => *v as i128, - ScalarValue::Int32(Some(v)) => *v as i128, - ScalarValue::Int64(Some(v)) => *v as i128, + let lit_value_target_type = match lit_value { + ScalarValue::Int8(Some(v)) => (*v as i128).checked_mul(mul), + ScalarValue::Int16(Some(v)) => (*v as i128).checked_mul(mul), + ScalarValue::Int32(Some(v)) => (*v as i128).checked_mul(mul), + ScalarValue::Int64(Some(v)) => (*v as i128).checked_mul(mul), + ScalarValue::Decimal128(Some(v), _, scale) => { + let lit_scale_mul = 10_i128.pow(*scale as u32); + if mul >= lit_scale_mul { + // Example: + // lit is decimal(123,3,2) + // target type is decimal(5,3) + // the lit can be converted to the decimal(1230,5,3) + (*v).checked_mul(mul / lit_scale_mul) + } else if (*v) % (lit_scale_mul / mul) == 0 { + // Example: + // lit is decimal(123000,10,3) + // target type is int32: the lit can be converted to INT32(123) + // target type is decimal(10,2): the lit can be converted to decimal(12300,10,2) + Some(*v / (lit_scale_mul / mul)) + } else { + // can't convert the lit decimal to the target data type + None + } + } other_value => { return Err(DataFusionError::Internal(format!( "Invalid literal value {:?}", other_value - ))) + ))); } }; - Ok(lit_value >= target_min && lit_value <= target_max) + match lit_value_target_type { + None => Ok(None), + Some(value) => { + if value >= target_min && value <= target_max { + // the value casted from lit to the target type is in the range of target type. + // return the target type of scalar value + let result_scalar = match target_type { + DataType::Int8 => ScalarValue::Int8(Some(value as i8)), + DataType::Int16 => ScalarValue::Int16(Some(value as i16)), + DataType::Int32 => ScalarValue::Int32(Some(value as i32)), + DataType::Int64 => ScalarValue::Int64(Some(value as i64)), + DataType::Decimal128(p, s) => { + ScalarValue::Decimal128(Some(value), *p, *s) + } + other_type => { + return Err(DataFusionError::Internal(format!( + "Error target data type {:?}", + other_type + ))); + } + }; + Ok(Some(result_scalar)) + } else { + Ok(None) + } + } + } } #[cfg(test)] @@ -307,6 +323,67 @@ mod tests { assert_eq!(optimize_test(c1_lt_lit_null, &schema), expected); } + #[test] + fn test_not_cast_with_decimal_lit_comparison() { + let schema = expr_test_schema(); + // integer to decimal: value is out of the bounds of the decimal + // c3 = INT64(100000000000000000) + let expr_eq = col("c3").eq(lit(ScalarValue::Int64(Some(100000000000000000)))); + let expected = col("c3").eq(lit(ScalarValue::Int64(Some(100000000000000000)))); + assert_eq!(optimize_test(expr_eq, &schema), expected); + // c4 = INT64(1000) will overflow the i128 + let expr_eq = col("c4").eq(lit(ScalarValue::Int64(Some(1000)))); + let expected = col("c4").eq(lit(ScalarValue::Int64(Some(1000)))); + assert_eq!(optimize_test(expr_eq, &schema), expected); + + // decimal to decimal: value will lose the scale when convert to the target data type + // c3 = DECIMAL(12340,20,4) + let expr_eq = col("c3").eq(lit(ScalarValue::Decimal128(Some(12340), 20, 4))); + let expected = col("c3").eq(lit(ScalarValue::Decimal128(Some(12340), 20, 4))); + assert_eq!(optimize_test(expr_eq, &schema), expected); + + // decimal to integer + // c1 = DECIMAL(123, 10, 1): value will lose the scale when convert to the target data type + let expr_eq = col("c1").eq(lit(ScalarValue::Decimal128(Some(123), 10, 1))); + let expected = col("c1").eq(lit(ScalarValue::Decimal128(Some(123), 10, 1))); + assert_eq!(optimize_test(expr_eq, &schema), expected); + // c1 = DECIMAL(1230, 10, 2): value will lose the scale when convert to the target data type + let expr_eq = col("c1").eq(lit(ScalarValue::Decimal128(Some(1230), 10, 2))); + let expected = col("c1").eq(lit(ScalarValue::Decimal128(Some(1230), 10, 2))); + assert_eq!(optimize_test(expr_eq, &schema), expected); + } + + #[test] + fn test_pre_cast_with_decimal_lit_comparison() { + let schema = expr_test_schema(); + // integer to decimal + // c3 < INT64(16) -> c3 < (CAST(INT64(16) AS DECIMAL(18,2)); + let expr_lt = col("c3").lt(lit(ScalarValue::Int64(Some(16)))); + let expected = col("c3").lt(lit(ScalarValue::Decimal128(Some(1600), 18, 2))); + assert_eq!(optimize_test(expr_lt, &schema), expected); + + // c3 < INT64(NULL) + let c1_lt_lit_null = col("c3").lt(lit(ScalarValue::Int64(None))); + let expected = col("c3").lt(lit(ScalarValue::Decimal128(None, 18, 2))); + assert_eq!(optimize_test(c1_lt_lit_null, &schema), expected); + + // decimal to decimal + // c3 < Decimal(123,10,0) -> c3 < CAST(DECIMAL(123,10,0) AS DECIMAL(18,2)) -> c3 < DECIMAL(12300,18,2) + let expr_lt = col("c3").lt(lit(ScalarValue::Decimal128(Some(123), 10, 0))); + let expected = col("c3").lt(lit(ScalarValue::Decimal128(Some(12300), 18, 2))); + assert_eq!(optimize_test(expr_lt, &schema), expected); + // c3 < Decimal(1230,10,3) -> c3 < CAST(DECIMAL(1230,10,3) AS DECIMAL(18,2)) -> c3 < DECIMAL(123,18,2) + let expr_lt = col("c3").lt(lit(ScalarValue::Decimal128(Some(1230), 10, 3))); + let expected = col("c3").lt(lit(ScalarValue::Decimal128(Some(123), 18, 2))); + assert_eq!(optimize_test(expr_lt, &schema), expected); + + // decimal to integer + // c1 < Decimal(12300, 10, 2) -> c1 < CAST(DECIMAL(12300,10,2) AS INT32) -> c1 < INT32(123) + let expr_lt = col("c1").lt(lit(ScalarValue::Decimal128(Some(12300), 10, 2))); + let expected = col("c1").lt(lit(ScalarValue::Int32(Some(123)))); + assert_eq!(optimize_test(expr_lt, &schema), expected); + } + #[test] fn aliased() { let schema = expr_test_schema(); @@ -344,6 +421,8 @@ mod tests { vec![ DFField::new(None, "c1", DataType::Int32, false), DFField::new(None, "c2", DataType::Int64, false), + DFField::new(None, "c3", DataType::Decimal128(18, 2), false), + DFField::new(None, "c4", DataType::Decimal128(38, 37), false), ], HashMap::new(), )