From a815c623173d79e7d4ecff88ef784cc6467c548d Mon Sep 17 00:00:00 2001 From: liukun4515 Date: Wed, 24 Aug 2022 17:04:24 +0800 Subject: [PATCH] support decimal for the PreCastLitInComparisonExpressions rule --- .../src/pre_cast_lit_in_comparison.rs | 248 ++++++++++++------ 1 file changed, 168 insertions(+), 80 deletions(-) diff --git a/datafusion/optimizer/src/pre_cast_lit_in_comparison.rs b/datafusion/optimizer/src/pre_cast_lit_in_comparison.rs index 0c16f7921c32..9cc77579cf99 100644 --- a/datafusion/optimizer/src/pre_cast_lit_in_comparison.rs +++ b/datafusion/optimizer/src/pre_cast_lit_in_comparison.rs @@ -18,7 +18,9 @@ //! Pre-cast literal binary comparison rule can be only used to the binary comparison expr. //! It can reduce adding the `Expr::Cast` to the expr instead of adding the `Expr::Cast` to literal expr. use crate::{OptimizerConfig, OptimizerRule}; -use arrow::datatypes::DataType; +use arrow::datatypes::{ + DataType, MAX_DECIMAL_FOR_EACH_PRECISION, MIN_DECIMAL_FOR_EACH_PRECISION, +}; use datafusion_common::{DFSchemaRef, DataFusionError, Result, ScalarValue}; use datafusion_expr::utils::from_plan; use datafusion_expr::{binary_expr, lit, Expr, ExprSchemable, LogicalPlan, Operator}; @@ -97,8 +99,8 @@ fn visit_expr(expr: Expr, schema: &DFSchemaRef) -> Result { if left_type.is_err() || right_type.is_err() { return Ok(expr.clone()); } - let left_type = left_type.unwrap(); - let right_type = right_type.unwrap(); + let left_type = left_type?; + let right_type = right_type?; if !left_type.eq(&right_type) && is_support_data_type(&left_type) && is_support_data_type(&right_type) @@ -108,32 +110,27 @@ fn visit_expr(expr: Expr, schema: &DFSchemaRef) -> Result { (Expr::Literal(_), Expr::Literal(_)) => { // do nothing } - (Expr::Literal(left_lit_value), _) - if can_integer_literal_cast_to_type( - left_lit_value, - &right_type, - )? => - { - // cast the left literal to the right type - return Ok(binary_expr( - cast_to_other_scalar_expr(left_lit_value, &right_type)?, - *op, - right, - )); + (Expr::Literal(left_lit_value), _) => { + let (can_cast, casted_scalar_value) = + try_cast_literal_to_type(left_lit_value, &right_type)?; + if can_cast { + return Ok(binary_expr( + lit(casted_scalar_value.unwrap()), + *op, + right, + )); + } } - (_, Expr::Literal(right_lit_value)) - if can_integer_literal_cast_to_type( - right_lit_value, - &left_type, - ) - .unwrap() => - { - // cast the right literal to the left type - return Ok(binary_expr( - left, - *op, - cast_to_other_scalar_expr(right_lit_value, &left_type)?, - )); + (_, Expr::Literal(right_lit_value)) => { + let (can_cast, casted_scalar_value) = + try_cast_literal_to_type(right_lit_value, &left_type)?; + if can_cast { + return Ok(binary_expr( + left, + *op, + lit(casted_scalar_value.unwrap()), + )); + } } (_, _) => { // do nothing @@ -150,43 +147,6 @@ fn visit_expr(expr: Expr, schema: &DFSchemaRef) -> Result { } } -fn cast_to_other_scalar_expr( - origin_value: &ScalarValue, - target_type: &DataType, -) -> Result { - // null case - if origin_value.is_null() { - // if the origin value is null, just convert to another type of null value - // The target type must be satisfied `is_support_data_type` method, we can unwrap safely - return Ok(lit(ScalarValue::try_from(target_type).unwrap())); - } - // no null case - let value: i64 = match origin_value { - ScalarValue::Int8(Some(v)) => *v as i64, - ScalarValue::Int16(Some(v)) => *v as i64, - ScalarValue::Int32(Some(v)) => *v as i64, - ScalarValue::Int64(Some(v)) => *v as i64, - other_value => { - return Err(DataFusionError::Internal(format!( - "Invalid type and value {}", - other_value - ))) - } - }; - Ok(lit(match target_type { - DataType::Int8 => ScalarValue::Int8(Some(value as i8)), - DataType::Int16 => ScalarValue::Int16(Some(value as i16)), - DataType::Int32 => ScalarValue::Int32(Some(value as i32)), - DataType::Int64 => ScalarValue::Int64(Some(value)), - other_type => { - return Err(DataFusionError::Internal(format!( - "Invalid target data type {:?}", - other_type - ))) - } - })) -} - fn is_comparison_op(op: &Operator) -> bool { matches!( op, @@ -200,47 +160,112 @@ fn is_comparison_op(op: &Operator) -> bool { } fn is_support_data_type(data_type: &DataType) -> bool { - // TODO support decimal with other data type matches!( data_type, - DataType::Int8 | DataType::Int16 | DataType::Int32 | DataType::Int64 + DataType::Int8 + | DataType::Int16 + | DataType::Int32 + | DataType::Int64 + | DataType::Decimal128(_, _) ) } -fn can_integer_literal_cast_to_type( - integer_lit_value: &ScalarValue, +fn try_cast_literal_to_type( + lit_value: &ScalarValue, target_type: &DataType, -) -> Result { - if integer_lit_value.is_null() { +) -> Result<(bool, Option)> { + if lit_value.is_null() { // null value can be cast to any type of null value - return Ok(true); + return Ok((true, Some(ScalarValue::try_from(target_type)?))); } + let mul = match target_type { + DataType::Int8 | DataType::Int16 | DataType::Int32 | DataType::Int64 => 1 as i128, + DataType::Decimal128(_, scale) => 10_i128.pow(*scale as u32), + other_type => { + return Err(DataFusionError::Internal(format!( + "Error target data type {:?}", + other_type + ))); + } + }; let (target_min, target_max) = match target_type { DataType::Int8 => (i8::MIN as i128, i8::MAX as i128), DataType::Int16 => (i16::MIN as i128, i16::MAX as i128), DataType::Int32 => (i32::MIN as i128, i32::MAX as i128), DataType::Int64 => (i64::MIN as i128, i64::MAX as i128), + DataType::Decimal128(precision, _) => ( + MIN_DECIMAL_FOR_EACH_PRECISION[*precision - 1], + MAX_DECIMAL_FOR_EACH_PRECISION[*precision - 1], + ), other_type => { return Err(DataFusionError::Internal(format!( "Error target data type {:?}", other_type - ))) + ))); } }; - let lit_value = match integer_lit_value { - ScalarValue::Int8(Some(v)) => *v as i128, - ScalarValue::Int16(Some(v)) => *v as i128, - ScalarValue::Int32(Some(v)) => *v as i128, - ScalarValue::Int64(Some(v)) => *v as i128, + let lit_value_target_type = match lit_value { + ScalarValue::Int8(Some(v)) => (*v as i128).checked_mul(mul), + ScalarValue::Int16(Some(v)) => (*v as i128).checked_mul(mul), + ScalarValue::Int32(Some(v)) => (*v as i128).checked_mul(mul), + ScalarValue::Int64(Some(v)) => (*v as i128).checked_mul(mul), + ScalarValue::Decimal128(Some(v), _, scale) => { + let lit_scale_mul = 10_i128.pow(*scale as u32); + if mul >= lit_scale_mul { + // Example: + // lit is decimal(123,3,2) + // target type is decimal(5,3) + // the lit can be converted to the decimal(1230,5,3) + (*v).checked_mul(mul / lit_scale_mul) + } else { + if (*v) % (lit_scale_mul / mul) == 0 { + // Example: + // lit is decimal(123000,10,3) + // target type is int32: the lit can be converted to INT32(123) + // target type is decimal(10,2): the lit can be converted to decimal(12300,10,2) + Some(*v / (lit_scale_mul / mul)) + } else { + // can't convert the lit decimal to the target data type + None + } + } + } other_value => { return Err(DataFusionError::Internal(format!( "Invalid literal value {:?}", other_value - ))) + ))); } }; - Ok(lit_value >= target_min && lit_value <= target_max) + match lit_value_target_type { + None => Ok((false, None)), + Some(value) => { + match value >= target_min && value <= target_max { + // the value casted from lit to the target type is in the range of target type. + // return the target type of scalar value + true => { + let result_scalar = match target_type { + DataType::Int8 => ScalarValue::Int8(Some(value as i8)), + DataType::Int16 => ScalarValue::Int16(Some(value as i16)), + DataType::Int32 => ScalarValue::Int32(Some(value as i32)), + DataType::Int64 => ScalarValue::Int64(Some(value as i64)), + DataType::Decimal128(p, s) => { + ScalarValue::Decimal128(Some(value), *p, *s) + } + other_type => { + return Err(DataFusionError::Internal(format!( + "Error target data type {:?}", + other_type + ))); + } + }; + Ok((true, Some(result_scalar))) + } + false => Ok((false, None)), + } + } + } } #[cfg(test)] @@ -292,6 +317,67 @@ mod tests { assert_eq!(optimize_test(c1_lt_lit_null, &schema), expected); } + #[test] + fn test_not_cast_with_decimal_lit_comparison() { + let schema = expr_test_schema(); + // integer to decimal: value is out of the bounds of the decimal + // c3 = INT64(100000000000000000) + let expr_eq = col("c3").eq(lit(ScalarValue::Int64(Some(100000000000000000)))); + let expected = col("c3").eq(lit(ScalarValue::Int64(Some(100000000000000000)))); + assert_eq!(optimize_test(expr_eq, &schema), expected); + // c4 = INT64(1000) will overflow the i128 + let expr_eq = col("c4").eq(lit(ScalarValue::Int64(Some(1000)))); + let expected = col("c4").eq(lit(ScalarValue::Int64(Some(1000)))); + assert_eq!(optimize_test(expr_eq, &schema), expected); + + // decimal to decimal: value will lose the scale when convert to the target data type + // c3 = DECIMAL(12340,20,4) + let expr_eq = col("c3").eq(lit(ScalarValue::Decimal128(Some(12340), 20, 4))); + let expected = col("c3").eq(lit(ScalarValue::Decimal128(Some(12340), 20, 4))); + assert_eq!(optimize_test(expr_eq, &schema), expected); + + // decimal to integer + // c1 = DECIMAL(123, 10, 1): value will lose the scale when convert to the target data type + let expr_eq = col("c1").eq(lit(ScalarValue::Decimal128(Some(123), 10, 1))); + let expected = col("c1").eq(lit(ScalarValue::Decimal128(Some(123), 10, 1))); + assert_eq!(optimize_test(expr_eq, &schema), expected); + // c1 = DECIMAL(1230, 10, 2): value will lose the scale when convert to the target data type + let expr_eq = col("c1").eq(lit(ScalarValue::Decimal128(Some(1230), 10, 2))); + let expected = col("c1").eq(lit(ScalarValue::Decimal128(Some(1230), 10, 2))); + assert_eq!(optimize_test(expr_eq, &schema), expected); + } + + #[test] + fn test_pre_cast_with_decimal_lit_comparison() { + let schema = expr_test_schema(); + // integer to decimal + // c3 < INT64(16) -> c3 < (CAST(INT64(16) AS DECIMAL(18,2)); + let expr_lt = col("c3").lt(lit(ScalarValue::Int64(Some(16)))); + let expected = col("c3").lt(lit(ScalarValue::Decimal128(Some(1600), 18, 2))); + assert_eq!(optimize_test(expr_lt, &schema), expected); + + // c3 < INT64(NULL) + let c1_lt_lit_null = col("c3").lt(lit(ScalarValue::Int64(None))); + let expected = col("c3").lt(lit(ScalarValue::Decimal128(None, 18, 2))); + assert_eq!(optimize_test(c1_lt_lit_null, &schema), expected); + + // decimal to decimal + // c3 < Decimal(123,10,0) -> c3 < CAST(DECIMAL(123,10,0) AS DECIMAL(18,2)) -> c3 < DECIMAL(12300,18,2) + let expr_lt = col("c3").lt(lit(ScalarValue::Decimal128(Some(123), 10, 0))); + let expected = col("c3").lt(lit(ScalarValue::Decimal128(Some(12300), 18, 2))); + assert_eq!(optimize_test(expr_lt, &schema), expected); + // c3 < Decimal(1230,10,3) -> c3 < CAST(DECIMAL(1230,10,3) AS DECIMAL(18,2)) -> c3 < DECIMAL(123,18,2) + let expr_lt = col("c3").lt(lit(ScalarValue::Decimal128(Some(1230), 10, 3))); + let expected = col("c3").lt(lit(ScalarValue::Decimal128(Some(123), 18, 2))); + assert_eq!(optimize_test(expr_lt, &schema), expected); + + // decimal to integer + // c1 < Decimal(12300, 10, 2) -> c1 < CAST(DECIMAL(12300,10,2) AS INT32) -> c1 < INT32(123) + let expr_lt = col("c1").lt(lit(ScalarValue::Decimal128(Some(12300), 10, 2))); + let expected = col("c1").lt(lit(ScalarValue::Int32(Some(123)))); + assert_eq!(optimize_test(expr_lt, &schema), expected); + } + fn optimize_test(expr: Expr, schema: &DFSchemaRef) -> Expr { visit_expr(expr, schema).unwrap() } @@ -302,6 +388,8 @@ mod tests { vec![ DFField::new(None, "c1", DataType::Int32, false), DFField::new(None, "c2", DataType::Int64, false), + DFField::new(None, "c3", DataType::Decimal128(18, 2), false), + DFField::new(None, "c4", DataType::Decimal128(38, 37), false), ], HashMap::new(), )