diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs index 9ba866a4c919..0de4a87b941b 100644 --- a/datafusion/expr/src/expr.rs +++ b/datafusion/expr/src/expr.rs @@ -33,7 +33,9 @@ use crate::{ use crate::{window_frame, Volatility}; use arrow::datatypes::{DataType, FieldRef}; -use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; +use datafusion_common::tree_node::{ + Transformed, TransformedResult, TreeNode, TreeNodeRecursion, +}; use datafusion_common::{ internal_err, plan_err, Column, DFSchema, Result, ScalarValue, TableReference, }; @@ -1333,6 +1335,46 @@ impl Expr { Ok(using_columns) } + /// Return all references to columns in this expression. + /// + /// # Example + /// ``` + /// # use std::collections::HashSet; + /// # use datafusion_common::Column; + /// # use datafusion_expr::col; + /// // For an expression `a + (b * a)` + /// let expr = col("a") + (col("b") * col("a")); + /// let refs = expr.column_refs(); + /// // refs contains "a" and "b" + /// assert_eq!(refs.len(), 2); + /// assert!(refs.contains(&Column::new_unqualified("a"))); + /// assert!(refs.contains(&Column::new_unqualified("b"))); + /// ``` + pub fn column_refs(&self) -> HashSet<&Column> { + let mut using_columns = HashSet::new(); + self.add_column_refs(&mut using_columns); + using_columns + } + + /// Adds references to all columns in this expression to the set + /// + /// See [`Self::column_refs`] for details + pub fn add_column_refs<'a>(&'a self, set: &mut HashSet<&'a Column>) { + self.apply(|expr| { + if let Expr::Column(col) = expr { + set.insert(col); + } + Ok(TreeNodeRecursion::Continue) + }) + .expect("traversal is infallable"); + } + + /// Returns true if there are any column references in this Expr + pub fn any_column_refs(&self) -> bool { + self.exists(|expr| Ok(matches!(expr, Expr::Column(_)))) + .unwrap() + } + /// Return true when the expression contains out reference(correlated) expressions. pub fn contains_outer(&self) -> bool { self.exists(|expr| Ok(matches!(expr, Expr::OuterReferenceColumn { .. }))) @@ -2038,7 +2080,7 @@ mod test { // single column { let expr = &Expr::Cast(Cast::new(Box::new(col("a")), DataType::Float64)); - let columns = expr.to_columns()?; + let columns = expr.column_refs(); assert_eq!(1, columns.len()); assert!(columns.contains(&Column::from_name("a"))); } @@ -2046,7 +2088,7 @@ mod test { // multiple columns { let expr = col("a") + col("b") + lit(1); - let columns = expr.to_columns()?; + let columns = expr.column_refs(); assert_eq!(2, columns.len()); assert!(columns.contains(&Column::from_name("a"))); assert!(columns.contains(&Column::from_name("b"))); diff --git a/datafusion/expr/src/utils.rs b/datafusion/expr/src/utils.rs index 3ab0c180dcba..6baabfcc7130 100644 --- a/datafusion/expr/src/utils.rs +++ b/datafusion/expr/src/utils.rs @@ -46,6 +46,7 @@ pub const COUNT_STAR_EXPANSION: ScalarValue = ScalarValue::Int64(Some(1)); /// Recursively walk a list of expression trees, collecting the unique set of columns /// referenced in the expression +#[deprecated(since = "40.0.0", note = "Expr::add_column_refs instead")] pub fn exprlist_to_columns(expr: &[Expr], accum: &mut HashSet) -> Result<()> { for e in expr { expr_to_columns(e, accum)?; diff --git a/datafusion/optimizer/src/analyzer/subquery.rs b/datafusion/optimizer/src/analyzer/subquery.rs index 444ee94c4292..7d1290204eb7 100644 --- a/datafusion/optimizer/src/analyzer/subquery.rs +++ b/datafusion/optimizer/src/analyzer/subquery.rs @@ -300,19 +300,17 @@ fn can_pullup_over_aggregation(expr: &Expr) -> bool { }) = expr { match (left.deref(), right.deref()) { - (Expr::Column(_), right) if right.to_columns().unwrap().is_empty() => true, - (left, Expr::Column(_)) if left.to_columns().unwrap().is_empty() => true, + (Expr::Column(_), right) => !right.any_column_refs(), + (left, Expr::Column(_)) => !left.any_column_refs(), (Expr::Cast(Cast { expr, .. }), right) - if matches!(expr.deref(), Expr::Column(_)) - && right.to_columns().unwrap().is_empty() => + if matches!(expr.deref(), Expr::Column(_)) => { - true + !right.any_column_refs() } (left, Expr::Cast(Cast { expr, .. })) - if matches!(expr.deref(), Expr::Column(_)) - && left.to_columns().unwrap().is_empty() => + if matches!(expr.deref(), Expr::Column(_)) => { - true + !left.any_column_refs() } (_, _) => false, } @@ -323,9 +321,10 @@ fn can_pullup_over_aggregation(expr: &Expr) -> bool { /// Check whether the window expressions contain a mixture of out reference columns and inner columns fn check_mixed_out_refer_in_window(window: &Window) -> Result<()> { - let mixed = window.window_expr.iter().any(|win_expr| { - win_expr.contains_outer() && !win_expr.to_columns().unwrap().is_empty() - }); + let mixed = window + .window_expr + .iter() + .any(|win_expr| win_expr.contains_outer() && win_expr.any_column_refs()); if mixed { plan_err!( "Window expressions should not contain a mixed of outer references and inner columns" diff --git a/datafusion/optimizer/src/decorrelate.rs b/datafusion/optimizer/src/decorrelate.rs index 7806a622ac0f..5f8e0a85215a 100644 --- a/datafusion/optimizer/src/decorrelate.rs +++ b/datafusion/optimizer/src/decorrelate.rs @@ -370,11 +370,14 @@ impl PullUpCorrelatedExpr { } } if let Some(pull_up_having) = &self.pull_up_having_expr { - let filter_apply_columns = pull_up_having.to_columns()?; + let filter_apply_columns = pull_up_having.column_refs(); for col in filter_apply_columns { - let col_expr = Expr::Column(col); - if !missing_exprs.contains(&col_expr) { - missing_exprs.push(col_expr) + // add to missing_exprs if not already there + let contains = missing_exprs + .iter() + .any(|expr| matches!(expr, Expr::Column(c) if c == col)); + if !contains { + missing_exprs.push(Expr::Column(col.clone())) } } } diff --git a/datafusion/optimizer/src/optimize_projections/mod.rs b/datafusion/optimizer/src/optimize_projections/mod.rs index 5749469f2ddc..50b2b1efad40 100644 --- a/datafusion/optimizer/src/optimize_projections/mod.rs +++ b/datafusion/optimizer/src/optimize_projections/mod.rs @@ -479,10 +479,10 @@ fn merge_consecutive_projections(proj: Projection) -> Result::new(); - for columns in expr.iter().flat_map(|expr| expr.to_columns()) { + let mut column_referral_map = HashMap::<&Column, usize>::new(); + for columns in expr.iter().map(|expr| expr.column_refs()) { for col in columns.into_iter() { - *column_referral_map.entry(col.clone()).or_default() += 1; + *column_referral_map.entry(col).or_default() += 1; } } @@ -493,7 +493,7 @@ fn merge_consecutive_projections(proj: Projection) -> Result 1 && !is_expr_trivial( &prev_projection.expr - [prev_projection.schema.index_of_column(&col).unwrap()], + [prev_projection.schema.index_of_column(col).unwrap()], ) }) { // no change @@ -625,12 +625,12 @@ fn rewrite_expr(expr: Expr, input: &Projection) -> Result> { /// * `expr` - The expression to analyze for outer-referenced columns. /// * `columns` - A mutable reference to a `HashSet` where detected /// columns are collected. -fn outer_columns(expr: &Expr, columns: &mut HashSet) { +fn outer_columns<'a>(expr: &'a Expr, columns: &mut HashSet<&'a Column>) { // inspect_expr_pre doesn't handle subquery references, so find them explicitly expr.apply(|expr| { match expr { Expr::OuterReferenceColumn(_, col) => { - columns.insert(col.clone()); + columns.insert(col); } Expr::ScalarSubquery(subquery) => { outer_columns_helper_multi(&subquery.outer_ref_columns, columns); @@ -660,9 +660,9 @@ fn outer_columns(expr: &Expr, columns: &mut HashSet) { /// * `exprs` - The expressions to analyze for outer-referenced columns. /// * `columns` - A mutable reference to a `HashSet` where detected /// columns are collected. -fn outer_columns_helper_multi<'a>( +fn outer_columns_helper_multi<'a, 'b>( exprs: impl IntoIterator, - columns: &mut HashSet, + columns: &'b mut HashSet<&'a Column>, ) { exprs.into_iter().for_each(|e| outer_columns(e, columns)); } diff --git a/datafusion/optimizer/src/optimize_projections/required_indices.rs b/datafusion/optimizer/src/optimize_projections/required_indices.rs index 113c100bbd9b..3f32a0c36a9a 100644 --- a/datafusion/optimizer/src/optimize_projections/required_indices.rs +++ b/datafusion/optimizer/src/optimize_projections/required_indices.rs @@ -113,12 +113,12 @@ impl RequiredIndicies { /// * `expr`: An expression for which we want to find necessary field indices. fn add_expr(&mut self, input_schema: &DFSchemaRef, expr: &Expr) -> Result<()> { // TODO could remove these clones (and visit the expression directly) - let mut cols = expr.to_columns()?; + let mut cols = expr.column_refs(); // Get outer-referenced (subquery) columns: outer_columns(expr, &mut cols); self.indices.reserve(cols.len()); for col in cols { - if let Some(idx) = input_schema.maybe_index_of_column(&col) { + if let Some(idx) = input_schema.maybe_index_of_column(col) { self.indices.push(idx); } } diff --git a/datafusion/optimizer/src/push_down_filter.rs b/datafusion/optimizer/src/push_down_filter.rs index 68339a84649d..89bcd6085cca 100644 --- a/datafusion/optimizer/src/push_down_filter.rs +++ b/datafusion/optimizer/src/push_down_filter.rs @@ -561,12 +561,9 @@ fn infer_join_predicates( .filter_map(|predicate| { let mut join_cols_to_replace = HashMap::new(); - let columns = match predicate.to_columns() { - Ok(columns) => columns, - Err(e) => return Some(Err(e)), - }; + let columns = predicate.column_refs(); - for col in columns.iter() { + for &col in columns.iter() { for (l, r) in join_col_keys.iter() { if col == *l { join_cols_to_replace.insert(col, *r); @@ -798,7 +795,7 @@ impl OptimizerRule for PushDownFilter { let mut keep_predicates = vec![]; let mut push_predicates = vec![]; for expr in predicates { - let cols = expr.to_columns()?; + let cols = expr.column_refs(); if cols.iter().all(|c| group_expr_columns.contains(c)) { push_predicates.push(expr); } else { @@ -899,7 +896,7 @@ impl OptimizerRule for PushDownFilter { let predicate_push_or_keep = split_conjunction(&filter.predicate) .iter() .map(|expr| { - let cols = expr.to_columns()?; + let cols = expr.column_refs(); if cols.iter().any(|c| prevent_cols.contains(&c.name)) { Ok(false) // No push (keep) } else { diff --git a/datafusion/optimizer/src/utils.rs b/datafusion/optimizer/src/utils.rs index 6218140409b5..00aaff196c3b 100644 --- a/datafusion/optimizer/src/utils.rs +++ b/datafusion/optimizer/src/utils.rs @@ -72,9 +72,9 @@ pub(crate) fn collect_subquery_cols( ) -> Result> { exprs.iter().try_fold(BTreeSet::new(), |mut cols, expr| { let mut using_cols: Vec = vec![]; - for col in expr.to_columns()?.into_iter() { - if subquery_schema.has_column(&col) { - using_cols.push(col); + for col in expr.column_refs().into_iter() { + if subquery_schema.has_column(col) { + using_cols.push(col.clone()); } } diff --git a/datafusion/sql/src/statement.rs b/datafusion/sql/src/statement.rs index d10956efb66c..cb492b390c76 100644 --- a/datafusion/sql/src/statement.rs +++ b/datafusion/sql/src/statement.rs @@ -964,7 +964,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { self.order_by_to_sort_expr(&expr, schema, planner_context, true, None)?; // Verify that columns of all SortExprs exist in the schema: for expr in expr_vec.iter() { - for column in expr.to_columns()?.iter() { + for column in expr.column_refs().iter() { if !schema.has_column(column) { // Return an error if any column is not in the schema: return plan_err!("Column {column} is not in schema");