Skip to content

Commit

Permalink
support decimal for the PreCastLitInComparisonExpressions rule
Browse files Browse the repository at this point in the history
  • Loading branch information
liukun4515 committed Aug 24, 2022
1 parent c574269 commit 51029d5
Showing 1 changed file with 166 additions and 80 deletions.
246 changes: 166 additions & 80 deletions datafusion/optimizer/src/pre_cast_lit_in_comparison.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@
//! Pre-cast literal binary comparison rule can be only used to the binary comparison expr.
//! It can reduce adding the `Expr::Cast` to the expr instead of adding the `Expr::Cast` to literal expr.
use crate::{OptimizerConfig, OptimizerRule};
use arrow::datatypes::DataType;
use arrow::datatypes::{
DataType, MAX_DECIMAL_FOR_EACH_PRECISION, MIN_DECIMAL_FOR_EACH_PRECISION,
};
use datafusion_common::{DFSchemaRef, DataFusionError, Result, ScalarValue};
use datafusion_expr::utils::from_plan;
use datafusion_expr::{binary_expr, lit, Expr, ExprSchemable, LogicalPlan, Operator};
Expand Down Expand Up @@ -97,8 +99,8 @@ fn visit_expr(expr: Expr, schema: &DFSchemaRef) -> Result<Expr> {
if left_type.is_err() || right_type.is_err() {
return Ok(expr.clone());
}
let left_type = left_type.unwrap();
let right_type = right_type.unwrap();
let left_type = left_type?;
let right_type = right_type?;
if !left_type.eq(&right_type)
&& is_support_data_type(&left_type)
&& is_support_data_type(&right_type)
Expand All @@ -108,32 +110,27 @@ fn visit_expr(expr: Expr, schema: &DFSchemaRef) -> Result<Expr> {
(Expr::Literal(_), Expr::Literal(_)) => {
// do nothing
}
(Expr::Literal(left_lit_value), _)
if can_integer_literal_cast_to_type(
left_lit_value,
&right_type,
)? =>
{
// cast the left literal to the right type
return Ok(binary_expr(
cast_to_other_scalar_expr(left_lit_value, &right_type)?,
*op,
right,
));
(Expr::Literal(left_lit_value), _) => {
let (can_cast, casted_scalar_value) =
try_cast_literal_to_type(left_lit_value, &right_type)?;
if can_cast {
return Ok(binary_expr(
lit(casted_scalar_value.unwrap()),
*op,
right,
));
}
}
(_, Expr::Literal(right_lit_value))
if can_integer_literal_cast_to_type(
right_lit_value,
&left_type,
)
.unwrap() =>
{
// cast the right literal to the left type
return Ok(binary_expr(
left,
*op,
cast_to_other_scalar_expr(right_lit_value, &left_type)?,
));
(_, Expr::Literal(right_lit_value)) => {
let (can_cast, casted_scalar_value) =
try_cast_literal_to_type(right_lit_value, &left_type)?;
if can_cast {
return Ok(binary_expr(
left,
*op,
lit(casted_scalar_value.unwrap()),
));
}
}
(_, _) => {
// do nothing
Expand All @@ -150,43 +147,6 @@ fn visit_expr(expr: Expr, schema: &DFSchemaRef) -> Result<Expr> {
}
}

fn cast_to_other_scalar_expr(
origin_value: &ScalarValue,
target_type: &DataType,
) -> Result<Expr> {
// null case
if origin_value.is_null() {
// if the origin value is null, just convert to another type of null value
// The target type must be satisfied `is_support_data_type` method, we can unwrap safely
return Ok(lit(ScalarValue::try_from(target_type).unwrap()));
}
// no null case
let value: i64 = match origin_value {
ScalarValue::Int8(Some(v)) => *v as i64,
ScalarValue::Int16(Some(v)) => *v as i64,
ScalarValue::Int32(Some(v)) => *v as i64,
ScalarValue::Int64(Some(v)) => *v as i64,
other_value => {
return Err(DataFusionError::Internal(format!(
"Invalid type and value {}",
other_value
)))
}
};
Ok(lit(match target_type {
DataType::Int8 => ScalarValue::Int8(Some(value as i8)),
DataType::Int16 => ScalarValue::Int16(Some(value as i16)),
DataType::Int32 => ScalarValue::Int32(Some(value as i32)),
DataType::Int64 => ScalarValue::Int64(Some(value)),
other_type => {
return Err(DataFusionError::Internal(format!(
"Invalid target data type {:?}",
other_type
)))
}
}))
}

fn is_comparison_op(op: &Operator) -> bool {
matches!(
op,
Expand All @@ -200,47 +160,110 @@ fn is_comparison_op(op: &Operator) -> bool {
}

fn is_support_data_type(data_type: &DataType) -> bool {
// TODO support decimal with other data type
matches!(
data_type,
DataType::Int8 | DataType::Int16 | DataType::Int32 | DataType::Int64
DataType::Int8
| DataType::Int16
| DataType::Int32
| DataType::Int64
| DataType::Decimal128(_, _)
)
}

fn can_integer_literal_cast_to_type(
integer_lit_value: &ScalarValue,
fn try_cast_literal_to_type(
lit_value: &ScalarValue,
target_type: &DataType,
) -> Result<bool> {
if integer_lit_value.is_null() {
) -> Result<(bool, Option<ScalarValue>)> {
if lit_value.is_null() {
// null value can be cast to any type of null value
return Ok(true);
return Ok((true, Some(ScalarValue::try_from(target_type)?)));
}
let mul = match target_type {
DataType::Int8 | DataType::Int16 | DataType::Int32 | DataType::Int64 => 1_i128,
DataType::Decimal128(_, scale) => 10_i128.pow(*scale as u32),
other_type => {
return Err(DataFusionError::Internal(format!(
"Error target data type {:?}",
other_type
)));
}
};
let (target_min, target_max) = match target_type {
DataType::Int8 => (i8::MIN as i128, i8::MAX as i128),
DataType::Int16 => (i16::MIN as i128, i16::MAX as i128),
DataType::Int32 => (i32::MIN as i128, i32::MAX as i128),
DataType::Int64 => (i64::MIN as i128, i64::MAX as i128),
DataType::Decimal128(precision, _) => (
MIN_DECIMAL_FOR_EACH_PRECISION[*precision - 1],
MAX_DECIMAL_FOR_EACH_PRECISION[*precision - 1],
),
other_type => {
return Err(DataFusionError::Internal(format!(
"Error target data type {:?}",
other_type
)))
)));
}
};
let lit_value = match integer_lit_value {
ScalarValue::Int8(Some(v)) => *v as i128,
ScalarValue::Int16(Some(v)) => *v as i128,
ScalarValue::Int32(Some(v)) => *v as i128,
ScalarValue::Int64(Some(v)) => *v as i128,
let lit_value_target_type = match lit_value {
ScalarValue::Int8(Some(v)) => (*v as i128).checked_mul(mul),
ScalarValue::Int16(Some(v)) => (*v as i128).checked_mul(mul),
ScalarValue::Int32(Some(v)) => (*v as i128).checked_mul(mul),
ScalarValue::Int64(Some(v)) => (*v as i128).checked_mul(mul),
ScalarValue::Decimal128(Some(v), _, scale) => {
let lit_scale_mul = 10_i128.pow(*scale as u32);
if mul >= lit_scale_mul {
// Example:
// lit is decimal(123,3,2)
// target type is decimal(5,3)
// the lit can be converted to the decimal(1230,5,3)
(*v).checked_mul(mul / lit_scale_mul)
} else if (*v) % (lit_scale_mul / mul) == 0 {
// Example:
// lit is decimal(123000,10,3)
// target type is int32: the lit can be converted to INT32(123)
// target type is decimal(10,2): the lit can be converted to decimal(12300,10,2)
Some(*v / (lit_scale_mul / mul))
} else {
// can't convert the lit decimal to the target data type
None
}
}
other_value => {
return Err(DataFusionError::Internal(format!(
"Invalid literal value {:?}",
other_value
)))
)));
}
};

Ok(lit_value >= target_min && lit_value <= target_max)
match lit_value_target_type {
None => Ok((false, None)),
Some(value) => {
match value >= target_min && value <= target_max {
// the value casted from lit to the target type is in the range of target type.
// return the target type of scalar value
true => {
let result_scalar = match target_type {
DataType::Int8 => ScalarValue::Int8(Some(value as i8)),
DataType::Int16 => ScalarValue::Int16(Some(value as i16)),
DataType::Int32 => ScalarValue::Int32(Some(value as i32)),
DataType::Int64 => ScalarValue::Int64(Some(value as i64)),
DataType::Decimal128(p, s) => {
ScalarValue::Decimal128(Some(value), *p, *s)
}
other_type => {
return Err(DataFusionError::Internal(format!(
"Error target data type {:?}",
other_type
)));
}
};
Ok((true, Some(result_scalar)))
}
false => Ok((false, None)),
}
}
}
}

#[cfg(test)]
Expand Down Expand Up @@ -292,6 +315,67 @@ mod tests {
assert_eq!(optimize_test(c1_lt_lit_null, &schema), expected);
}

#[test]
fn test_not_cast_with_decimal_lit_comparison() {
let schema = expr_test_schema();
// integer to decimal: value is out of the bounds of the decimal
// c3 = INT64(100000000000000000)
let expr_eq = col("c3").eq(lit(ScalarValue::Int64(Some(100000000000000000))));
let expected = col("c3").eq(lit(ScalarValue::Int64(Some(100000000000000000))));
assert_eq!(optimize_test(expr_eq, &schema), expected);
// c4 = INT64(1000) will overflow the i128
let expr_eq = col("c4").eq(lit(ScalarValue::Int64(Some(1000))));
let expected = col("c4").eq(lit(ScalarValue::Int64(Some(1000))));
assert_eq!(optimize_test(expr_eq, &schema), expected);

// decimal to decimal: value will lose the scale when convert to the target data type
// c3 = DECIMAL(12340,20,4)
let expr_eq = col("c3").eq(lit(ScalarValue::Decimal128(Some(12340), 20, 4)));
let expected = col("c3").eq(lit(ScalarValue::Decimal128(Some(12340), 20, 4)));
assert_eq!(optimize_test(expr_eq, &schema), expected);

// decimal to integer
// c1 = DECIMAL(123, 10, 1): value will lose the scale when convert to the target data type
let expr_eq = col("c1").eq(lit(ScalarValue::Decimal128(Some(123), 10, 1)));
let expected = col("c1").eq(lit(ScalarValue::Decimal128(Some(123), 10, 1)));
assert_eq!(optimize_test(expr_eq, &schema), expected);
// c1 = DECIMAL(1230, 10, 2): value will lose the scale when convert to the target data type
let expr_eq = col("c1").eq(lit(ScalarValue::Decimal128(Some(1230), 10, 2)));
let expected = col("c1").eq(lit(ScalarValue::Decimal128(Some(1230), 10, 2)));
assert_eq!(optimize_test(expr_eq, &schema), expected);
}

#[test]
fn test_pre_cast_with_decimal_lit_comparison() {
let schema = expr_test_schema();
// integer to decimal
// c3 < INT64(16) -> c3 < (CAST(INT64(16) AS DECIMAL(18,2));
let expr_lt = col("c3").lt(lit(ScalarValue::Int64(Some(16))));
let expected = col("c3").lt(lit(ScalarValue::Decimal128(Some(1600), 18, 2)));
assert_eq!(optimize_test(expr_lt, &schema), expected);

// c3 < INT64(NULL)
let c1_lt_lit_null = col("c3").lt(lit(ScalarValue::Int64(None)));
let expected = col("c3").lt(lit(ScalarValue::Decimal128(None, 18, 2)));
assert_eq!(optimize_test(c1_lt_lit_null, &schema), expected);

// decimal to decimal
// c3 < Decimal(123,10,0) -> c3 < CAST(DECIMAL(123,10,0) AS DECIMAL(18,2)) -> c3 < DECIMAL(12300,18,2)
let expr_lt = col("c3").lt(lit(ScalarValue::Decimal128(Some(123), 10, 0)));
let expected = col("c3").lt(lit(ScalarValue::Decimal128(Some(12300), 18, 2)));
assert_eq!(optimize_test(expr_lt, &schema), expected);
// c3 < Decimal(1230,10,3) -> c3 < CAST(DECIMAL(1230,10,3) AS DECIMAL(18,2)) -> c3 < DECIMAL(123,18,2)
let expr_lt = col("c3").lt(lit(ScalarValue::Decimal128(Some(1230), 10, 3)));
let expected = col("c3").lt(lit(ScalarValue::Decimal128(Some(123), 18, 2)));
assert_eq!(optimize_test(expr_lt, &schema), expected);

// decimal to integer
// c1 < Decimal(12300, 10, 2) -> c1 < CAST(DECIMAL(12300,10,2) AS INT32) -> c1 < INT32(123)
let expr_lt = col("c1").lt(lit(ScalarValue::Decimal128(Some(12300), 10, 2)));
let expected = col("c1").lt(lit(ScalarValue::Int32(Some(123))));
assert_eq!(optimize_test(expr_lt, &schema), expected);
}

fn optimize_test(expr: Expr, schema: &DFSchemaRef) -> Expr {
visit_expr(expr, schema).unwrap()
}
Expand All @@ -302,6 +386,8 @@ mod tests {
vec![
DFField::new(None, "c1", DataType::Int32, false),
DFField::new(None, "c2", DataType::Int64, false),
DFField::new(None, "c3", DataType::Decimal128(18, 2), false),
DFField::new(None, "c4", DataType::Decimal128(38, 37), false),
],
HashMap::new(),
)
Expand Down

0 comments on commit 51029d5

Please sign in to comment.