From 2b5a10f18c48c52f7fd69bc1675f869c2d414587 Mon Sep 17 00:00:00 2001 From: Pepijn Van Eeckhoudt Date: Tue, 18 Nov 2025 16:44:57 +0100 Subject: [PATCH 1/7] Move GuaranteeRewriter to `expr` --- .../src/expr_rewriter}/guarantees.rs | 20 +++++++------------ datafusion/expr/src/expr_rewriter/mod.rs | 2 ++ .../simplify_expressions/expr_simplifier.rs | 2 +- .../optimizer/src/simplify_expressions/mod.rs | 3 +-- 4 files changed, 11 insertions(+), 16 deletions(-) rename datafusion/{optimizer/src/simplify_expressions => expr/src/expr_rewriter}/guarantees.rs (95%) diff --git a/datafusion/optimizer/src/simplify_expressions/guarantees.rs b/datafusion/expr/src/expr_rewriter/guarantees.rs similarity index 95% rename from datafusion/optimizer/src/simplify_expressions/guarantees.rs rename to datafusion/expr/src/expr_rewriter/guarantees.rs index 515fd29003af..1b1cd8707a74 100644 --- a/datafusion/optimizer/src/simplify_expressions/guarantees.rs +++ b/datafusion/expr/src/expr_rewriter/guarantees.rs @@ -15,30 +15,24 @@ // specific language governing permissions and limitations // under the License. -//! Simplifier implementation for [`ExprSimplifier::with_guarantees()`] -//! -//! [`ExprSimplifier::with_guarantees()`]: crate::simplify_expressions::expr_simplifier::ExprSimplifier::with_guarantees +//! Rewrite expressions based on external expression value range guarantees. use std::{borrow::Cow, collections::HashMap}; -use datafusion_common::tree_node::{Transformed, TreeNodeRewriter}; +use crate::{expr::InList, lit, Between, BinaryExpr, Expr, LogicalPlan}; +use datafusion_common::tree_node::{Transformed, TreeNode, TreeNodeRewriter}; use datafusion_common::{DataFusionError, Result}; -use datafusion_expr::interval_arithmetic::{Interval, NullableInterval}; -use datafusion_expr::{expr::InList, lit, Between, BinaryExpr, Expr}; +use datafusion_expr_common::interval_arithmetic::{Interval, NullableInterval}; +use crate::expr::Sort; /// Rewrite expressions to incorporate guarantees. /// /// Guarantees are a mapping from an expression (which currently is always a /// column reference) to a [NullableInterval]. The interval represents the known -/// possible values of the column. Using these known values, expressions are -/// rewritten so they can be simplified using `ConstEvaluator` and `Simplifier`. +/// possible values of the column. /// /// For example, if we know that a column is not null and has values in the /// range [1, 10), we can rewrite `x IS NULL` to `false` or `x < 10` to `true`. -/// -/// See a full example in [`ExprSimplifier::with_guarantees()`]. -/// -/// [`ExprSimplifier::with_guarantees()`]: crate::simplify_expressions::expr_simplifier::ExprSimplifier::with_guarantees pub struct GuaranteeRewriter<'a> { guarantees: HashMap<&'a Expr, &'a NullableInterval>, } @@ -203,10 +197,10 @@ impl TreeNodeRewriter for GuaranteeRewriter<'_> { mod tests { use super::*; + use crate::{col, Operator}; use arrow::datatypes::DataType; use datafusion_common::tree_node::{TransformedResult, TreeNode}; use datafusion_common::ScalarValue; - use datafusion_expr::{col, Operator}; #[test] fn test_null_handling() { diff --git a/datafusion/expr/src/expr_rewriter/mod.rs b/datafusion/expr/src/expr_rewriter/mod.rs index 9c3c5df7007f..163553479c95 100644 --- a/datafusion/expr/src/expr_rewriter/mod.rs +++ b/datafusion/expr/src/expr_rewriter/mod.rs @@ -31,7 +31,9 @@ use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; use datafusion_common::TableReference; use datafusion_common::{Column, DFSchema, Result}; +mod guarantees; mod order_by; + pub use order_by::rewrite_sort_cols_by_aggs; /// Trait for rewriting [`Expr`]s into function calls. diff --git a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs index c7912bbf70b0..2543292ab670 100644 --- a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs +++ b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs @@ -50,7 +50,6 @@ use datafusion_physical_expr::{create_physical_expr, execution_props::ExecutionP use super::inlist_simplifier::ShortenInListSimplifier; use super::utils::*; use crate::analyzer::type_coercion::TypeCoercionRewriter; -use crate::simplify_expressions::guarantees::GuaranteeRewriter; use crate::simplify_expressions::regex::simplify_regex_expr; use crate::simplify_expressions::unwrap_cast::{ is_cast_expr_and_support_unwrap_cast_in_comparison_for_binary, @@ -58,6 +57,7 @@ use crate::simplify_expressions::unwrap_cast::{ unwrap_cast_in_comparison_for_binary, }; use crate::simplify_expressions::SimplifyInfo; +use datafusion_expr::expr_rewriter::guarantees::GuaranteeRewriter; use datafusion_expr_common::casts::try_cast_literal_to_type; use indexmap::IndexSet; use regex::Regex; diff --git a/datafusion/optimizer/src/simplify_expressions/mod.rs b/datafusion/optimizer/src/simplify_expressions/mod.rs index 7ae38eec9a3a..4112f2598671 100644 --- a/datafusion/optimizer/src/simplify_expressions/mod.rs +++ b/datafusion/optimizer/src/simplify_expressions/mod.rs @@ -19,7 +19,6 @@ //! [`ExprSimplifier`] simplifies individual `Expr`s. pub mod expr_simplifier; -mod guarantees; mod inlist_simplifier; mod regex; pub mod simplify_exprs; @@ -35,4 +34,4 @@ pub use simplify_exprs::*; pub use simplify_predicates::simplify_predicates; // Export for test in datafusion/core/tests/optimizer_integration.rs -pub use guarantees::GuaranteeRewriter; +pub use datafusion_expr::expr_rewriter::guarantees::GuaranteeRewriter; From 701909ce3f77f4c679101c78349739fd8210ae3e Mon Sep 17 00:00:00 2001 From: Pepijn Van Eeckhoudt Date: Tue, 18 Nov 2025 18:01:50 +0100 Subject: [PATCH 2/7] Make GuaranteeRewriter implementation private --- datafusion/core/tests/optimizer/mod.rs | 26 ++-- .../expr/src/expr_rewriter/guarantees.rs | 114 ++++++++++++------ datafusion/expr/src/expr_rewriter/mod.rs | 2 + .../simplify_expressions/expr_simplifier.rs | 10 +- .../optimizer/src/simplify_expressions/mod.rs | 3 - 5 files changed, 100 insertions(+), 55 deletions(-) diff --git a/datafusion/core/tests/optimizer/mod.rs b/datafusion/core/tests/optimizer/mod.rs index 9b2a5596827d..b288706a54c9 100644 --- a/datafusion/core/tests/optimizer/mod.rs +++ b/datafusion/core/tests/optimizer/mod.rs @@ -27,7 +27,7 @@ use arrow::datatypes::{ DataType, Field, Fields, Schema, SchemaBuilder, SchemaRef, TimeUnit, }; use datafusion_common::config::ConfigOptions; -use datafusion_common::tree_node::{TransformedResult, TreeNode}; +use datafusion_common::tree_node::TransformedResult; use datafusion_common::{plan_err, DFSchema, Result, ScalarValue, TableReference}; use datafusion_expr::interval_arithmetic::{Interval, NullableInterval}; use datafusion_expr::{ @@ -37,7 +37,6 @@ use datafusion_expr::{ use datafusion_functions::core::expr_ext::FieldAccessor; use datafusion_optimizer::analyzer::Analyzer; use datafusion_optimizer::optimizer::Optimizer; -use datafusion_optimizer::simplify_expressions::GuaranteeRewriter; use datafusion_optimizer::{OptimizerConfig, OptimizerContext}; use datafusion_sql::planner::{ContextProvider, SqlToRel}; use datafusion_sql::sqlparser::ast::Statement; @@ -45,6 +44,7 @@ use datafusion_sql::sqlparser::dialect::GenericDialect; use datafusion_sql::sqlparser::parser::Parser; use chrono::DateTime; +use datafusion_expr::expr_rewriter::rewrite_with_guarantees; use datafusion_functions::datetime; #[cfg(test)] @@ -304,8 +304,6 @@ fn test_inequalities_non_null_bounded() { ), ]; - let mut rewriter = GuaranteeRewriter::new(guarantees.iter()); - // (original_expr, expected_simplification) let simplified_cases = &[ (col("x").lt(lit(0)), false), @@ -337,7 +335,7 @@ fn test_inequalities_non_null_bounded() { ), ]; - validate_simplified_cases(&mut rewriter, simplified_cases); + validate_simplified_cases(&guarantees, simplified_cases); let unchanged_cases = &[ col("x").gt(lit(2)), @@ -348,16 +346,20 @@ fn test_inequalities_non_null_bounded() { col("x").not_between(lit(3), lit(10)), ]; - validate_unchanged_cases(&mut rewriter, unchanged_cases); + validate_unchanged_cases(&guarantees, unchanged_cases); } -fn validate_simplified_cases(rewriter: &mut GuaranteeRewriter, cases: &[(Expr, T)]) -where +fn validate_simplified_cases( + guarantees: &[(Expr, NullableInterval)], + cases: &[(Expr, T)], +) where ScalarValue: From, T: Clone, { for (expr, expected_value) in cases { - let output = expr.clone().rewrite(rewriter).data().unwrap(); + let output = rewrite_with_guarantees(expr.clone(), guarantees) + .data() + .unwrap(); let expected = lit(ScalarValue::from(expected_value.clone())); assert_eq!( output, expected, @@ -365,9 +367,11 @@ where ); } } -fn validate_unchanged_cases(rewriter: &mut GuaranteeRewriter, cases: &[Expr]) { +fn validate_unchanged_cases(guarantees: &[(Expr, NullableInterval)], cases: &[Expr]) { for expr in cases { - let output = expr.clone().rewrite(rewriter).data().unwrap(); + let output = rewrite_with_guarantees(expr.clone(), guarantees) + .data() + .unwrap(); assert_eq!( &output, expr, "{expr} was simplified to {output}, but expected it to be unchanged" diff --git a/datafusion/expr/src/expr_rewriter/guarantees.rs b/datafusion/expr/src/expr_rewriter/guarantees.rs index 1b1cd8707a74..19e2987e23ad 100644 --- a/datafusion/expr/src/expr_rewriter/guarantees.rs +++ b/datafusion/expr/src/expr_rewriter/guarantees.rs @@ -17,13 +17,18 @@ //! Rewrite expressions based on external expression value range guarantees. -use std::{borrow::Cow, collections::HashMap}; +use std::borrow::Cow; -use crate::{expr::InList, lit, Between, BinaryExpr, Expr, LogicalPlan}; -use datafusion_common::tree_node::{Transformed, TreeNode, TreeNodeRewriter}; -use datafusion_common::{DataFusionError, Result}; +use crate::{expr::InList, lit, Between, BinaryExpr, Expr}; +use datafusion_common::tree_node::{ + Transformed, TreeNode, TreeNodeRecursion, TreeNodeRewriter, +}; +use datafusion_common::{DataFusionError, HashMap, Result, ScalarValue}; use datafusion_expr_common::interval_arithmetic::{Interval, NullableInterval}; -use crate::expr::Sort; + +struct GuaranteeRewriter<'a> { + guarantees: &'a HashMap<&'a Expr, &'a NullableInterval>, +} /// Rewrite expressions to incorporate guarantees. /// @@ -33,27 +38,51 @@ use crate::expr::Sort; /// /// For example, if we know that a column is not null and has values in the /// range [1, 10), we can rewrite `x IS NULL` to `false` or `x < 10` to `true`. -pub struct GuaranteeRewriter<'a> { - guarantees: HashMap<&'a Expr, &'a NullableInterval>, +/// +/// If the set of guarantees will be used to rewrite multiple expressions consider using +/// [rewrite_with_guarantees_map] instead. +pub fn rewrite_with_guarantees<'a>( + expr: Expr, + guarantees: impl IntoIterator, +) -> Result> { + let guarantees_map: HashMap<&Expr, &NullableInterval> = + guarantees.into_iter().map(|(k, v)| (k, v)).collect(); + rewrite_with_guarantees_map(expr, &guarantees_map) } -impl<'a> GuaranteeRewriter<'a> { - pub fn new( - guarantees: impl IntoIterator, - ) -> Self { - Self { - // TODO: Clippy wants the "map" call removed, but doing so generates - // a compilation error. Remove the clippy directive once this - // issue is fixed. - #[allow(clippy::map_identity)] - guarantees: guarantees.into_iter().map(|(k, v)| (k, v)).collect(), - } - } +/// Rewrite expressions to incorporate guarantees. +/// +/// Guarantees are a mapping from an expression (which currently is always a +/// column reference) to a [NullableInterval]. The interval represents the known +/// possible values of the column. +/// +/// For example, if we know that a column is not null and has values in the +/// range [1, 10), we can rewrite `x IS NULL` to `false` or `x < 10` to `true`. +pub fn rewrite_with_guarantees_map<'a>( + expr: Expr, + guarantees: &'a HashMap<&'a Expr, &'a NullableInterval>, +) -> Result> { + let mut rewriter = GuaranteeRewriter { guarantees }; + expr.rewrite(&mut rewriter) } impl TreeNodeRewriter for GuaranteeRewriter<'_> { type Node = Expr; + fn f_down(&mut self, expr: Expr) -> Result> { + if self.guarantees.is_empty() { + return Ok(Transformed::no(expr)); + } + + match self.guarantees.get(&expr) { + Some(NullableInterval::Null { datatype }) => { + let null = lit(ScalarValue::try_new_null(datatype)?); + Ok(Transformed::new(null, true, TreeNodeRecursion::Jump)) + } + _ => Ok(Transformed::no(expr)), + } + } + fn f_up(&mut self, expr: Expr) -> Result> { if self.guarantees.is_empty() { return Ok(Transformed::no(expr)); @@ -199,7 +228,7 @@ mod tests { use crate::{col, Operator}; use arrow::datatypes::DataType; - use datafusion_common::tree_node::{TransformedResult, TreeNode}; + use datafusion_common::tree_node::TransformedResult; use datafusion_common::ScalarValue; #[test] @@ -215,26 +244,33 @@ mod tests { }, ), ]; - let mut rewriter = GuaranteeRewriter::new(guarantees.iter()); // x IS NULL => guaranteed false let expr = col("x").is_null(); - let output = expr.rewrite(&mut rewriter).data().unwrap(); + let output = rewrite_with_guarantees(expr, guarantees.iter()) + .data() + .unwrap(); assert_eq!(output, lit(false)); // x IS NOT NULL => guaranteed true let expr = col("x").is_not_null(); - let output = expr.rewrite(&mut rewriter).data().unwrap(); + let output = rewrite_with_guarantees(expr, guarantees.iter()) + .data() + .unwrap(); assert_eq!(output, lit(true)); } - fn validate_simplified_cases(rewriter: &mut GuaranteeRewriter, cases: &[(Expr, T)]) - where + fn validate_simplified_cases( + guarantees: &[(Expr, NullableInterval)], + cases: &[(Expr, T)], + ) where ScalarValue: From, T: Clone, { for (expr, expected_value) in cases { - let output = expr.clone().rewrite(rewriter).data().unwrap(); + let output = rewrite_with_guarantees(expr.clone(), guarantees.iter()) + .data() + .unwrap(); let expected = lit(ScalarValue::from(expected_value.clone())); assert_eq!( output, expected, @@ -243,9 +279,11 @@ mod tests { } } - fn validate_unchanged_cases(rewriter: &mut GuaranteeRewriter, cases: &[Expr]) { + fn validate_unchanged_cases(guarantees: &[(Expr, NullableInterval)], cases: &[Expr]) { for expr in cases { - let output = expr.clone().rewrite(rewriter).data().unwrap(); + let output = rewrite_with_guarantees(expr.clone(), guarantees.iter()) + .data() + .unwrap(); assert_eq!( &output, expr, "{expr} was simplified to {output}, but expected it to be unchanged" @@ -268,7 +306,6 @@ mod tests { }, ), ]; - let mut rewriter = GuaranteeRewriter::new(guarantees.iter()); // (original_expr, expected_simplification) let simplified_cases = &[ @@ -310,7 +347,7 @@ mod tests { ), ]; - validate_simplified_cases(&mut rewriter, simplified_cases); + validate_simplified_cases(&guarantees, simplified_cases); let unchanged_cases = &[ col("x").lt(lit(ScalarValue::Date32(Some(19000)))), @@ -329,7 +366,7 @@ mod tests { ), ]; - validate_unchanged_cases(&mut rewriter, unchanged_cases); + validate_unchanged_cases(&guarantees, unchanged_cases); } #[test] @@ -347,7 +384,6 @@ mod tests { }, ), ]; - let mut rewriter = GuaranteeRewriter::new(guarantees.iter()); // (original_expr, expected_simplification) let simplified_cases = &[ @@ -369,7 +405,7 @@ mod tests { ), ]; - validate_simplified_cases(&mut rewriter, simplified_cases); + validate_simplified_cases(&guarantees, simplified_cases); let unchanged_cases = &[ col("x").lt(lit("z")), @@ -387,7 +423,7 @@ mod tests { }), ]; - validate_unchanged_cases(&mut rewriter, unchanged_cases); + validate_unchanged_cases(&guarantees, unchanged_cases); } #[test] @@ -406,9 +442,10 @@ mod tests { for scalar in scalars { let guarantees = [(col("x"), NullableInterval::from(scalar.clone()))]; - let mut rewriter = GuaranteeRewriter::new(guarantees.iter()); - let output = col("x").rewrite(&mut rewriter).data().unwrap(); + let output = rewrite_with_guarantees(col("x"), guarantees.iter()) + .data() + .unwrap(); assert_eq!(output, Expr::Literal(scalar.clone(), None)); } } @@ -428,7 +465,6 @@ mod tests { }, ), ]; - let mut rewriter = GuaranteeRewriter::new(guarantees.iter()); // These cases should be simplified so the list doesn't contain any // values the guarantee says are outside the range. @@ -452,7 +488,9 @@ mod tests { .collect(), *negated, ); - let output = expr.clone().rewrite(&mut rewriter).data().unwrap(); + let output = rewrite_with_guarantees(expr.clone(), guarantees.iter()) + .data() + .unwrap(); let expected_list = expected_list .iter() .map(|v| lit(ScalarValue::Int32(Some(*v)))) diff --git a/datafusion/expr/src/expr_rewriter/mod.rs b/datafusion/expr/src/expr_rewriter/mod.rs index 163553479c95..b084be59b1c7 100644 --- a/datafusion/expr/src/expr_rewriter/mod.rs +++ b/datafusion/expr/src/expr_rewriter/mod.rs @@ -32,6 +32,8 @@ use datafusion_common::TableReference; use datafusion_common::{Column, DFSchema, Result}; mod guarantees; +pub use guarantees::rewrite_with_guarantees; +pub use guarantees::rewrite_with_guarantees_map; mod order_by; pub use order_by::rewrite_sort_cols_by_aggs; diff --git a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs index 2543292ab670..366c99ce8f28 100644 --- a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs +++ b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs @@ -31,6 +31,7 @@ use datafusion_common::{ cast::{as_large_list_array, as_list_array}, metadata::FieldMetadata, tree_node::{Transformed, TransformedResult, TreeNode, TreeNodeRewriter}, + HashMap, }; use datafusion_common::{ exec_datafusion_err, internal_err, DFSchema, DataFusionError, Result, ScalarValue, @@ -57,7 +58,7 @@ use crate::simplify_expressions::unwrap_cast::{ unwrap_cast_in_comparison_for_binary, }; use crate::simplify_expressions::SimplifyInfo; -use datafusion_expr::expr_rewriter::guarantees::GuaranteeRewriter; +use datafusion_expr::expr_rewriter::rewrite_with_guarantees_map; use datafusion_expr_common::casts::try_cast_literal_to_type; use indexmap::IndexSet; use regex::Regex; @@ -226,7 +227,8 @@ impl ExprSimplifier { let mut simplifier = Simplifier::new(&self.info); let mut const_evaluator = ConstEvaluator::try_new(self.info.execution_props())?; let mut shorten_in_list_simplifier = ShortenInListSimplifier::new(); - let mut guarantee_rewriter = GuaranteeRewriter::new(&self.guarantees); + let guarantees_map: HashMap<&Expr, &NullableInterval> = + self.guarantees.iter().map(|(k, v)| (k, v)).collect(); if self.canonicalize { expr = expr.rewrite(&mut Canonicalizer::new()).data()? @@ -243,7 +245,9 @@ impl ExprSimplifier { } = expr .rewrite(&mut const_evaluator)? .transform_data(|expr| expr.rewrite(&mut simplifier))? - .transform_data(|expr| expr.rewrite(&mut guarantee_rewriter))?; + .transform_data(|expr| { + rewrite_with_guarantees_map(expr, &guarantees_map) + })?; expr = data; num_cycles += 1; // Track if any transformation occurred diff --git a/datafusion/optimizer/src/simplify_expressions/mod.rs b/datafusion/optimizer/src/simplify_expressions/mod.rs index 4112f2598671..52a3be3652c8 100644 --- a/datafusion/optimizer/src/simplify_expressions/mod.rs +++ b/datafusion/optimizer/src/simplify_expressions/mod.rs @@ -32,6 +32,3 @@ pub use datafusion_expr::simplify::{SimplifyContext, SimplifyInfo}; pub use expr_simplifier::*; pub use simplify_exprs::*; pub use simplify_predicates::simplify_predicates; - -// Export for test in datafusion/core/tests/optimizer_integration.rs -pub use datafusion_expr::expr_rewriter::guarantees::GuaranteeRewriter; From b647c8f3f4afae6212fb4f10e942dcacb2209e68 Mon Sep 17 00:00:00 2001 From: Pepijn Van Eeckhoudt Date: Tue, 18 Nov 2025 19:00:49 +0100 Subject: [PATCH 3/7] Make null replacement the fallback branch of f_up --- .../expr/src/expr_rewriter/guarantees.rs | 71 +++++++++---------- 1 file changed, 32 insertions(+), 39 deletions(-) diff --git a/datafusion/expr/src/expr_rewriter/guarantees.rs b/datafusion/expr/src/expr_rewriter/guarantees.rs index 19e2987e23ad..fd9cbcdc20ae 100644 --- a/datafusion/expr/src/expr_rewriter/guarantees.rs +++ b/datafusion/expr/src/expr_rewriter/guarantees.rs @@ -20,9 +20,7 @@ use std::borrow::Cow; use crate::{expr::InList, lit, Between, BinaryExpr, Expr}; -use datafusion_common::tree_node::{ - Transformed, TreeNode, TreeNodeRecursion, TreeNodeRewriter, -}; +use datafusion_common::tree_node::{Transformed, TreeNode, TreeNodeRewriter}; use datafusion_common::{DataFusionError, HashMap, Result, ScalarValue}; use datafusion_expr_common::interval_arithmetic::{Interval, NullableInterval}; @@ -69,37 +67,21 @@ pub fn rewrite_with_guarantees_map<'a>( impl TreeNodeRewriter for GuaranteeRewriter<'_> { type Node = Expr; - fn f_down(&mut self, expr: Expr) -> Result> { - if self.guarantees.is_empty() { - return Ok(Transformed::no(expr)); - } - - match self.guarantees.get(&expr) { - Some(NullableInterval::Null { datatype }) => { - let null = lit(ScalarValue::try_new_null(datatype)?); - Ok(Transformed::new(null, true, TreeNodeRecursion::Jump)) - } - _ => Ok(Transformed::no(expr)), - } - } - fn f_up(&mut self, expr: Expr) -> Result> { if self.guarantees.is_empty() { return Ok(Transformed::no(expr)); } - match &expr { + let new_expr = match &expr { Expr::IsNull(inner) => match self.guarantees.get(inner.as_ref()) { - Some(NullableInterval::Null { .. }) => Ok(Transformed::yes(lit(true))), - Some(NullableInterval::NotNull { .. }) => { - Ok(Transformed::yes(lit(false))) - } - _ => Ok(Transformed::no(expr)), + Some(NullableInterval::Null { .. }) => Some(lit(true)), + Some(NullableInterval::NotNull { .. }) => Some(lit(false)), + _ => None, }, Expr::IsNotNull(inner) => match self.guarantees.get(inner.as_ref()) { - Some(NullableInterval::Null { .. }) => Ok(Transformed::yes(lit(false))), - Some(NullableInterval::NotNull { .. }) => Ok(Transformed::yes(lit(true))), - _ => Ok(Transformed::no(expr)), + Some(NullableInterval::Null { .. }) => Some(lit(false)), + Some(NullableInterval::NotNull { .. }) => Some(lit(true)), + _ => None, }, Expr::Between(Between { expr: inner, @@ -119,14 +101,14 @@ impl TreeNodeRewriter for GuaranteeRewriter<'_> { let contains = expr_interval.contains(*interval)?; if contains.is_certainly_true() { - Ok(Transformed::yes(lit(!negated))) + Some(lit(!negated)) } else if contains.is_certainly_false() { - Ok(Transformed::yes(lit(*negated))) + Some(lit(*negated)) } else { - Ok(Transformed::no(expr)) + None } } else { - Ok(Transformed::no(expr)) + None } } @@ -161,23 +143,23 @@ impl TreeNodeRewriter for GuaranteeRewriter<'_> { let result = left_interval.apply_operator(op, right_interval.as_ref())?; if result.is_certainly_true() { - Ok(Transformed::yes(lit(true))) + Some(lit(true)) } else if result.is_certainly_false() { - Ok(Transformed::yes(lit(false))) + Some(lit(false)) } else { - Ok(Transformed::no(expr)) + None } } - _ => Ok(Transformed::no(expr)), + _ => None, } } // Columns (if interval is collapsed to a single value) Expr::Column(_) => { if let Some(interval) = self.guarantees.get(&expr) { - Ok(Transformed::yes(interval.single_value().map_or(expr, lit))) + interval.single_value().map(lit) } else { - Ok(Transformed::no(expr)) + None } } @@ -207,16 +189,27 @@ impl TreeNodeRewriter for GuaranteeRewriter<'_> { }) .collect::>()?; - Ok(Transformed::yes(Expr::InList(InList { + Some(Expr::InList(InList { expr: inner.clone(), list: new_list, negated: *negated, - }))) + })) } else { - Ok(Transformed::no(expr)) + None } } + _ => None, + }; + + if let Some(e) = new_expr { + return Ok(Transformed::yes(e)); + } + + match self.guarantees.get(&expr) { + Some(NullableInterval::Null { datatype }) => { + Ok(Transformed::yes(lit(ScalarValue::try_new_null(datatype)?))) + } _ => Ok(Transformed::no(expr)), } } From 720d5c2bd4ede26d35771de0b7ae3ae71f1507f2 Mon Sep 17 00:00:00 2001 From: Pepijn Van Eeckhoudt Date: Wed, 19 Nov 2025 09:12:41 +0100 Subject: [PATCH 4/7] Restructure GuaranteeRewriter for readability --- .../expr/src/expr_rewriter/guarantees.rs | 259 +++++++++--------- 1 file changed, 130 insertions(+), 129 deletions(-) diff --git a/datafusion/expr/src/expr_rewriter/guarantees.rs b/datafusion/expr/src/expr_rewriter/guarantees.rs index fd9cbcdc20ae..8190daeef99a 100644 --- a/datafusion/expr/src/expr_rewriter/guarantees.rs +++ b/datafusion/expr/src/expr_rewriter/guarantees.rs @@ -21,7 +21,7 @@ use std::borrow::Cow; use crate::{expr::InList, lit, Between, BinaryExpr, Expr}; use datafusion_common::tree_node::{Transformed, TreeNode, TreeNodeRewriter}; -use datafusion_common::{DataFusionError, HashMap, Result, ScalarValue}; +use datafusion_common::{DataFusionError, HashMap, Result}; use datafusion_expr_common::interval_arithmetic::{Interval, NullableInterval}; struct GuaranteeRewriter<'a> { @@ -83,135 +83,131 @@ impl TreeNodeRewriter for GuaranteeRewriter<'_> { Some(NullableInterval::NotNull { .. }) => Some(lit(true)), _ => None, }, - Expr::Between(Between { - expr: inner, - negated, - low, - high, - }) => { - if let (Some(interval), Expr::Literal(low, _), Expr::Literal(high, _)) = ( - self.guarantees.get(inner.as_ref()), - low.as_ref(), - high.as_ref(), - ) { - let expr_interval = NullableInterval::NotNull { - values: Interval::try_new(low.clone(), high.clone())?, - }; - - let contains = expr_interval.contains(*interval)?; - - if contains.is_certainly_true() { - Some(lit(!negated)) - } else if contains.is_certainly_false() { - Some(lit(*negated)) - } else { - None - } + Expr::Between(b) => self.rewrite_between(b)?, + Expr::BinaryExpr(b) => self.rewrite_binary_expr(&b)?, + Expr::InList(i) => self.rewrite_inlist(i)?, + _ => None, + }; + + if let Some(e) = new_expr { + return Ok(Transformed::yes(e)); + } + + match self.guarantees.get(&expr) { + Some(interval) => { + // If an expression collapses to a single value, replace it with a literal + if let Some(value) = interval.single_value() { + Ok(Transformed::yes(lit(value))) } else { - None + Ok(Transformed::no(expr)) } } + _ => Ok(Transformed::no(expr)), + } + } +} - Expr::BinaryExpr(BinaryExpr { left, op, right }) => { - // The left or right side of expression might either have a guarantee - // or be a literal. Either way, we can resolve them to a NullableInterval. - let left_interval = self - .guarantees - .get(left.as_ref()) - .map(|interval| Cow::Borrowed(*interval)) - .or_else(|| { - if let Expr::Literal(value, _) = left.as_ref() { - Some(Cow::Owned(value.clone().into())) - } else { - None - } - }); - let right_interval = self - .guarantees - .get(right.as_ref()) - .map(|interval| Cow::Borrowed(*interval)) - .or_else(|| { - if let Expr::Literal(value, _) = right.as_ref() { - Some(Cow::Owned(value.clone().into())) - } else { - None - } - }); - - match (left_interval, right_interval) { - (Some(left_interval), Some(right_interval)) => { - let result = - left_interval.apply_operator(op, right_interval.as_ref())?; - if result.is_certainly_true() { - Some(lit(true)) - } else if result.is_certainly_false() { - Some(lit(false)) - } else { - None - } - } - _ => None, - } - } +impl GuaranteeRewriter<'_> { + fn rewrite_between( + &mut self, + between: &Between, + ) -> Result, DataFusionError> { + let (Some(interval), Expr::Literal(low, _), Expr::Literal(high, _)) = ( + self.guarantees.get(between.expr.as_ref()), + between.low.as_ref(), + between.high.as_ref(), + ) else { + return Ok(None); + }; - // Columns (if interval is collapsed to a single value) - Expr::Column(_) => { - if let Some(interval) = self.guarantees.get(&expr) { - interval.single_value().map(lit) + let values = Interval::try_new(low.clone(), high.clone())?; + let expr_interval = NullableInterval::NotNull { values }; + + let contains = expr_interval.contains(*interval)?; + + if contains.is_certainly_true() { + Ok(Some(lit(!between.negated))) + } else if contains.is_certainly_false() { + Ok(Some(lit(between.negated))) + } else { + Ok(None) + } + } + + fn rewrite_binary_expr( + &mut self, + b: &&BinaryExpr, + ) -> Result, DataFusionError> { + // The left or right side of expression might either have a guarantee + // or be a literal. Either way, we can resolve them to a NullableInterval. + let left_interval = self + .guarantees + .get(b.left.as_ref()) + .map(|interval| Cow::Borrowed(*interval)) + .or_else(|| { + if let Expr::Literal(value, _) = b.left.as_ref() { + Some(Cow::Owned(value.clone().into())) } else { None } - } - - Expr::InList(InList { - expr: inner, - list, - negated, - }) => { - if let Some(interval) = self.guarantees.get(inner.as_ref()) { - // Can remove items from the list that don't match the guarantee - let new_list: Vec = list - .iter() - .filter_map(|expr| { - if let Expr::Literal(item, _) = expr { - match interval - .contains(NullableInterval::from(item.clone())) - { - // If we know for certain the value isn't in the column's interval, - // we can skip checking it. - Ok(interval) if interval.is_certainly_false() => None, - Ok(_) => Some(Ok(expr.clone())), - Err(e) => Some(Err(e)), - } - } else { - Some(Ok(expr.clone())) - } - }) - .collect::>()?; - - Some(Expr::InList(InList { - expr: inner.clone(), - list: new_list, - negated: *negated, - })) + }); + let right_interval = self + .guarantees + .get(b.right.as_ref()) + .map(|interval| Cow::Borrowed(*interval)) + .or_else(|| { + if let Expr::Literal(value, _) = b.right.as_ref() { + Some(Cow::Owned(value.clone().into())) + } else { + None + } + }); + + Ok(match (left_interval, right_interval) { + (Some(left_interval), Some(right_interval)) => { + let result = + left_interval.apply_operator(&b.op, right_interval.as_ref())?; + if result.is_certainly_true() { + Some(lit(true)) + } else if result.is_certainly_false() { + Some(lit(false)) } else { None } } - _ => None, - }; + }) + } - if let Some(e) = new_expr { - return Ok(Transformed::yes(e)); - } + fn rewrite_inlist(&mut self, i: &InList) -> Result, DataFusionError> { + let Some(interval) = self.guarantees.get(i.expr.as_ref()) else { + return Ok(None); + }; - match self.guarantees.get(&expr) { - Some(NullableInterval::Null { datatype }) => { - Ok(Transformed::yes(lit(ScalarValue::try_new_null(datatype)?))) - } - _ => Ok(Transformed::no(expr)), - } + // Can remove items from the list that don't match the guarantee + let new_list: Vec = i + .list + .iter() + .filter_map(|expr| { + if let Expr::Literal(item, _) = expr { + match interval.contains(NullableInterval::from(item.clone())) { + // If we know for certain the value isn't in the column's interval, + // we can skip checking it. + Ok(interval) if interval.is_certainly_false() => None, + Ok(_) => Some(Ok(expr.clone())), + Err(e) => Some(Err(e)), + } + } else { + Some(Ok(expr.clone())) + } + }) + .collect::>()?; + + Ok(Some(Expr::InList(InList { + expr: i.expr.clone(), + list: new_list, + negated: i.negated, + }))) } } @@ -225,7 +221,7 @@ mod tests { use datafusion_common::ScalarValue; #[test] - fn test_null_handling() { + fn test_not_null_guarantee() { // IsNull / IsNotNull can be rewritten to true / false let guarantees = [ // Note: AlwaysNull case handled by test_column_single_value test, @@ -233,24 +229,29 @@ mod tests { ( col("x"), NullableInterval::NotNull { - values: Interval::make_unbounded(&DataType::Boolean).unwrap(), + values: Interval::make_unbounded(&DataType::Int32).unwrap(), }, ), ]; // x IS NULL => guaranteed false - let expr = col("x").is_null(); - let output = rewrite_with_guarantees(expr, guarantees.iter()) - .data() - .unwrap(); - assert_eq!(output, lit(false)); - - // x IS NOT NULL => guaranteed true - let expr = col("x").is_not_null(); - let output = rewrite_with_guarantees(expr, guarantees.iter()) - .data() - .unwrap(); - assert_eq!(output, lit(true)); + let is_null_cases = vec![ + (col("x").is_null(), Some(lit(false))), + (col("x").is_not_null(), Some(lit(true))), + (col("x").between(lit(1), lit(2)), None), + ]; + + for case in is_null_cases { + let output = rewrite_with_guarantees(case.0.clone(), guarantees.iter()) + .data() + .unwrap(); + let expected = match case.1 { + None => case.0, + Some(expected) => expected, + }; + + assert_eq!(output, expected); + } } fn validate_simplified_cases( From a3ebf3328c1dcb7a6103764e81ccdfb03324be76 Mon Sep 17 00:00:00 2001 From: Pepijn Van Eeckhoudt Date: Wed, 19 Nov 2025 09:13:23 +0100 Subject: [PATCH 5/7] Do not error out when rewriting 'between' expressions with empty value ranges --- datafusion/expr/src/expr_rewriter/guarantees.rs | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/datafusion/expr/src/expr_rewriter/guarantees.rs b/datafusion/expr/src/expr_rewriter/guarantees.rs index 8190daeef99a..b9c5caeca518 100644 --- a/datafusion/expr/src/expr_rewriter/guarantees.rs +++ b/datafusion/expr/src/expr_rewriter/guarantees.rs @@ -120,7 +120,12 @@ impl GuaranteeRewriter<'_> { return Ok(None); }; - let values = Interval::try_new(low.clone(), high.clone())?; + let Ok(values) = Interval::try_new(low.clone(), high.clone()) else { + // If we can't create an interval from the literals, be conservative and simply leave + // the expression unmodified. + return Ok(None); + }; + let expr_interval = NullableInterval::NotNull { values }; let contains = expr_interval.contains(*interval)?; @@ -239,6 +244,7 @@ mod tests { (col("x").is_null(), Some(lit(false))), (col("x").is_not_null(), Some(lit(true))), (col("x").between(lit(1), lit(2)), None), + (col("x").between(lit(1), lit(-2)), None), ]; for case in is_null_cases { From 0a7962eca41e1b126ecb1c810b54559962bae2aa Mon Sep 17 00:00:00 2001 From: Pepijn Van Eeckhoudt Date: Wed, 19 Nov 2025 16:08:50 +0100 Subject: [PATCH 6/7] Further simplification --- .../expr/src/expr_rewriter/guarantees.rs | 317 ++++++++++-------- datafusion/expr/src/expr_rewriter/mod.rs | 1 + .../optimizer/src/simplify_expressions/mod.rs | 3 + 3 files changed, 182 insertions(+), 139 deletions(-) diff --git a/datafusion/expr/src/expr_rewriter/guarantees.rs b/datafusion/expr/src/expr_rewriter/guarantees.rs index b9c5caeca518..3a8dbbd36e2b 100644 --- a/datafusion/expr/src/expr_rewriter/guarantees.rs +++ b/datafusion/expr/src/expr_rewriter/guarantees.rs @@ -21,24 +21,41 @@ use std::borrow::Cow; use crate::{expr::InList, lit, Between, BinaryExpr, Expr}; use datafusion_common::tree_node::{Transformed, TreeNode, TreeNodeRewriter}; -use datafusion_common::{DataFusionError, HashMap, Result}; +use datafusion_common::{DataFusionError, HashMap, Result, ScalarValue}; use datafusion_expr_common::interval_arithmetic::{Interval, NullableInterval}; -struct GuaranteeRewriter<'a> { - guarantees: &'a HashMap<&'a Expr, &'a NullableInterval>, +/// Rewrite expressions to incorporate guarantees. +pub struct GuaranteeRewriter<'a> { + guarantees: HashMap<&'a Expr, &'a NullableInterval>, +} + +impl<'a> GuaranteeRewriter<'a> { + pub fn new( + guarantees: impl IntoIterator, + ) -> Self { + Self { + guarantees: guarantees.into_iter().map(|(k, v)| (k, v)).collect(), + } + } } /// Rewrite expressions to incorporate guarantees. /// /// Guarantees are a mapping from an expression (which currently is always a -/// column reference) to a [NullableInterval]. The interval represents the known -/// possible values of the column. +/// column reference) to a [NullableInterval] that represents the known possible +/// values of the expression. +/// +/// Rewriting expressions using this type of guarantee can make the work of other expression +/// simplifications, like const evaluation, easier. /// /// For example, if we know that a column is not null and has values in the /// range [1, 10), we can rewrite `x IS NULL` to `false` or `x < 10` to `true`. /// -/// If the set of guarantees will be used to rewrite multiple expressions consider using +/// If the set of guarantees will be used to rewrite more than one expression, consider using /// [rewrite_with_guarantees_map] instead. +/// +/// A full example of using this rewrite rule can be found in +/// [`ExprSimplifier::with_guarantees()`](https://docs.rs/datafusion/latest/datafusion/optimizer/simplify_expressions/struct.ExprSimplifier.html#method.with_guarantees). pub fn rewrite_with_guarantees<'a>( expr: Expr, guarantees: impl IntoIterator, @@ -60,160 +77,182 @@ pub fn rewrite_with_guarantees_map<'a>( expr: Expr, guarantees: &'a HashMap<&'a Expr, &'a NullableInterval>, ) -> Result> { - let mut rewriter = GuaranteeRewriter { guarantees }; - expr.rewrite(&mut rewriter) + expr.transform_up(|e| rewrite_expr(e, guarantees)) } impl TreeNodeRewriter for GuaranteeRewriter<'_> { type Node = Expr; fn f_up(&mut self, expr: Expr) -> Result> { - if self.guarantees.is_empty() { - return Ok(Transformed::no(expr)); - } + rewrite_expr(expr, &self.guarantees) + } +} - let new_expr = match &expr { - Expr::IsNull(inner) => match self.guarantees.get(inner.as_ref()) { - Some(NullableInterval::Null { .. }) => Some(lit(true)), - Some(NullableInterval::NotNull { .. }) => Some(lit(false)), - _ => None, - }, - Expr::IsNotNull(inner) => match self.guarantees.get(inner.as_ref()) { - Some(NullableInterval::Null { .. }) => Some(lit(false)), - Some(NullableInterval::NotNull { .. }) => Some(lit(true)), - _ => None, - }, - Expr::Between(b) => self.rewrite_between(b)?, - Expr::BinaryExpr(b) => self.rewrite_binary_expr(&b)?, - Expr::InList(i) => self.rewrite_inlist(i)?, - _ => None, - }; +fn rewrite_expr( + expr: Expr, + guarantees: &HashMap<&Expr, &NullableInterval>, +) -> Result> { + if guarantees.is_empty() { + return Ok(Transformed::no(expr)); + } - if let Some(e) = new_expr { - return Ok(Transformed::yes(e)); - } + let new_expr = match &expr { + Expr::IsNull(inner) => match guarantees.get(inner.as_ref()) { + Some(NullableInterval::Null { .. }) => Some(lit(true)), + Some(NullableInterval::NotNull { .. }) => Some(lit(false)), + _ => None, + }, + Expr::IsNotNull(inner) => match guarantees.get(inner.as_ref()) { + Some(NullableInterval::Null { .. }) => Some(lit(false)), + Some(NullableInterval::NotNull { .. }) => Some(lit(true)), + _ => None, + }, + Expr::Between(b) => rewrite_between(b, guarantees)?, + Expr::BinaryExpr(b) => rewrite_binary_expr(b, guarantees)?, + Expr::InList(i) => rewrite_inlist(i, guarantees)?, + _ => None, + }; + + if let Some(e) = new_expr { + return Ok(Transformed::yes(e)); + } - match self.guarantees.get(&expr) { - Some(interval) => { - // If an expression collapses to a single value, replace it with a literal - if let Some(value) = interval.single_value() { - Ok(Transformed::yes(lit(value))) - } else { - Ok(Transformed::no(expr)) - } + match guarantees.get(&expr) { + Some(interval) => { + // If an expression collapses to a single value, replace it with a literal + if let Some(value) = interval.single_value() { + Ok(Transformed::yes(lit(value))) + } else { + Ok(Transformed::no(expr)) } - _ => Ok(Transformed::no(expr)), } + _ => Ok(Transformed::no(expr)), + } +} + +fn rewrite_between( + between: &Between, + guarantees: &HashMap<&Expr, &NullableInterval>, +) -> Result, DataFusionError> { + let (Some(interval), Expr::Literal(low, _), Expr::Literal(high, _)) = ( + guarantees.get(between.expr.as_ref()), + between.low.as_ref(), + between.high.as_ref(), + ) else { + return Ok(None); + }; + + // Ensure that, if low or high are null, their type matches the other bound + let low = ensure_typed_null(low, high)?; + let high = ensure_typed_null(high, &low)?; + + let Ok(values) = Interval::try_new(low, high) else { + // If we can't create an interval from the literals, be conservative and simply leave + // the expression unmodified. + return Ok(None); + }; + + let expr_interval = NullableInterval::NotNull { values }; + + let contains = expr_interval.contains(*interval)?; + + if contains.is_certainly_true() { + Ok(Some(lit(!between.negated))) + } else if contains.is_certainly_false() { + Ok(Some(lit(between.negated))) + } else { + Ok(None) } } -impl GuaranteeRewriter<'_> { - fn rewrite_between( - &mut self, - between: &Between, - ) -> Result, DataFusionError> { - let (Some(interval), Expr::Literal(low, _), Expr::Literal(high, _)) = ( - self.guarantees.get(between.expr.as_ref()), - between.low.as_ref(), - between.high.as_ref(), - ) else { - return Ok(None); - }; - - let Ok(values) = Interval::try_new(low.clone(), high.clone()) else { - // If we can't create an interval from the literals, be conservative and simply leave - // the expression unmodified. - return Ok(None); - }; - - let expr_interval = NullableInterval::NotNull { values }; - - let contains = expr_interval.contains(*interval)?; - - if contains.is_certainly_true() { - Ok(Some(lit(!between.negated))) - } else if contains.is_certainly_false() { - Ok(Some(lit(between.negated))) +fn ensure_typed_null( + value: &ScalarValue, + other: &ScalarValue, +) -> Result { + Ok( + if value.data_type().is_null() && !other.data_type().is_null() { + ScalarValue::try_new_null(&other.data_type())? } else { - Ok(None) + value.clone() + }, + ) +} + +fn rewrite_binary_expr( + binary: &BinaryExpr, + guarantees: &HashMap<&Expr, &NullableInterval>, +) -> Result, DataFusionError> { + // The left or right side of expression might either have a guarantee + // or be a literal. Either way, we can resolve them to a NullableInterval. + let left_interval = guarantees + .get(binary.left.as_ref()) + .map(|interval| Cow::Borrowed(*interval)) + .or_else(|| { + if let Expr::Literal(value, _) = binary.left.as_ref() { + Some(Cow::Owned(value.clone().into())) + } else { + None + } + }); + let right_interval = guarantees + .get(binary.right.as_ref()) + .map(|interval| Cow::Borrowed(*interval)) + .or_else(|| { + if let Expr::Literal(value, _) = binary.right.as_ref() { + Some(Cow::Owned(value.clone().into())) + } else { + None + } + }); + + Ok(match (left_interval, right_interval) { + (Some(left_interval), Some(right_interval)) => { + let result = + left_interval.apply_operator(&binary.op, right_interval.as_ref())?; + if result.is_certainly_true() { + Some(lit(true)) + } else if result.is_certainly_false() { + Some(lit(false)) + } else { + None + } } - } + _ => None, + }) +} - fn rewrite_binary_expr( - &mut self, - b: &&BinaryExpr, - ) -> Result, DataFusionError> { - // The left or right side of expression might either have a guarantee - // or be a literal. Either way, we can resolve them to a NullableInterval. - let left_interval = self - .guarantees - .get(b.left.as_ref()) - .map(|interval| Cow::Borrowed(*interval)) - .or_else(|| { - if let Expr::Literal(value, _) = b.left.as_ref() { - Some(Cow::Owned(value.clone().into())) - } else { - None - } - }); - let right_interval = self - .guarantees - .get(b.right.as_ref()) - .map(|interval| Cow::Borrowed(*interval)) - .or_else(|| { - if let Expr::Literal(value, _) = b.right.as_ref() { - Some(Cow::Owned(value.clone().into())) - } else { - None - } - }); - - Ok(match (left_interval, right_interval) { - (Some(left_interval), Some(right_interval)) => { - let result = - left_interval.apply_operator(&b.op, right_interval.as_ref())?; - if result.is_certainly_true() { - Some(lit(true)) - } else if result.is_certainly_false() { - Some(lit(false)) - } else { - None +fn rewrite_inlist( + inlist: &InList, + guarantees: &HashMap<&Expr, &NullableInterval>, +) -> Result, DataFusionError> { + let Some(interval) = guarantees.get(inlist.expr.as_ref()) else { + return Ok(None); + }; + + // Can remove items from the list that don't match the guarantee + let new_list: Vec = inlist + .list + .iter() + .filter_map(|expr| { + if let Expr::Literal(item, _) = expr { + match interval.contains(NullableInterval::from(item.clone())) { + // If we know for certain the value isn't in the column's interval, + // we can skip checking it. + Ok(interval) if interval.is_certainly_false() => None, + Ok(_) => Some(Ok(expr.clone())), + Err(e) => Some(Err(e)), } + } else { + Some(Ok(expr.clone())) } - _ => None, }) - } + .collect::>()?; - fn rewrite_inlist(&mut self, i: &InList) -> Result, DataFusionError> { - let Some(interval) = self.guarantees.get(i.expr.as_ref()) else { - return Ok(None); - }; - - // Can remove items from the list that don't match the guarantee - let new_list: Vec = i - .list - .iter() - .filter_map(|expr| { - if let Expr::Literal(item, _) = expr { - match interval.contains(NullableInterval::from(item.clone())) { - // If we know for certain the value isn't in the column's interval, - // we can skip checking it. - Ok(interval) if interval.is_certainly_false() => None, - Ok(_) => Some(Ok(expr.clone())), - Err(e) => Some(Err(e)), - } - } else { - Some(Ok(expr.clone())) - } - }) - .collect::>()?; - - Ok(Some(Expr::InList(InList { - expr: i.expr.clone(), - list: new_list, - negated: i.negated, - }))) - } + Ok(Some(Expr::InList(InList { + expr: inlist.expr.clone(), + list: new_list, + negated: inlist.negated, + }))) } #[cfg(test)] diff --git a/datafusion/expr/src/expr_rewriter/mod.rs b/datafusion/expr/src/expr_rewriter/mod.rs index b084be59b1c7..31759f1cc9cf 100644 --- a/datafusion/expr/src/expr_rewriter/mod.rs +++ b/datafusion/expr/src/expr_rewriter/mod.rs @@ -34,6 +34,7 @@ use datafusion_common::{Column, DFSchema, Result}; mod guarantees; pub use guarantees::rewrite_with_guarantees; pub use guarantees::rewrite_with_guarantees_map; +pub use guarantees::GuaranteeRewriter; mod order_by; pub use order_by::rewrite_sort_cols_by_aggs; diff --git a/datafusion/optimizer/src/simplify_expressions/mod.rs b/datafusion/optimizer/src/simplify_expressions/mod.rs index 52a3be3652c8..e238fca32689 100644 --- a/datafusion/optimizer/src/simplify_expressions/mod.rs +++ b/datafusion/optimizer/src/simplify_expressions/mod.rs @@ -32,3 +32,6 @@ pub use datafusion_expr::simplify::{SimplifyContext, SimplifyInfo}; pub use expr_simplifier::*; pub use simplify_exprs::*; pub use simplify_predicates::simplify_predicates; + +// Export for test in datafusion/core/tests/optimizer_integration.rs +pub use datafusion_expr::expr_rewriter::GuaranteeRewriter; From b53d2cfca9f9136da70a39f1896f4885253dbd92 Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Wed, 19 Nov 2025 10:58:39 -0500 Subject: [PATCH 7/7] reduce cloning in GuaranteeRewriter --- .../expr/src/expr_rewriter/guarantees.rs | 110 ++++++++---------- 1 file changed, 51 insertions(+), 59 deletions(-) diff --git a/datafusion/expr/src/expr_rewriter/guarantees.rs b/datafusion/expr/src/expr_rewriter/guarantees.rs index 3a8dbbd36e2b..e2b136081bd1 100644 --- a/datafusion/expr/src/expr_rewriter/guarantees.rs +++ b/datafusion/expr/src/expr_rewriter/guarantees.rs @@ -96,50 +96,42 @@ fn rewrite_expr( return Ok(Transformed::no(expr)); } - let new_expr = match &expr { + // If an expression collapses to a single value, replace it with a literal + if let Some(interval) = guarantees.get(&expr) { + if let Some(value) = interval.single_value() { + return Ok(Transformed::yes(lit(value))); + } + } + + let result = match expr { Expr::IsNull(inner) => match guarantees.get(inner.as_ref()) { - Some(NullableInterval::Null { .. }) => Some(lit(true)), - Some(NullableInterval::NotNull { .. }) => Some(lit(false)), - _ => None, + Some(NullableInterval::Null { .. }) => Transformed::yes(lit(true)), + Some(NullableInterval::NotNull { .. }) => Transformed::yes(lit(false)), + _ => Transformed::no(Expr::IsNull(inner)), }, Expr::IsNotNull(inner) => match guarantees.get(inner.as_ref()) { - Some(NullableInterval::Null { .. }) => Some(lit(false)), - Some(NullableInterval::NotNull { .. }) => Some(lit(true)), - _ => None, + Some(NullableInterval::Null { .. }) => Transformed::yes(lit(false)), + Some(NullableInterval::NotNull { .. }) => Transformed::yes(lit(true)), + _ => Transformed::no(Expr::IsNotNull(inner)), }, Expr::Between(b) => rewrite_between(b, guarantees)?, Expr::BinaryExpr(b) => rewrite_binary_expr(b, guarantees)?, Expr::InList(i) => rewrite_inlist(i, guarantees)?, - _ => None, + expr => Transformed::no(expr), }; - - if let Some(e) = new_expr { - return Ok(Transformed::yes(e)); - } - - match guarantees.get(&expr) { - Some(interval) => { - // If an expression collapses to a single value, replace it with a literal - if let Some(value) = interval.single_value() { - Ok(Transformed::yes(lit(value))) - } else { - Ok(Transformed::no(expr)) - } - } - _ => Ok(Transformed::no(expr)), - } + Ok(result) } fn rewrite_between( - between: &Between, + between: Between, guarantees: &HashMap<&Expr, &NullableInterval>, -) -> Result, DataFusionError> { +) -> Result> { let (Some(interval), Expr::Literal(low, _), Expr::Literal(high, _)) = ( guarantees.get(between.expr.as_ref()), between.low.as_ref(), between.high.as_ref(), ) else { - return Ok(None); + return Ok(Transformed::no(Expr::Between(between))); }; // Ensure that, if low or high are null, their type matches the other bound @@ -149,7 +141,7 @@ fn rewrite_between( let Ok(values) = Interval::try_new(low, high) else { // If we can't create an interval from the literals, be conservative and simply leave // the expression unmodified. - return Ok(None); + return Ok(Transformed::no(Expr::Between(between))); }; let expr_interval = NullableInterval::NotNull { values }; @@ -157,11 +149,11 @@ fn rewrite_between( let contains = expr_interval.contains(*interval)?; if contains.is_certainly_true() { - Ok(Some(lit(!between.negated))) + Ok(Transformed::yes(lit(!between.negated))) } else if contains.is_certainly_false() { - Ok(Some(lit(between.negated))) + Ok(Transformed::yes(lit(between.negated))) } else { - Ok(None) + Ok(Transformed::no(Expr::Between(between))) } } @@ -179,9 +171,9 @@ fn ensure_typed_null( } fn rewrite_binary_expr( - binary: &BinaryExpr, + binary: BinaryExpr, guarantees: &HashMap<&Expr, &NullableInterval>, -) -> Result, DataFusionError> { +) -> Result, DataFusionError> { // The left or right side of expression might either have a guarantee // or be a literal. Either way, we can resolve them to a NullableInterval. let left_interval = guarantees @@ -205,53 +197,53 @@ fn rewrite_binary_expr( } }); - Ok(match (left_interval, right_interval) { - (Some(left_interval), Some(right_interval)) => { - let result = - left_interval.apply_operator(&binary.op, right_interval.as_ref())?; - if result.is_certainly_true() { - Some(lit(true)) - } else if result.is_certainly_false() { - Some(lit(false)) - } else { - None - } + if let (Some(left_interval), Some(right_interval)) = (left_interval, right_interval) { + let result = left_interval.apply_operator(&binary.op, right_interval.as_ref())?; + if result.is_certainly_true() { + return Ok(Transformed::yes(lit(true))); + } else if result.is_certainly_false() { + return Ok(Transformed::yes(lit(false))); } - _ => None, - }) + } + Ok(Transformed::no(Expr::BinaryExpr(binary))) } fn rewrite_inlist( - inlist: &InList, + inlist: InList, guarantees: &HashMap<&Expr, &NullableInterval>, -) -> Result, DataFusionError> { +) -> Result, DataFusionError> { let Some(interval) = guarantees.get(inlist.expr.as_ref()) else { - return Ok(None); + return Ok(Transformed::no(Expr::InList(inlist))); }; + let InList { + expr, + list, + negated, + } = inlist; + // Can remove items from the list that don't match the guarantee - let new_list: Vec = inlist - .list - .iter() + let list: Vec = list + .into_iter() .filter_map(|expr| { - if let Expr::Literal(item, _) = expr { + if let Expr::Literal(item, _) = &expr { match interval.contains(NullableInterval::from(item.clone())) { // If we know for certain the value isn't in the column's interval, // we can skip checking it. Ok(interval) if interval.is_certainly_false() => None, - Ok(_) => Some(Ok(expr.clone())), + Ok(_) => Some(Ok(expr)), Err(e) => Some(Err(e)), } } else { - Some(Ok(expr.clone())) + Some(Ok(expr)) } }) .collect::>()?; - Ok(Some(Expr::InList(InList { - expr: inlist.expr.clone(), - list: new_list, - negated: inlist.negated, + Ok(Transformed::yes(Expr::InList(InList { + expr, + list, + negated, }))) }