Skip to content

Commit

Permalink
support inlist for pre cast literal expression (#3270)
Browse files Browse the repository at this point in the history
* support decimal for the PreCastLitInComparisonExpressions rule

* address comments

* support list
  • Loading branch information
liukun4515 authored Aug 30, 2022
1 parent 3effee8 commit a4e74c0
Showing 1 changed file with 178 additions and 3 deletions.
181 changes: 178 additions & 3 deletions datafusion/optimizer/src/pre_cast_lit_in_comparison.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,9 @@ use arrow::datatypes::{
use datafusion_common::{DFSchemaRef, DataFusionError, Result, ScalarValue};
use datafusion_expr::expr_rewriter::{ExprRewritable, ExprRewriter, RewriteRecursion};
use datafusion_expr::utils::from_plan;
use datafusion_expr::{binary_expr, lit, Expr, ExprSchemable, LogicalPlan, Operator};
use datafusion_expr::{
binary_expr, in_list, lit, Expr, ExprSchemable, LogicalPlan, Operator,
};

/// The rule can be only used to the numeric binary comparison with literal expr, like below pattern:
/// `left_expr comparison_op literal_expr` or `literal_expr comparison_op right_expr`.
Expand Down Expand Up @@ -144,8 +146,57 @@ impl ExprRewriter for PreCastLitExprRewriter {
// return the new binary op
Ok(binary_expr(left, *op, right))
}
// TODO: optimize in list
// Expr::InList { .. } => {}
Expr::InList {
expr: left_expr,
list,
negated,
} => {
let left = left_expr.as_ref().clone();
let left_type = left.get_type(&self.schema);
if left_type.is_err() {
// error data type
return Ok(expr);
}
let left_type = left_type?;
if !is_support_data_type(&left_type) {
// not supported data type
return Ok(expr);
}
let right_exprs = list
.iter()
.map(|right| {
let right_type = right.get_type(&self.schema)?;
if !is_support_data_type(&right_type) {
return Err(DataFusionError::Internal(format!(
"The type of list expr {} not support",
&right_type
)));
}
match right {
Expr::Literal(right_lit_value) => {
let casted_scalar_value =
try_cast_literal_to_type(right_lit_value, &left_type)?;
if let Some(value) = casted_scalar_value {
Ok(lit(value))
} else {
Err(DataFusionError::Internal(format!(
"Can't cast the list expr {:?} to type {:?}",
right_lit_value, &left_type
)))
}
}
other_expr => Err(DataFusionError::Internal(format!(
"Only support literal expr to optimize, but the expr is {:?}",
&other_expr
))),
}
})
.collect::<Result<Vec<_>>>();
match right_exprs {
Ok(right_exprs) => Ok(in_list(left, right_exprs, *negated)),
Err(_) => Ok(expr),
}
}
// TODO: handle other expr type and dfs visit them
_ => Ok(expr),
}
Expand Down Expand Up @@ -384,6 +435,129 @@ mod tests {
assert_eq!(optimize_test(expr_lt, &schema), expected);
}

#[test]
fn test_not_list_cast_lit_comparison() {
let schema = expr_test_schema();
// left type is not supported
// FLOAT32(C5) in ...
let expr_lt = col("c5").in_list(
vec![
lit(ScalarValue::Int64(Some(12))),
lit(ScalarValue::Int32(Some(12))),
],
false,
);
assert_eq!(optimize_test(expr_lt.clone(), &schema), expr_lt);

// INT32(C1) in (FLOAT32(1.23), INT32(12), INT64(12))
let expr_lt = col("c1").in_list(
vec![
lit(ScalarValue::Int32(Some(12))),
lit(ScalarValue::Int64(Some(12))),
lit(ScalarValue::Float32(Some(1.23))),
],
false,
);
assert_eq!(optimize_test(expr_lt.clone(), &schema), expr_lt);

// INT32(C1) in (INT64(99999999999), INT64(12))
let expr_lt = col("c1").in_list(
vec![
lit(ScalarValue::Int32(Some(12))),
lit(ScalarValue::Int64(Some(99999999999))),
],
false,
);
assert_eq!(optimize_test(expr_lt.clone(), &schema), expr_lt);

// DECIMAL(C3) in (INT64(12), INT32(12), DECIMAL(128,12,3))
let expr_lt = col("c3").in_list(
vec![
lit(ScalarValue::Int32(Some(12))),
lit(ScalarValue::Int64(Some(12))),
lit(ScalarValue::Decimal128(Some(128), 12, 3)),
],
false,
);
assert_eq!(optimize_test(expr_lt.clone(), &schema), expr_lt);
}

#[test]
fn test_pre_list_cast_lit_comparison() {
let schema = expr_test_schema();
// INT32(C1) IN (INT32(12),INT64(24)) -> INT32(C1) IN (INT32(12),INT32(24))
let expr_lt = col("c1").in_list(
vec![
lit(ScalarValue::Int32(Some(12))),
lit(ScalarValue::Int64(Some(24))),
],
false,
);
let expected = col("c1").in_list(
vec![
lit(ScalarValue::Int32(Some(12))),
lit(ScalarValue::Int32(Some(24))),
],
false,
);
assert_eq!(optimize_test(expr_lt, &schema), expected);
// INT32(C2) IN (INT64(NULL),INT64(24)) -> INT32(C1) IN (INT32(12),INT32(24))
let expr_lt = col("c2").in_list(
vec![
lit(ScalarValue::Int64(None)),
lit(ScalarValue::Int32(Some(14))),
],
false,
);
let expected = col("c2").in_list(
vec![
lit(ScalarValue::Int64(None)),
lit(ScalarValue::Int64(Some(14))),
],
false,
);

assert_eq!(optimize_test(expr_lt, &schema), expected);

// decimal test case
let expr_lt = col("c3").in_list(
vec![
lit(ScalarValue::Int32(Some(12))),
lit(ScalarValue::Int64(Some(24))),
lit(ScalarValue::Decimal128(Some(128), 10, 2)),
lit(ScalarValue::Decimal128(Some(1280), 10, 3)),
],
false,
);
let expected = col("c3").in_list(
vec![
lit(ScalarValue::Decimal128(Some(1200), 18, 2)),
lit(ScalarValue::Decimal128(Some(2400), 18, 2)),
lit(ScalarValue::Decimal128(Some(128), 18, 2)),
lit(ScalarValue::Decimal128(Some(128), 18, 2)),
],
false,
);
assert_eq!(optimize_test(expr_lt, &schema), expected);

// INT32(12) IN (.....)
let expr_lt = lit(ScalarValue::Int32(Some(12))).in_list(
vec![
lit(ScalarValue::Int32(Some(12))),
lit(ScalarValue::Int64(Some(12))),
],
false,
);
let expected = lit(ScalarValue::Int32(Some(12))).in_list(
vec![
lit(ScalarValue::Int32(Some(12))),
lit(ScalarValue::Int32(Some(12))),
],
false,
);
assert_eq!(optimize_test(expr_lt, &schema), expected);
}

#[test]
fn aliased() {
let schema = expr_test_schema();
Expand Down Expand Up @@ -423,6 +597,7 @@ mod tests {
DFField::new(None, "c2", DataType::Int64, false),
DFField::new(None, "c3", DataType::Decimal128(18, 2), false),
DFField::new(None, "c4", DataType::Decimal128(38, 37), false),
DFField::new(None, "c5", DataType::Float32, false),
],
HashMap::new(),
)
Expand Down

0 comments on commit a4e74c0

Please sign in to comment.