diff --git a/datafusion/core/src/execution/context.rs b/datafusion/core/src/execution/context.rs index 9c9ed952625b..7299ca7ac504 100644 --- a/datafusion/core/src/execution/context.rs +++ b/datafusion/core/src/execution/context.rs @@ -106,6 +106,7 @@ use datafusion_optimizer::decorrelate_scalar_subquery::DecorrelateScalarSubquery use datafusion_optimizer::decorrelate_where_exists::DecorrelateWhereExists; use datafusion_optimizer::decorrelate_where_in::DecorrelateWhereIn; use datafusion_optimizer::filter_null_join_keys::FilterNullJoinKeys; +use datafusion_optimizer::pre_cast_lit_in_comparison::PreCastLitInComparisonExpressions; use datafusion_optimizer::rewrite_disjunctive_predicate::RewriteDisjunctivePredicate; use datafusion_sql::{ parser::DFParser, @@ -1358,6 +1359,7 @@ impl SessionState { // Simplify expressions first to maximize the chance // of applying other optimizations Arc::new(SimplifyExpressions::new()), + Arc::new(PreCastLitInComparisonExpressions::new()), Arc::new(DecorrelateWhereExists::new()), Arc::new(DecorrelateWhereIn::new()), Arc::new(DecorrelateScalarSubquery::new()), diff --git a/datafusion/core/tests/provider_filter_pushdown.rs b/datafusion/core/tests/provider_filter_pushdown.rs index 3ebfec996e64..8e6d695c9e9c 100644 --- a/datafusion/core/tests/provider_filter_pushdown.rs +++ b/datafusion/core/tests/provider_filter_pushdown.rs @@ -31,6 +31,8 @@ use datafusion::physical_plan::{ }; use datafusion::prelude::*; use datafusion::scalar::ScalarValue; +use datafusion_common::DataFusionError; +use std::ops::Deref; use std::sync::Arc; fn create_batch(value: i32, num_rows: usize) -> Result { @@ -146,8 +148,36 @@ impl TableProvider for CustomProvider { match &filters[0] { Expr::BinaryExpr { right, .. } => { let int_value = match &**right { - Expr::Literal(ScalarValue::Int64(i)) => i.unwrap(), - _ => unimplemented!(), + Expr::Literal(ScalarValue::Int8(Some(i))) => *i as i64, + Expr::Literal(ScalarValue::Int16(Some(i))) => *i as i64, + Expr::Literal(ScalarValue::Int32(Some(i))) => *i as i64, + Expr::Literal(ScalarValue::Int64(Some(i))) => *i as i64, + Expr::Cast { expr, data_type: _ } => match expr.deref() { + Expr::Literal(lit_value) => match lit_value { + ScalarValue::Int8(Some(v)) => *v as i64, + ScalarValue::Int16(Some(v)) => *v as i64, + ScalarValue::Int32(Some(v)) => *v as i64, + ScalarValue::Int64(Some(v)) => *v, + other_value => { + return Err(DataFusionError::NotImplemented(format!( + "Do not support value {:?}", + other_value + ))) + } + }, + other_expr => { + return Err(DataFusionError::NotImplemented(format!( + "Do not support expr {:?}", + other_expr + ))) + } + }, + other_expr => { + return Err(DataFusionError::NotImplemented(format!( + "Do not support expr {:?}", + other_expr + ))) + } }; Ok(Arc::new(CustomPlan { diff --git a/datafusion/core/tests/sql/explain_analyze.rs b/datafusion/core/tests/sql/explain_analyze.rs index 02db3e873330..2b801ed01cb9 100644 --- a/datafusion/core/tests/sql/explain_analyze.rs +++ b/datafusion/core/tests/sql/explain_analyze.rs @@ -271,8 +271,8 @@ async fn csv_explain_plans() { let expected = vec![ "Explain [plan_type:Utf8, plan:Utf8]", " Projection: #aggregate_test_100.c1 [c1:Utf8]", - " Filter: #aggregate_test_100.c2 > Int64(10) [c1:Utf8, c2:Int32]", - " TableScan: aggregate_test_100 projection=[c1, c2], partial_filters=[#aggregate_test_100.c2 > Int64(10)] [c1:Utf8, c2:Int32]", + " Filter: #aggregate_test_100.c2 > Int32(10) [c1:Utf8, c2:Int32]", + " TableScan: aggregate_test_100 projection=[c1, c2], partial_filters=[#aggregate_test_100.c2 > Int32(10)] [c1:Utf8, c2:Int32]", ]; let formatted = plan.display_indent_schema().to_string(); let actual: Vec<&str> = formatted.trim().lines().collect(); @@ -286,8 +286,8 @@ async fn csv_explain_plans() { let expected = vec![ "Explain", " Projection: #aggregate_test_100.c1", - " Filter: #aggregate_test_100.c2 > Int64(10)", - " TableScan: aggregate_test_100 projection=[c1, c2], partial_filters=[#aggregate_test_100.c2 > Int64(10)]", + " Filter: #aggregate_test_100.c2 > Int32(10)", + " TableScan: aggregate_test_100 projection=[c1, c2], partial_filters=[#aggregate_test_100.c2 > Int32(10)]", ]; let formatted = plan.display_indent().to_string(); let actual: Vec<&str> = formatted.trim().lines().collect(); @@ -307,9 +307,9 @@ async fn csv_explain_plans() { " 2[shape=box label=\"Explain\"]", " 3[shape=box label=\"Projection: #aggregate_test_100.c1\"]", " 2 -> 3 [arrowhead=none, arrowtail=normal, dir=back]", - " 4[shape=box label=\"Filter: #aggregate_test_100.c2 > Int64(10)\"]", + " 4[shape=box label=\"Filter: #aggregate_test_100.c2 > Int32(10)\"]", " 3 -> 4 [arrowhead=none, arrowtail=normal, dir=back]", - " 5[shape=box label=\"TableScan: aggregate_test_100 projection=[c1, c2], partial_filters=[#aggregate_test_100.c2 > Int64(10)]\"]", + " 5[shape=box label=\"TableScan: aggregate_test_100 projection=[c1, c2], partial_filters=[#aggregate_test_100.c2 > Int32(10)]\"]", " 4 -> 5 [arrowhead=none, arrowtail=normal, dir=back]", " }", " subgraph cluster_6", @@ -318,9 +318,9 @@ async fn csv_explain_plans() { " 7[shape=box label=\"Explain\\nSchema: [plan_type:Utf8, plan:Utf8]\"]", " 8[shape=box label=\"Projection: #aggregate_test_100.c1\\nSchema: [c1:Utf8]\"]", " 7 -> 8 [arrowhead=none, arrowtail=normal, dir=back]", - " 9[shape=box label=\"Filter: #aggregate_test_100.c2 > Int64(10)\\nSchema: [c1:Utf8, c2:Int32]\"]", + " 9[shape=box label=\"Filter: #aggregate_test_100.c2 > Int32(10)\\nSchema: [c1:Utf8, c2:Int32]\"]", " 8 -> 9 [arrowhead=none, arrowtail=normal, dir=back]", - " 10[shape=box label=\"TableScan: aggregate_test_100 projection=[c1, c2], partial_filters=[#aggregate_test_100.c2 > Int64(10)]\\nSchema: [c1:Utf8, c2:Int32]\"]", + " 10[shape=box label=\"TableScan: aggregate_test_100 projection=[c1, c2], partial_filters=[#aggregate_test_100.c2 > Int32(10)]\\nSchema: [c1:Utf8, c2:Int32]\"]", " 9 -> 10 [arrowhead=none, arrowtail=normal, dir=back]", " }", "}", @@ -349,7 +349,7 @@ async fn csv_explain_plans() { // Since the plan contains path that are environmentally dependant (e.g. full path of the test file), only verify important content assert_contains!(&actual, "logical_plan"); assert_contains!(&actual, "Projection: #aggregate_test_100.c1"); - assert_contains!(actual, "Filter: #aggregate_test_100.c2 > Int64(10)"); + assert_contains!(actual, "Filter: #aggregate_test_100.c2 > Int32(10)"); } #[tokio::test] @@ -469,8 +469,8 @@ async fn csv_explain_verbose_plans() { let expected = vec![ "Explain [plan_type:Utf8, plan:Utf8]", " Projection: #aggregate_test_100.c1 [c1:Utf8]", - " Filter: #aggregate_test_100.c2 > Int64(10) [c1:Utf8, c2:Int32]", - " TableScan: aggregate_test_100 projection=[c1, c2], partial_filters=[#aggregate_test_100.c2 > Int64(10)] [c1:Utf8, c2:Int32]", + " Filter: #aggregate_test_100.c2 > Int32(10) [c1:Utf8, c2:Int32]", + " TableScan: aggregate_test_100 projection=[c1, c2], partial_filters=[#aggregate_test_100.c2 > Int32(10)] [c1:Utf8, c2:Int32]", ]; let formatted = plan.display_indent_schema().to_string(); let actual: Vec<&str> = formatted.trim().lines().collect(); @@ -484,8 +484,8 @@ async fn csv_explain_verbose_plans() { let expected = vec![ "Explain", " Projection: #aggregate_test_100.c1", - " Filter: #aggregate_test_100.c2 > Int64(10)", - " TableScan: aggregate_test_100 projection=[c1, c2], partial_filters=[#aggregate_test_100.c2 > Int64(10)]", + " Filter: #aggregate_test_100.c2 > Int32(10)", + " TableScan: aggregate_test_100 projection=[c1, c2], partial_filters=[#aggregate_test_100.c2 > Int32(10)]", ]; let formatted = plan.display_indent().to_string(); let actual: Vec<&str> = formatted.trim().lines().collect(); @@ -505,9 +505,9 @@ async fn csv_explain_verbose_plans() { " 2[shape=box label=\"Explain\"]", " 3[shape=box label=\"Projection: #aggregate_test_100.c1\"]", " 2 -> 3 [arrowhead=none, arrowtail=normal, dir=back]", - " 4[shape=box label=\"Filter: #aggregate_test_100.c2 > Int64(10)\"]", + " 4[shape=box label=\"Filter: #aggregate_test_100.c2 > Int32(10)\"]", " 3 -> 4 [arrowhead=none, arrowtail=normal, dir=back]", - " 5[shape=box label=\"TableScan: aggregate_test_100 projection=[c1, c2], partial_filters=[#aggregate_test_100.c2 > Int64(10)]\"]", + " 5[shape=box label=\"TableScan: aggregate_test_100 projection=[c1, c2], partial_filters=[#aggregate_test_100.c2 > Int32(10)]\"]", " 4 -> 5 [arrowhead=none, arrowtail=normal, dir=back]", " }", " subgraph cluster_6", @@ -516,9 +516,9 @@ async fn csv_explain_verbose_plans() { " 7[shape=box label=\"Explain\\nSchema: [plan_type:Utf8, plan:Utf8]\"]", " 8[shape=box label=\"Projection: #aggregate_test_100.c1\\nSchema: [c1:Utf8]\"]", " 7 -> 8 [arrowhead=none, arrowtail=normal, dir=back]", - " 9[shape=box label=\"Filter: #aggregate_test_100.c2 > Int64(10)\\nSchema: [c1:Utf8, c2:Int32]\"]", + " 9[shape=box label=\"Filter: #aggregate_test_100.c2 > Int32(10)\\nSchema: [c1:Utf8, c2:Int32]\"]", " 8 -> 9 [arrowhead=none, arrowtail=normal, dir=back]", - " 10[shape=box label=\"TableScan: aggregate_test_100 projection=[c1, c2], partial_filters=[#aggregate_test_100.c2 > Int64(10)]\\nSchema: [c1:Utf8, c2:Int32]\"]", + " 10[shape=box label=\"TableScan: aggregate_test_100 projection=[c1, c2], partial_filters=[#aggregate_test_100.c2 > Int32(10)]\\nSchema: [c1:Utf8, c2:Int32]\"]", " 9 -> 10 [arrowhead=none, arrowtail=normal, dir=back]", " }", "}", @@ -549,7 +549,7 @@ async fn csv_explain_verbose_plans() { // important content assert_contains!(&actual, "logical_plan after projection_push_down"); assert_contains!(&actual, "physical_plan"); - assert_contains!(&actual, "FilterExec: CAST(c2@1 AS Int64) > 10"); + assert_contains!(&actual, "FilterExec: c2@1 > 10"); assert_contains!(actual, "ProjectionExec: expr=[c1@0 as c1]"); } @@ -745,7 +745,7 @@ async fn csv_explain() { // then execute the physical plan and return the final explain results let ctx = SessionContext::new(); register_aggregate_csv_by_sql(&ctx).await; - let sql = "EXPLAIN SELECT c1 FROM aggregate_test_100 where c2 > 10"; + let sql = "EXPLAIN SELECT c1 FROM aggregate_test_100 where c2 > cast(10 as int)"; let actual = execute(&ctx, sql).await; let actual = normalize_vec_for_explain(actual); @@ -755,13 +755,13 @@ async fn csv_explain() { vec![ "logical_plan", "Projection: #aggregate_test_100.c1\ - \n Filter: #aggregate_test_100.c2 > Int64(10)\ - \n TableScan: aggregate_test_100 projection=[c1, c2], partial_filters=[#aggregate_test_100.c2 > Int64(10)]" + \n Filter: #aggregate_test_100.c2 > Int32(10)\ + \n TableScan: aggregate_test_100 projection=[c1, c2], partial_filters=[#aggregate_test_100.c2 > Int32(10)]" ], vec!["physical_plan", "ProjectionExec: expr=[c1@0 as c1]\ \n CoalesceBatchesExec: target_batch_size=4096\ - \n FilterExec: CAST(c2@1 AS Int64) > 10\ + \n FilterExec: c2@1 > 10\ \n RepartitionExec: partitioning=RoundRobinBatch(NUM_CORES)\ \n CsvExec: files=[ARROW_TEST_DATA/csv/aggregate_test_100.csv], has_header=true, limit=None, projection=[c1, c2]\ \n" diff --git a/datafusion/core/tests/sql/subqueries.rs b/datafusion/core/tests/sql/subqueries.rs index 4eaf921f6937..d85a2693253a 100644 --- a/datafusion/core/tests/sql/subqueries.rs +++ b/datafusion/core/tests/sql/subqueries.rs @@ -147,8 +147,8 @@ order by s_acctbal desc, n_name, s_name, p_partkey;"#; Inner Join: #supplier.s_nationkey = #nation.n_nationkey Inner Join: #partsupp.ps_suppkey = #supplier.s_suppkey Inner Join: #part.p_partkey = #partsupp.ps_partkey - Filter: #part.p_size = Int64(15) AND #part.p_type LIKE Utf8("%BRASS") - TableScan: part projection=[p_partkey, p_mfgr, p_type, p_size], partial_filters=[#part.p_size = Int64(15), #part.p_type LIKE Utf8("%BRASS")] + Filter: #part.p_size = Int32(15) AND #part.p_type LIKE Utf8("%BRASS") + TableScan: part projection=[p_partkey, p_mfgr, p_type, p_size], partial_filters=[#part.p_size = Int32(15), #part.p_type LIKE Utf8("%BRASS")] TableScan: partsupp projection=[ps_partkey, ps_suppkey, ps_supplycost] TableScan: supplier projection=[s_suppkey, s_name, s_address, s_nationkey, s_phone, s_acctbal, s_comment] TableScan: nation projection=[n_nationkey, n_name, n_regionkey] diff --git a/datafusion/optimizer/src/lib.rs b/datafusion/optimizer/src/lib.rs index 6da67b6fc132..60c450992de5 100644 --- a/datafusion/optimizer/src/lib.rs +++ b/datafusion/optimizer/src/lib.rs @@ -33,6 +33,7 @@ pub mod single_distinct_to_groupby; pub mod subquery_filter_to_join; pub mod utils; +pub mod pre_cast_lit_in_comparison; pub mod rewrite_disjunctive_predicate; #[cfg(test)] pub mod test; diff --git a/datafusion/optimizer/src/pre_cast_lit_in_comparison.rs b/datafusion/optimizer/src/pre_cast_lit_in_comparison.rs new file mode 100644 index 000000000000..0c16f7921c32 --- /dev/null +++ b/datafusion/optimizer/src/pre_cast_lit_in_comparison.rs @@ -0,0 +1,311 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Pre-cast literal binary comparison rule can be only used to the binary comparison expr. +//! It can reduce adding the `Expr::Cast` to the expr instead of adding the `Expr::Cast` to literal expr. +use crate::{OptimizerConfig, OptimizerRule}; +use arrow::datatypes::DataType; +use datafusion_common::{DFSchemaRef, DataFusionError, Result, ScalarValue}; +use datafusion_expr::utils::from_plan; +use datafusion_expr::{binary_expr, lit, Expr, ExprSchemable, LogicalPlan, Operator}; + +/// The rule can be only used to the numeric binary comparison with literal expr, like below pattern: +/// `left_expr comparison_op literal_expr` or `literal_expr comparison_op right_expr`. +/// The data type of two sides must be signed numeric type now, and will support more data type later. +/// +/// If the binary comparison expr match above rules, the optimizer will check if the value of `literal` +/// is in within range(min,max) which is the range(min,max) of the data type for `left_expr` or `right_expr`. +/// +/// If this true, the literal expr will be casted to the data type of expr on the other side, and the result of +/// binary comparison will be `left_expr comparison_op cast(literal_expr, left_data_type)` or +/// `cast(literal_expr, right_data_type) comparison_op right_expr`. For better optimization, +/// the expr of `cast(literal_expr, target_type)` will be precomputed and converted to the new expr `new_literal_expr` +/// which data type is `target_type`. +/// If this false, do nothing. +/// +/// This is inspired by the optimizer rule `UnwrapCastInBinaryComparison` of Spark. +/// # Example +/// +/// `Filter: c1 > INT64(10)` will be optimized to `Filter: c1 > CAST(INT64(10) AS INT32), +/// and continue to be converted to `Filter: c1 > INT32(10)`, if the DataType of c1 is INT32. +/// +#[derive(Default)] +pub struct PreCastLitInComparisonExpressions {} + +impl PreCastLitInComparisonExpressions { + pub fn new() -> Self { + Self::default() + } +} + +impl OptimizerRule for PreCastLitInComparisonExpressions { + fn optimize( + &self, + plan: &LogicalPlan, + _optimizer_config: &mut OptimizerConfig, + ) -> Result { + optimize(plan) + } + + fn name(&self) -> &str { + "pre_cast_lit_in_comparison" + } +} + +fn optimize(plan: &LogicalPlan) -> Result { + let new_inputs = plan + .inputs() + .iter() + .map(|input| optimize(input)) + .collect::>>()?; + + let schema = plan.schema(); + let new_exprs = plan + .expressions() + .into_iter() + .map(|expr| visit_expr(expr, schema)) + .collect::>>()?; + + from_plan(plan, new_exprs.as_slice(), new_inputs.as_slice()) +} + +// Visit all type of expr, if the current has child expr, the child expr needed to visit first. +fn visit_expr(expr: Expr, schema: &DFSchemaRef) -> Result { + // traverse the expr by dfs + match &expr { + Expr::BinaryExpr { left, op, right } => { + // dfs visit the left and right expr + let left = visit_expr(*left.clone(), schema)?; + let right = visit_expr(*right.clone(), schema)?; + let left_type = left.get_type(schema); + let right_type = right.get_type(schema); + // can't get the data type, just return the expr + if left_type.is_err() || right_type.is_err() { + return Ok(expr.clone()); + } + let left_type = left_type.unwrap(); + let right_type = right_type.unwrap(); + if !left_type.eq(&right_type) + && is_support_data_type(&left_type) + && is_support_data_type(&right_type) + && is_comparison_op(op) + { + match (&left, &right) { + (Expr::Literal(_), Expr::Literal(_)) => { + // do nothing + } + (Expr::Literal(left_lit_value), _) + if can_integer_literal_cast_to_type( + left_lit_value, + &right_type, + )? => + { + // cast the left literal to the right type + return Ok(binary_expr( + cast_to_other_scalar_expr(left_lit_value, &right_type)?, + *op, + right, + )); + } + (_, Expr::Literal(right_lit_value)) + if can_integer_literal_cast_to_type( + right_lit_value, + &left_type, + ) + .unwrap() => + { + // cast the right literal to the left type + return Ok(binary_expr( + left, + *op, + cast_to_other_scalar_expr(right_lit_value, &left_type)?, + )); + } + (_, _) => { + // do nothing + } + }; + } + // return the new binary op + Ok(binary_expr(left, *op, right)) + } + // TODO: optimize in list + // Expr::InList { .. } => {} + // TODO: handle other expr type and dfs visit them + _ => Ok(expr), + } +} + +fn cast_to_other_scalar_expr( + origin_value: &ScalarValue, + target_type: &DataType, +) -> Result { + // null case + if origin_value.is_null() { + // if the origin value is null, just convert to another type of null value + // The target type must be satisfied `is_support_data_type` method, we can unwrap safely + return Ok(lit(ScalarValue::try_from(target_type).unwrap())); + } + // no null case + let value: i64 = match origin_value { + ScalarValue::Int8(Some(v)) => *v as i64, + ScalarValue::Int16(Some(v)) => *v as i64, + ScalarValue::Int32(Some(v)) => *v as i64, + ScalarValue::Int64(Some(v)) => *v as i64, + other_value => { + return Err(DataFusionError::Internal(format!( + "Invalid type and value {}", + other_value + ))) + } + }; + Ok(lit(match target_type { + DataType::Int8 => ScalarValue::Int8(Some(value as i8)), + DataType::Int16 => ScalarValue::Int16(Some(value as i16)), + DataType::Int32 => ScalarValue::Int32(Some(value as i32)), + DataType::Int64 => ScalarValue::Int64(Some(value)), + other_type => { + return Err(DataFusionError::Internal(format!( + "Invalid target data type {:?}", + other_type + ))) + } + })) +} + +fn is_comparison_op(op: &Operator) -> bool { + matches!( + op, + Operator::Eq + | Operator::NotEq + | Operator::Gt + | Operator::GtEq + | Operator::Lt + | Operator::LtEq + ) +} + +fn is_support_data_type(data_type: &DataType) -> bool { + // TODO support decimal with other data type + matches!( + data_type, + DataType::Int8 | DataType::Int16 | DataType::Int32 | DataType::Int64 + ) +} + +fn can_integer_literal_cast_to_type( + integer_lit_value: &ScalarValue, + target_type: &DataType, +) -> Result { + if integer_lit_value.is_null() { + // null value can be cast to any type of null value + return Ok(true); + } + let (target_min, target_max) = match target_type { + DataType::Int8 => (i8::MIN as i128, i8::MAX as i128), + DataType::Int16 => (i16::MIN as i128, i16::MAX as i128), + DataType::Int32 => (i32::MIN as i128, i32::MAX as i128), + DataType::Int64 => (i64::MIN as i128, i64::MAX as i128), + other_type => { + return Err(DataFusionError::Internal(format!( + "Error target data type {:?}", + other_type + ))) + } + }; + let lit_value = match integer_lit_value { + ScalarValue::Int8(Some(v)) => *v as i128, + ScalarValue::Int16(Some(v)) => *v as i128, + ScalarValue::Int32(Some(v)) => *v as i128, + ScalarValue::Int64(Some(v)) => *v as i128, + other_value => { + return Err(DataFusionError::Internal(format!( + "Invalid literal value {:?}", + other_value + ))) + } + }; + + Ok(lit_value >= target_min && lit_value <= target_max) +} + +#[cfg(test)] +mod tests { + use crate::pre_cast_lit_in_comparison::visit_expr; + use arrow::datatypes::DataType; + use datafusion_common::{DFField, DFSchema, DFSchemaRef, ScalarValue}; + use datafusion_expr::{col, lit, Expr}; + use std::collections::HashMap; + use std::sync::Arc; + + #[test] + fn test_not_cast_lit_comparison() { + let schema = expr_test_schema(); + // INT8(NULL) < INT32(12) + let lit_lt_lit = + lit(ScalarValue::Int8(None)).lt(lit(ScalarValue::Int32(Some(12)))); + assert_eq!(optimize_test(lit_lt_lit.clone(), &schema), lit_lt_lit); + // INT32(c1) > INT64(c2) + let c1_gt_c2 = col("c1").gt(col("c2")); + assert_eq!(optimize_test(c1_gt_c2.clone(), &schema), c1_gt_c2); + + // INT32(c1) < INT32(16), the type is same + let expr_lt = col("c1").lt(lit(ScalarValue::Int32(Some(16)))); + assert_eq!(optimize_test(expr_lt.clone(), &schema), expr_lt); + + // the 99999999999 is not within the range of MAX(int32) and MIN(int32), we don't cast the lit(99999999999) to int32 type + let expr_lt = col("c1").lt(lit(ScalarValue::Int64(Some(99999999999)))); + assert_eq!(optimize_test(expr_lt.clone(), &schema), expr_lt); + } + + #[test] + fn test_pre_cast_lit_comparison() { + let schema = expr_test_schema(); + // c1 < INT64(16) -> c1 < cast(INT32(16)) + // the 16 is within the range of MAX(int32) and MIN(int32), we can cast the 16 to int32(16) + let expr_lt = col("c1").lt(lit(ScalarValue::Int64(Some(16)))); + let expected = col("c1").lt(lit(ScalarValue::Int32(Some(16)))); + assert_eq!(optimize_test(expr_lt, &schema), expected); + + // INT64(c2) = INT32(16) => INT64(c2) = INT64(16) + let c2_eq_lit = col("c2").eq(lit(ScalarValue::Int32(Some(16)))); + let expected = col("c2").eq(lit(ScalarValue::Int64(Some(16)))); + assert_eq!(optimize_test(c2_eq_lit, &schema), expected); + + // INT32(c1) < INT64(NULL) => INT32(c1) < INT32(NULL) + let c1_lt_lit_null = col("c1").lt(lit(ScalarValue::Int64(None))); + let expected = col("c1").lt(lit(ScalarValue::Int32(None))); + assert_eq!(optimize_test(c1_lt_lit_null, &schema), expected); + } + + fn optimize_test(expr: Expr, schema: &DFSchemaRef) -> Expr { + visit_expr(expr, schema).unwrap() + } + + fn expr_test_schema() -> DFSchemaRef { + Arc::new( + DFSchema::new_with_metadata( + vec![ + DFField::new(None, "c1", DataType::Int32, false), + DFField::new(None, "c2", DataType::Int64, false), + ], + HashMap::new(), + ) + .unwrap(), + ) + } +}