From 51029d5151162834a495db258f4f43fb764f90aa Mon Sep 17 00:00:00 2001 From: liukun4515 Date: Wed, 24 Aug 2022 17:04:24 +0800 Subject: [PATCH 1/4] support decimal for the PreCastLitInComparisonExpressions rule --- .../src/pre_cast_lit_in_comparison.rs | 246 ++++++++++++------ 1 file changed, 166 insertions(+), 80 deletions(-) diff --git a/datafusion/optimizer/src/pre_cast_lit_in_comparison.rs b/datafusion/optimizer/src/pre_cast_lit_in_comparison.rs index 0c16f7921c32..b27c2f7c6b27 100644 --- a/datafusion/optimizer/src/pre_cast_lit_in_comparison.rs +++ b/datafusion/optimizer/src/pre_cast_lit_in_comparison.rs @@ -18,7 +18,9 @@ //! Pre-cast literal binary comparison rule can be only used to the binary comparison expr. //! It can reduce adding the `Expr::Cast` to the expr instead of adding the `Expr::Cast` to literal expr. use crate::{OptimizerConfig, OptimizerRule}; -use arrow::datatypes::DataType; +use arrow::datatypes::{ + DataType, MAX_DECIMAL_FOR_EACH_PRECISION, MIN_DECIMAL_FOR_EACH_PRECISION, +}; use datafusion_common::{DFSchemaRef, DataFusionError, Result, ScalarValue}; use datafusion_expr::utils::from_plan; use datafusion_expr::{binary_expr, lit, Expr, ExprSchemable, LogicalPlan, Operator}; @@ -97,8 +99,8 @@ fn visit_expr(expr: Expr, schema: &DFSchemaRef) -> Result { if left_type.is_err() || right_type.is_err() { return Ok(expr.clone()); } - let left_type = left_type.unwrap(); - let right_type = right_type.unwrap(); + let left_type = left_type?; + let right_type = right_type?; if !left_type.eq(&right_type) && is_support_data_type(&left_type) && is_support_data_type(&right_type) @@ -108,32 +110,27 @@ fn visit_expr(expr: Expr, schema: &DFSchemaRef) -> Result { (Expr::Literal(_), Expr::Literal(_)) => { // do nothing } - (Expr::Literal(left_lit_value), _) - if can_integer_literal_cast_to_type( - left_lit_value, - &right_type, - )? => - { - // cast the left literal to the right type - return Ok(binary_expr( - cast_to_other_scalar_expr(left_lit_value, &right_type)?, - *op, - right, - )); + (Expr::Literal(left_lit_value), _) => { + let (can_cast, casted_scalar_value) = + try_cast_literal_to_type(left_lit_value, &right_type)?; + if can_cast { + return Ok(binary_expr( + lit(casted_scalar_value.unwrap()), + *op, + right, + )); + } } - (_, Expr::Literal(right_lit_value)) - if can_integer_literal_cast_to_type( - right_lit_value, - &left_type, - ) - .unwrap() => - { - // cast the right literal to the left type - return Ok(binary_expr( - left, - *op, - cast_to_other_scalar_expr(right_lit_value, &left_type)?, - )); + (_, Expr::Literal(right_lit_value)) => { + let (can_cast, casted_scalar_value) = + try_cast_literal_to_type(right_lit_value, &left_type)?; + if can_cast { + return Ok(binary_expr( + left, + *op, + lit(casted_scalar_value.unwrap()), + )); + } } (_, _) => { // do nothing @@ -150,43 +147,6 @@ fn visit_expr(expr: Expr, schema: &DFSchemaRef) -> Result { } } -fn cast_to_other_scalar_expr( - origin_value: &ScalarValue, - target_type: &DataType, -) -> Result { - // null case - if origin_value.is_null() { - // if the origin value is null, just convert to another type of null value - // The target type must be satisfied `is_support_data_type` method, we can unwrap safely - return Ok(lit(ScalarValue::try_from(target_type).unwrap())); - } - // no null case - let value: i64 = match origin_value { - ScalarValue::Int8(Some(v)) => *v as i64, - ScalarValue::Int16(Some(v)) => *v as i64, - ScalarValue::Int32(Some(v)) => *v as i64, - ScalarValue::Int64(Some(v)) => *v as i64, - other_value => { - return Err(DataFusionError::Internal(format!( - "Invalid type and value {}", - other_value - ))) - } - }; - Ok(lit(match target_type { - DataType::Int8 => ScalarValue::Int8(Some(value as i8)), - DataType::Int16 => ScalarValue::Int16(Some(value as i16)), - DataType::Int32 => ScalarValue::Int32(Some(value as i32)), - DataType::Int64 => ScalarValue::Int64(Some(value)), - other_type => { - return Err(DataFusionError::Internal(format!( - "Invalid target data type {:?}", - other_type - ))) - } - })) -} - fn is_comparison_op(op: &Operator) -> bool { matches!( op, @@ -200,47 +160,110 @@ fn is_comparison_op(op: &Operator) -> bool { } fn is_support_data_type(data_type: &DataType) -> bool { - // TODO support decimal with other data type matches!( data_type, - DataType::Int8 | DataType::Int16 | DataType::Int32 | DataType::Int64 + DataType::Int8 + | DataType::Int16 + | DataType::Int32 + | DataType::Int64 + | DataType::Decimal128(_, _) ) } -fn can_integer_literal_cast_to_type( - integer_lit_value: &ScalarValue, +fn try_cast_literal_to_type( + lit_value: &ScalarValue, target_type: &DataType, -) -> Result { - if integer_lit_value.is_null() { +) -> Result<(bool, Option)> { + if lit_value.is_null() { // null value can be cast to any type of null value - return Ok(true); + return Ok((true, Some(ScalarValue::try_from(target_type)?))); } + let mul = match target_type { + DataType::Int8 | DataType::Int16 | DataType::Int32 | DataType::Int64 => 1_i128, + DataType::Decimal128(_, scale) => 10_i128.pow(*scale as u32), + other_type => { + return Err(DataFusionError::Internal(format!( + "Error target data type {:?}", + other_type + ))); + } + }; let (target_min, target_max) = match target_type { DataType::Int8 => (i8::MIN as i128, i8::MAX as i128), DataType::Int16 => (i16::MIN as i128, i16::MAX as i128), DataType::Int32 => (i32::MIN as i128, i32::MAX as i128), DataType::Int64 => (i64::MIN as i128, i64::MAX as i128), + DataType::Decimal128(precision, _) => ( + MIN_DECIMAL_FOR_EACH_PRECISION[*precision - 1], + MAX_DECIMAL_FOR_EACH_PRECISION[*precision - 1], + ), other_type => { return Err(DataFusionError::Internal(format!( "Error target data type {:?}", other_type - ))) + ))); } }; - let lit_value = match integer_lit_value { - ScalarValue::Int8(Some(v)) => *v as i128, - ScalarValue::Int16(Some(v)) => *v as i128, - ScalarValue::Int32(Some(v)) => *v as i128, - ScalarValue::Int64(Some(v)) => *v as i128, + let lit_value_target_type = match lit_value { + ScalarValue::Int8(Some(v)) => (*v as i128).checked_mul(mul), + ScalarValue::Int16(Some(v)) => (*v as i128).checked_mul(mul), + ScalarValue::Int32(Some(v)) => (*v as i128).checked_mul(mul), + ScalarValue::Int64(Some(v)) => (*v as i128).checked_mul(mul), + ScalarValue::Decimal128(Some(v), _, scale) => { + let lit_scale_mul = 10_i128.pow(*scale as u32); + if mul >= lit_scale_mul { + // Example: + // lit is decimal(123,3,2) + // target type is decimal(5,3) + // the lit can be converted to the decimal(1230,5,3) + (*v).checked_mul(mul / lit_scale_mul) + } else if (*v) % (lit_scale_mul / mul) == 0 { + // Example: + // lit is decimal(123000,10,3) + // target type is int32: the lit can be converted to INT32(123) + // target type is decimal(10,2): the lit can be converted to decimal(12300,10,2) + Some(*v / (lit_scale_mul / mul)) + } else { + // can't convert the lit decimal to the target data type + None + } + } other_value => { return Err(DataFusionError::Internal(format!( "Invalid literal value {:?}", other_value - ))) + ))); } }; - Ok(lit_value >= target_min && lit_value <= target_max) + match lit_value_target_type { + None => Ok((false, None)), + Some(value) => { + match value >= target_min && value <= target_max { + // the value casted from lit to the target type is in the range of target type. + // return the target type of scalar value + true => { + let result_scalar = match target_type { + DataType::Int8 => ScalarValue::Int8(Some(value as i8)), + DataType::Int16 => ScalarValue::Int16(Some(value as i16)), + DataType::Int32 => ScalarValue::Int32(Some(value as i32)), + DataType::Int64 => ScalarValue::Int64(Some(value as i64)), + DataType::Decimal128(p, s) => { + ScalarValue::Decimal128(Some(value), *p, *s) + } + other_type => { + return Err(DataFusionError::Internal(format!( + "Error target data type {:?}", + other_type + ))); + } + }; + Ok((true, Some(result_scalar))) + } + false => Ok((false, None)), + } + } + } } #[cfg(test)] @@ -292,6 +315,67 @@ mod tests { assert_eq!(optimize_test(c1_lt_lit_null, &schema), expected); } + #[test] + fn test_not_cast_with_decimal_lit_comparison() { + let schema = expr_test_schema(); + // integer to decimal: value is out of the bounds of the decimal + // c3 = INT64(100000000000000000) + let expr_eq = col("c3").eq(lit(ScalarValue::Int64(Some(100000000000000000)))); + let expected = col("c3").eq(lit(ScalarValue::Int64(Some(100000000000000000)))); + assert_eq!(optimize_test(expr_eq, &schema), expected); + // c4 = INT64(1000) will overflow the i128 + let expr_eq = col("c4").eq(lit(ScalarValue::Int64(Some(1000)))); + let expected = col("c4").eq(lit(ScalarValue::Int64(Some(1000)))); + assert_eq!(optimize_test(expr_eq, &schema), expected); + + // decimal to decimal: value will lose the scale when convert to the target data type + // c3 = DECIMAL(12340,20,4) + let expr_eq = col("c3").eq(lit(ScalarValue::Decimal128(Some(12340), 20, 4))); + let expected = col("c3").eq(lit(ScalarValue::Decimal128(Some(12340), 20, 4))); + assert_eq!(optimize_test(expr_eq, &schema), expected); + + // decimal to integer + // c1 = DECIMAL(123, 10, 1): value will lose the scale when convert to the target data type + let expr_eq = col("c1").eq(lit(ScalarValue::Decimal128(Some(123), 10, 1))); + let expected = col("c1").eq(lit(ScalarValue::Decimal128(Some(123), 10, 1))); + assert_eq!(optimize_test(expr_eq, &schema), expected); + // c1 = DECIMAL(1230, 10, 2): value will lose the scale when convert to the target data type + let expr_eq = col("c1").eq(lit(ScalarValue::Decimal128(Some(1230), 10, 2))); + let expected = col("c1").eq(lit(ScalarValue::Decimal128(Some(1230), 10, 2))); + assert_eq!(optimize_test(expr_eq, &schema), expected); + } + + #[test] + fn test_pre_cast_with_decimal_lit_comparison() { + let schema = expr_test_schema(); + // integer to decimal + // c3 < INT64(16) -> c3 < (CAST(INT64(16) AS DECIMAL(18,2)); + let expr_lt = col("c3").lt(lit(ScalarValue::Int64(Some(16)))); + let expected = col("c3").lt(lit(ScalarValue::Decimal128(Some(1600), 18, 2))); + assert_eq!(optimize_test(expr_lt, &schema), expected); + + // c3 < INT64(NULL) + let c1_lt_lit_null = col("c3").lt(lit(ScalarValue::Int64(None))); + let expected = col("c3").lt(lit(ScalarValue::Decimal128(None, 18, 2))); + assert_eq!(optimize_test(c1_lt_lit_null, &schema), expected); + + // decimal to decimal + // c3 < Decimal(123,10,0) -> c3 < CAST(DECIMAL(123,10,0) AS DECIMAL(18,2)) -> c3 < DECIMAL(12300,18,2) + let expr_lt = col("c3").lt(lit(ScalarValue::Decimal128(Some(123), 10, 0))); + let expected = col("c3").lt(lit(ScalarValue::Decimal128(Some(12300), 18, 2))); + assert_eq!(optimize_test(expr_lt, &schema), expected); + // c3 < Decimal(1230,10,3) -> c3 < CAST(DECIMAL(1230,10,3) AS DECIMAL(18,2)) -> c3 < DECIMAL(123,18,2) + let expr_lt = col("c3").lt(lit(ScalarValue::Decimal128(Some(1230), 10, 3))); + let expected = col("c3").lt(lit(ScalarValue::Decimal128(Some(123), 18, 2))); + assert_eq!(optimize_test(expr_lt, &schema), expected); + + // decimal to integer + // c1 < Decimal(12300, 10, 2) -> c1 < CAST(DECIMAL(12300,10,2) AS INT32) -> c1 < INT32(123) + let expr_lt = col("c1").lt(lit(ScalarValue::Decimal128(Some(12300), 10, 2))); + let expected = col("c1").lt(lit(ScalarValue::Int32(Some(123)))); + assert_eq!(optimize_test(expr_lt, &schema), expected); + } + fn optimize_test(expr: Expr, schema: &DFSchemaRef) -> Expr { visit_expr(expr, schema).unwrap() } @@ -302,6 +386,8 @@ mod tests { vec![ DFField::new(None, "c1", DataType::Int32, false), DFField::new(None, "c2", DataType::Int64, false), + DFField::new(None, "c3", DataType::Decimal128(18, 2), false), + DFField::new(None, "c4", DataType::Decimal128(38, 37), false), ], HashMap::new(), ) From 6a82273fc01d6dca01fba63378b08889644c297a Mon Sep 17 00:00:00 2001 From: liukun4515 Date: Thu, 25 Aug 2022 10:26:52 +0800 Subject: [PATCH 2/4] address comments --- .../src/pre_cast_lit_in_comparison.rs | 65 ++++++++----------- 1 file changed, 28 insertions(+), 37 deletions(-) diff --git a/datafusion/optimizer/src/pre_cast_lit_in_comparison.rs b/datafusion/optimizer/src/pre_cast_lit_in_comparison.rs index b27c2f7c6b27..f625edfca413 100644 --- a/datafusion/optimizer/src/pre_cast_lit_in_comparison.rs +++ b/datafusion/optimizer/src/pre_cast_lit_in_comparison.rs @@ -111,25 +111,17 @@ fn visit_expr(expr: Expr, schema: &DFSchemaRef) -> Result { // do nothing } (Expr::Literal(left_lit_value), _) => { - let (can_cast, casted_scalar_value) = + let casted_scalar_value = try_cast_literal_to_type(left_lit_value, &right_type)?; - if can_cast { - return Ok(binary_expr( - lit(casted_scalar_value.unwrap()), - *op, - right, - )); + if let Some(value) = casted_scalar_value { + return Ok(binary_expr(lit(value), *op, right)); } } (_, Expr::Literal(right_lit_value)) => { - let (can_cast, casted_scalar_value) = + let casted_scalar_value = try_cast_literal_to_type(right_lit_value, &left_type)?; - if can_cast { - return Ok(binary_expr( - left, - *op, - lit(casted_scalar_value.unwrap()), - )); + if let Some(value) = casted_scalar_value { + return Ok(binary_expr(left, *op, lit(value))); } } (_, _) => { @@ -173,10 +165,10 @@ fn is_support_data_type(data_type: &DataType) -> bool { fn try_cast_literal_to_type( lit_value: &ScalarValue, target_type: &DataType, -) -> Result<(bool, Option)> { +) -> Result> { if lit_value.is_null() { // null value can be cast to any type of null value - return Ok((true, Some(ScalarValue::try_from(target_type)?))); + return Ok(Some(ScalarValue::try_from(target_type)?)); } let mul = match target_type { DataType::Int8 | DataType::Int16 | DataType::Int32 | DataType::Int64 => 1_i128, @@ -237,30 +229,29 @@ fn try_cast_literal_to_type( }; match lit_value_target_type { - None => Ok((false, None)), + None => Ok(None), Some(value) => { - match value >= target_min && value <= target_max { + if value >= target_min && value <= target_max { // the value casted from lit to the target type is in the range of target type. // return the target type of scalar value - true => { - let result_scalar = match target_type { - DataType::Int8 => ScalarValue::Int8(Some(value as i8)), - DataType::Int16 => ScalarValue::Int16(Some(value as i16)), - DataType::Int32 => ScalarValue::Int32(Some(value as i32)), - DataType::Int64 => ScalarValue::Int64(Some(value as i64)), - DataType::Decimal128(p, s) => { - ScalarValue::Decimal128(Some(value), *p, *s) - } - other_type => { - return Err(DataFusionError::Internal(format!( - "Error target data type {:?}", - other_type - ))); - } - }; - Ok((true, Some(result_scalar))) - } - false => Ok((false, None)), + let result_scalar = match target_type { + DataType::Int8 => ScalarValue::Int8(Some(value as i8)), + DataType::Int16 => ScalarValue::Int16(Some(value as i16)), + DataType::Int32 => ScalarValue::Int32(Some(value as i32)), + DataType::Int64 => ScalarValue::Int64(Some(value as i64)), + DataType::Decimal128(p, s) => { + ScalarValue::Decimal128(Some(value), *p, *s) + } + other_type => { + return Err(DataFusionError::Internal(format!( + "Error target data type {:?}", + other_type + ))); + } + }; + Ok(Some(result_scalar)) + } else { + Ok(None) } } } From 6170817466d4a5c7a82a46eed4295fcf3ae6e7eb Mon Sep 17 00:00:00 2001 From: liukun4515 Date: Fri, 26 Aug 2022 14:01:56 +0800 Subject: [PATCH 3/4] fix the lint --- .../src/pre_cast_lit_in_comparison.rs | 47 +------------------ 1 file changed, 1 insertion(+), 46 deletions(-) diff --git a/datafusion/optimizer/src/pre_cast_lit_in_comparison.rs b/datafusion/optimizer/src/pre_cast_lit_in_comparison.rs index b69c26372916..aea16480c5dc 100644 --- a/datafusion/optimizer/src/pre_cast_lit_in_comparison.rs +++ b/datafusion/optimizer/src/pre_cast_lit_in_comparison.rs @@ -91,51 +91,6 @@ fn optimize(plan: &LogicalPlan) -> Result { from_plan(plan, new_exprs.as_slice(), new_inputs.as_slice()) } -// <<<<<<< HEAD -// // Visit all type of expr, if the current has child expr, the child expr needed to visit first. -// fn visit_expr(expr: Expr, schema: &DFSchemaRef) -> Result { -// // traverse the expr by dfs -// match &expr { -// Expr::BinaryExpr { left, op, right } => { -// // dfs visit the left and right expr -// let left = visit_expr(*left.clone(), schema)?; -// let right = visit_expr(*right.clone(), schema)?; -// let left_type = left.get_type(schema); -// let right_type = right.get_type(schema); -// // can't get the data type, just return the expr -// if left_type.is_err() || right_type.is_err() { -// return Ok(expr.clone()); -// } -// let left_type = left_type?; -// let right_type = right_type?; -// if !left_type.eq(&right_type) -// && is_support_data_type(&left_type) -// && is_support_data_type(&right_type) -// && is_comparison_op(op) -// { -// match (&left, &right) { -// (Expr::Literal(_), Expr::Literal(_)) => { -// // do nothing -// } -// (Expr::Literal(left_lit_value), _) => { -// let casted_scalar_value = -// try_cast_literal_to_type(left_lit_value, &right_type)?; -// if let Some(value) = casted_scalar_value { -// return Ok(binary_expr(lit(value), *op, right)); -// } -// } -// (_, Expr::Literal(right_lit_value)) => { -// let casted_scalar_value = -// try_cast_literal_to_type(right_lit_value, &left_type)?; -// if let Some(value) = casted_scalar_value { -// return Ok(binary_expr(left, *op, lit(value))); -// } -// } -// (_, _) => { -// // do nothing -// } -// }; -// ======= struct PreCastLitExprRewriter { schema: DFSchemaRef, } @@ -468,7 +423,7 @@ mod tests { ], HashMap::new(), ) - .unwrap(), + .unwrap(), ) } } From 3c64aef34e2c7118c7ceafdb523d0ed8f3cd8f90 Mon Sep 17 00:00:00 2001 From: liukun4515 Date: Sat, 27 Aug 2022 11:38:06 +0800 Subject: [PATCH 4/4] add comments --- datafusion/optimizer/src/pre_cast_lit_in_comparison.rs | 3 +++ 1 file changed, 3 insertions(+) diff --git a/datafusion/optimizer/src/pre_cast_lit_in_comparison.rs b/datafusion/optimizer/src/pre_cast_lit_in_comparison.rs index aea16480c5dc..6e89afd600be 100644 --- a/datafusion/optimizer/src/pre_cast_lit_in_comparison.rs +++ b/datafusion/optimizer/src/pre_cast_lit_in_comparison.rs @@ -199,6 +199,9 @@ fn try_cast_literal_to_type( DataType::Int32 => (i32::MIN as i128, i32::MAX as i128), DataType::Int64 => (i64::MIN as i128, i64::MAX as i128), DataType::Decimal128(precision, _) => ( + // Different precision for decimal128 can store different range of value. + // For example, the precision is 3, the max of value is `999` and the min + // value is `-999` MIN_DECIMAL_FOR_EACH_PRECISION[*precision - 1], MAX_DECIMAL_FOR_EACH_PRECISION[*precision - 1], ),