Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support decimal data type for the optimizer rule of PreCastLitInComparisonExpressions #3245

Merged
merged 5 commits into from
Aug 27, 2022
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
237 changes: 158 additions & 79 deletions datafusion/optimizer/src/pre_cast_lit_in_comparison.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@
//! Pre-cast literal binary comparison rule can be only used to the binary comparison expr.
//! It can reduce adding the `Expr::Cast` to the expr instead of adding the `Expr::Cast` to literal expr.
use crate::{OptimizerConfig, OptimizerRule};
use arrow::datatypes::DataType;
use arrow::datatypes::{
DataType, MAX_DECIMAL_FOR_EACH_PRECISION, MIN_DECIMAL_FOR_EACH_PRECISION,
};
use datafusion_common::{DFSchemaRef, DataFusionError, Result, ScalarValue};
use datafusion_expr::expr_rewriter::{ExprRewritable, ExprRewriter, RewriteRecursion};
use datafusion_expr::utils::from_plan;
Expand Down Expand Up @@ -99,7 +101,6 @@ impl ExprRewriter for PreCastLitExprRewriter {
}

fn mutate(&mut self, expr: Expr) -> Result<Expr> {
// traverse the expr by dfs
match &expr {
Expr::BinaryExpr { left, op, right } => {
let left = left.as_ref().clone();
Expand All @@ -121,32 +122,19 @@ impl ExprRewriter for PreCastLitExprRewriter {
(Expr::Literal(_), Expr::Literal(_)) => {
// do nothing
}
(Expr::Literal(left_lit_value), _)
if can_integer_literal_cast_to_type(
left_lit_value,
&right_type,
)? =>
{
// cast the left literal to the right type
return Ok(binary_expr(
cast_to_other_scalar_expr(left_lit_value, &right_type)?,
*op,
right,
));
(Expr::Literal(left_lit_value), _) => {
let casted_scalar_value =
try_cast_literal_to_type(left_lit_value, &right_type)?;
if let Some(value) = casted_scalar_value {
return Ok(binary_expr(lit(value), *op, right));
}
}
(_, Expr::Literal(right_lit_value))
if can_integer_literal_cast_to_type(
right_lit_value,
&left_type,
)
.unwrap() =>
{
// cast the right literal to the left type
return Ok(binary_expr(
left,
*op,
cast_to_other_scalar_expr(right_lit_value, &left_type)?,
));
(_, Expr::Literal(right_lit_value)) => {
let casted_scalar_value =
try_cast_literal_to_type(right_lit_value, &left_type)?;
if let Some(value) = casted_scalar_value {
return Ok(binary_expr(left, *op, lit(value)));
}
}
(_, _) => {
// do nothing
Expand All @@ -164,43 +152,6 @@ impl ExprRewriter for PreCastLitExprRewriter {
}
}

fn cast_to_other_scalar_expr(
origin_value: &ScalarValue,
target_type: &DataType,
) -> Result<Expr> {
// null case
if origin_value.is_null() {
// if the origin value is null, just convert to another type of null value
// The target type must be satisfied `is_support_data_type` method, we can unwrap safely
return Ok(lit(ScalarValue::try_from(target_type).unwrap()));
}
// no null case
let value: i64 = match origin_value {
ScalarValue::Int8(Some(v)) => *v as i64,
ScalarValue::Int16(Some(v)) => *v as i64,
ScalarValue::Int32(Some(v)) => *v as i64,
ScalarValue::Int64(Some(v)) => *v as i64,
other_value => {
return Err(DataFusionError::Internal(format!(
"Invalid type and value {}",
other_value
)))
}
};
Ok(lit(match target_type {
DataType::Int8 => ScalarValue::Int8(Some(value as i8)),
DataType::Int16 => ScalarValue::Int16(Some(value as i16)),
DataType::Int32 => ScalarValue::Int32(Some(value as i32)),
DataType::Int64 => ScalarValue::Int64(Some(value)),
other_type => {
return Err(DataFusionError::Internal(format!(
"Invalid target data type {:?}",
other_type
)))
}
}))
}

fn is_comparison_op(op: &Operator) -> bool {
matches!(
op,
Expand All @@ -214,47 +165,112 @@ fn is_comparison_op(op: &Operator) -> bool {
}

fn is_support_data_type(data_type: &DataType) -> bool {
// TODO support decimal with other data type
matches!(
data_type,
DataType::Int8 | DataType::Int16 | DataType::Int32 | DataType::Int64
DataType::Int8
| DataType::Int16
| DataType::Int32
| DataType::Int64
| DataType::Decimal128(_, _)
)
}

fn can_integer_literal_cast_to_type(
integer_lit_value: &ScalarValue,
fn try_cast_literal_to_type(
lit_value: &ScalarValue,
target_type: &DataType,
) -> Result<bool> {
if integer_lit_value.is_null() {
) -> Result<Option<ScalarValue>> {
if lit_value.is_null() {
// null value can be cast to any type of null value
return Ok(true);
return Ok(Some(ScalarValue::try_from(target_type)?));
}
let mul = match target_type {
DataType::Int8 | DataType::Int16 | DataType::Int32 | DataType::Int64 => 1_i128,
DataType::Decimal128(_, scale) => 10_i128.pow(*scale as u32),
other_type => {
return Err(DataFusionError::Internal(format!(
"Error target data type {:?}",
other_type
)));
}
};
let (target_min, target_max) = match target_type {
DataType::Int8 => (i8::MIN as i128, i8::MAX as i128),
DataType::Int16 => (i16::MIN as i128, i16::MAX as i128),
DataType::Int32 => (i32::MIN as i128, i32::MAX as i128),
DataType::Int64 => (i64::MIN as i128, i64::MAX as i128),
DataType::Decimal128(precision, _) => (
// Different precision for decimal128 can store different range of value.
// For example, the precision is 3, the max of value is `999` and the min
// value is `-999`
MIN_DECIMAL_FOR_EACH_PRECISION[*precision - 1],
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you add a comment here to explain what is going on?

MAX_DECIMAL_FOR_EACH_PRECISION[*precision - 1],
),
other_type => {
return Err(DataFusionError::Internal(format!(
"Error target data type {:?}",
other_type
)))
)));
}
};
let lit_value = match integer_lit_value {
ScalarValue::Int8(Some(v)) => *v as i128,
ScalarValue::Int16(Some(v)) => *v as i128,
ScalarValue::Int32(Some(v)) => *v as i128,
ScalarValue::Int64(Some(v)) => *v as i128,
let lit_value_target_type = match lit_value {
ScalarValue::Int8(Some(v)) => (*v as i128).checked_mul(mul),
ScalarValue::Int16(Some(v)) => (*v as i128).checked_mul(mul),
ScalarValue::Int32(Some(v)) => (*v as i128).checked_mul(mul),
ScalarValue::Int64(Some(v)) => (*v as i128).checked_mul(mul),
ScalarValue::Decimal128(Some(v), _, scale) => {
let lit_scale_mul = 10_i128.pow(*scale as u32);
if mul >= lit_scale_mul {
// Example:
// lit is decimal(123,3,2)
// target type is decimal(5,3)
// the lit can be converted to the decimal(1230,5,3)
(*v).checked_mul(mul / lit_scale_mul)
} else if (*v) % (lit_scale_mul / mul) == 0 {
// Example:
// lit is decimal(123000,10,3)
// target type is int32: the lit can be converted to INT32(123)
// target type is decimal(10,2): the lit can be converted to decimal(12300,10,2)
Some(*v / (lit_scale_mul / mul))
} else {
// can't convert the lit decimal to the target data type
None
}
}
other_value => {
return Err(DataFusionError::Internal(format!(
"Invalid literal value {:?}",
other_value
)))
)));
}
};

Ok(lit_value >= target_min && lit_value <= target_max)
match lit_value_target_type {
None => Ok(None),
Some(value) => {
if value >= target_min && value <= target_max {
// the value casted from lit to the target type is in the range of target type.
// return the target type of scalar value
let result_scalar = match target_type {
DataType::Int8 => ScalarValue::Int8(Some(value as i8)),
DataType::Int16 => ScalarValue::Int16(Some(value as i16)),
DataType::Int32 => ScalarValue::Int32(Some(value as i32)),
DataType::Int64 => ScalarValue::Int64(Some(value as i64)),
DataType::Decimal128(p, s) => {
ScalarValue::Decimal128(Some(value), *p, *s)
}
other_type => {
return Err(DataFusionError::Internal(format!(
"Error target data type {:?}",
other_type
)));
}
};
Ok(Some(result_scalar))
} else {
Ok(None)
}
}
}
}

#[cfg(test)]
Expand Down Expand Up @@ -307,6 +323,67 @@ mod tests {
assert_eq!(optimize_test(c1_lt_lit_null, &schema), expected);
}

#[test]
fn test_not_cast_with_decimal_lit_comparison() {
let schema = expr_test_schema();
// integer to decimal: value is out of the bounds of the decimal
// c3 = INT64(100000000000000000)
let expr_eq = col("c3").eq(lit(ScalarValue::Int64(Some(100000000000000000))));
let expected = col("c3").eq(lit(ScalarValue::Int64(Some(100000000000000000))));
assert_eq!(optimize_test(expr_eq, &schema), expected);
// c4 = INT64(1000) will overflow the i128
let expr_eq = col("c4").eq(lit(ScalarValue::Int64(Some(1000))));
let expected = col("c4").eq(lit(ScalarValue::Int64(Some(1000))));
assert_eq!(optimize_test(expr_eq, &schema), expected);

// decimal to decimal: value will lose the scale when convert to the target data type
// c3 = DECIMAL(12340,20,4)
let expr_eq = col("c3").eq(lit(ScalarValue::Decimal128(Some(12340), 20, 4)));
let expected = col("c3").eq(lit(ScalarValue::Decimal128(Some(12340), 20, 4)));
assert_eq!(optimize_test(expr_eq, &schema), expected);

// decimal to integer
// c1 = DECIMAL(123, 10, 1): value will lose the scale when convert to the target data type
let expr_eq = col("c1").eq(lit(ScalarValue::Decimal128(Some(123), 10, 1)));
let expected = col("c1").eq(lit(ScalarValue::Decimal128(Some(123), 10, 1)));
assert_eq!(optimize_test(expr_eq, &schema), expected);
// c1 = DECIMAL(1230, 10, 2): value will lose the scale when convert to the target data type
let expr_eq = col("c1").eq(lit(ScalarValue::Decimal128(Some(1230), 10, 2)));
let expected = col("c1").eq(lit(ScalarValue::Decimal128(Some(1230), 10, 2)));
assert_eq!(optimize_test(expr_eq, &schema), expected);
}

#[test]
fn test_pre_cast_with_decimal_lit_comparison() {
let schema = expr_test_schema();
// integer to decimal
// c3 < INT64(16) -> c3 < (CAST(INT64(16) AS DECIMAL(18,2));
let expr_lt = col("c3").lt(lit(ScalarValue::Int64(Some(16))));
let expected = col("c3").lt(lit(ScalarValue::Decimal128(Some(1600), 18, 2)));
assert_eq!(optimize_test(expr_lt, &schema), expected);

// c3 < INT64(NULL)
let c1_lt_lit_null = col("c3").lt(lit(ScalarValue::Int64(None)));
let expected = col("c3").lt(lit(ScalarValue::Decimal128(None, 18, 2)));
assert_eq!(optimize_test(c1_lt_lit_null, &schema), expected);

// decimal to decimal
// c3 < Decimal(123,10,0) -> c3 < CAST(DECIMAL(123,10,0) AS DECIMAL(18,2)) -> c3 < DECIMAL(12300,18,2)
let expr_lt = col("c3").lt(lit(ScalarValue::Decimal128(Some(123), 10, 0)));
let expected = col("c3").lt(lit(ScalarValue::Decimal128(Some(12300), 18, 2)));
assert_eq!(optimize_test(expr_lt, &schema), expected);
// c3 < Decimal(1230,10,3) -> c3 < CAST(DECIMAL(1230,10,3) AS DECIMAL(18,2)) -> c3 < DECIMAL(123,18,2)
let expr_lt = col("c3").lt(lit(ScalarValue::Decimal128(Some(1230), 10, 3)));
let expected = col("c3").lt(lit(ScalarValue::Decimal128(Some(123), 18, 2)));
assert_eq!(optimize_test(expr_lt, &schema), expected);

// decimal to integer
// c1 < Decimal(12300, 10, 2) -> c1 < CAST(DECIMAL(12300,10,2) AS INT32) -> c1 < INT32(123)
let expr_lt = col("c1").lt(lit(ScalarValue::Decimal128(Some(12300), 10, 2)));
let expected = col("c1").lt(lit(ScalarValue::Int32(Some(123))));
assert_eq!(optimize_test(expr_lt, &schema), expected);
}

#[test]
fn aliased() {
let schema = expr_test_schema();
Expand Down Expand Up @@ -344,6 +421,8 @@ mod tests {
vec![
DFField::new(None, "c1", DataType::Int32, false),
DFField::new(None, "c2", DataType::Int64, false),
DFField::new(None, "c3", DataType::Decimal128(18, 2), false),
DFField::new(None, "c4", DataType::Decimal128(38, 37), false),
],
HashMap::new(),
)
Expand Down