Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 45 additions & 3 deletions datafusion/expr/src/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,9 @@ use crate::{
use crate::{window_frame, Volatility};

use arrow::datatypes::{DataType, FieldRef};
use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode};
use datafusion_common::tree_node::{
Transformed, TransformedResult, TreeNode, TreeNodeRecursion,
};
use datafusion_common::{
internal_err, plan_err, Column, DFSchema, Result, ScalarValue, TableReference,
};
Expand Down Expand Up @@ -1333,6 +1335,46 @@ impl Expr {
Ok(using_columns)
}

/// Return all references to columns in this expression.
///
/// # Example
/// ```
/// # use std::collections::HashSet;
/// # use datafusion_common::Column;
/// # use datafusion_expr::col;
/// // For an expression `a + (b * a)`
/// let expr = col("a") + (col("b") * col("a"));
/// let refs = expr.column_refs();
/// // refs contains "a" and "b"
/// assert_eq!(refs.len(), 2);
/// assert!(refs.contains(&Column::new_unqualified("a")));
/// assert!(refs.contains(&Column::new_unqualified("b")));
/// ```
pub fn column_refs(&self) -> HashSet<&Column> {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The implementation of this function is quite nice now compared to the expr_to_columns one: https://github.com/alamb/datafusion/blob/58d0c34d77c9a5202e62b9281cdbf1046abaa096/datafusion/expr/src/utils.rs#L264-L309

let mut using_columns = HashSet::new();
self.add_column_refs(&mut using_columns);
using_columns
}

/// Adds references to all columns in this expression to the set
///
/// See [`Self::column_refs`] for details
pub fn add_column_refs<'a>(&'a self, set: &mut HashSet<&'a Column>) {
self.apply(|expr| {
if let Expr::Column(col) = expr {
set.insert(col);
}
Ok(TreeNodeRecursion::Continue)
})
.expect("traversal is infallable");
}

/// Returns true if there are any column references in this Expr
pub fn any_column_refs(&self) -> bool {
self.exists(|expr| Ok(matches!(expr, Expr::Column(_))))
.unwrap()
}

/// Return true when the expression contains out reference(correlated) expressions.
pub fn contains_outer(&self) -> bool {
self.exists(|expr| Ok(matches!(expr, Expr::OuterReferenceColumn { .. })))
Expand Down Expand Up @@ -2038,15 +2080,15 @@ mod test {
// single column
{
let expr = &Expr::Cast(Cast::new(Box::new(col("a")), DataType::Float64));
let columns = expr.to_columns()?;
let columns = expr.column_refs();
assert_eq!(1, columns.len());
assert!(columns.contains(&Column::from_name("a")));
}

// multiple columns
{
let expr = col("a") + col("b") + lit(1);
let columns = expr.to_columns()?;
let columns = expr.column_refs();
assert_eq!(2, columns.len());
assert!(columns.contains(&Column::from_name("a")));
assert!(columns.contains(&Column::from_name("b")));
Expand Down
1 change: 1 addition & 0 deletions datafusion/expr/src/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ pub const COUNT_STAR_EXPANSION: ScalarValue = ScalarValue::Int64(Some(1));

/// Recursively walk a list of expression trees, collecting the unique set of columns
/// referenced in the expression
#[deprecated(since = "40.0.0", note = "Expr::add_column_refs instead")]
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this function is not used anywhere in the datafusion codebase

pub fn exprlist_to_columns(expr: &[Expr], accum: &mut HashSet<Column>) -> Result<()> {
for e in expr {
expr_to_columns(e, accum)?;
Expand Down
21 changes: 10 additions & 11 deletions datafusion/optimizer/src/analyzer/subquery.rs
Original file line number Diff line number Diff line change
Expand Up @@ -300,19 +300,17 @@ fn can_pullup_over_aggregation(expr: &Expr) -> bool {
}) = expr
{
match (left.deref(), right.deref()) {
(Expr::Column(_), right) if right.to_columns().unwrap().is_empty() => true,
(left, Expr::Column(_)) if left.to_columns().unwrap().is_empty() => true,
(Expr::Column(_), right) => !right.any_column_refs(),
(left, Expr::Column(_)) => !left.any_column_refs(),
(Expr::Cast(Cast { expr, .. }), right)
if matches!(expr.deref(), Expr::Column(_))
&& right.to_columns().unwrap().is_empty() =>
if matches!(expr.deref(), Expr::Column(_)) =>
{
true
!right.any_column_refs()
}
(left, Expr::Cast(Cast { expr, .. }))
if matches!(expr.deref(), Expr::Column(_))
&& left.to_columns().unwrap().is_empty() =>
if matches!(expr.deref(), Expr::Column(_)) =>
{
true
!left.any_column_refs()
}
(_, _) => false,
}
Expand All @@ -323,9 +321,10 @@ fn can_pullup_over_aggregation(expr: &Expr) -> bool {

/// Check whether the window expressions contain a mixture of out reference columns and inner columns
fn check_mixed_out_refer_in_window(window: &Window) -> Result<()> {
let mixed = window.window_expr.iter().any(|win_expr| {
win_expr.contains_outer() && !win_expr.to_columns().unwrap().is_empty()
});
let mixed = window
.window_expr
.iter()
.any(|win_expr| win_expr.contains_outer() && win_expr.any_column_refs());
if mixed {
plan_err!(
"Window expressions should not contain a mixed of outer references and inner columns"
Expand Down
11 changes: 7 additions & 4 deletions datafusion/optimizer/src/decorrelate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -370,11 +370,14 @@ impl PullUpCorrelatedExpr {
}
}
if let Some(pull_up_having) = &self.pull_up_having_expr {
let filter_apply_columns = pull_up_having.to_columns()?;
let filter_apply_columns = pull_up_having.column_refs();
for col in filter_apply_columns {
let col_expr = Expr::Column(col);
if !missing_exprs.contains(&col_expr) {
missing_exprs.push(col_expr)
// add to missing_exprs if not already there
let contains = missing_exprs
.iter()
.any(|expr| matches!(expr, Expr::Column(c) if c == col));
if !contains {
missing_exprs.push(Expr::Column(col.clone()))
}
}
}
Expand Down
16 changes: 8 additions & 8 deletions datafusion/optimizer/src/optimize_projections/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -479,10 +479,10 @@ fn merge_consecutive_projections(proj: Projection) -> Result<Transformed<Project
};

// Count usages (referrals) of each projection expression in its input fields:
let mut column_referral_map = HashMap::<Column, usize>::new();
for columns in expr.iter().flat_map(|expr| expr.to_columns()) {
let mut column_referral_map = HashMap::<&Column, usize>::new();
for columns in expr.iter().map(|expr| expr.column_refs()) {
for col in columns.into_iter() {
*column_referral_map.entry(col.clone()).or_default() += 1;
*column_referral_map.entry(col).or_default() += 1;
}
}

Expand All @@ -493,7 +493,7 @@ fn merge_consecutive_projections(proj: Projection) -> Result<Transformed<Project
usage > 1
&& !is_expr_trivial(
&prev_projection.expr
[prev_projection.schema.index_of_column(&col).unwrap()],
[prev_projection.schema.index_of_column(col).unwrap()],
)
}) {
// no change
Expand Down Expand Up @@ -625,12 +625,12 @@ fn rewrite_expr(expr: Expr, input: &Projection) -> Result<Transformed<Expr>> {
/// * `expr` - The expression to analyze for outer-referenced columns.
/// * `columns` - A mutable reference to a `HashSet<Column>` where detected
/// columns are collected.
fn outer_columns(expr: &Expr, columns: &mut HashSet<Column>) {
fn outer_columns<'a>(expr: &'a Expr, columns: &mut HashSet<&'a Column>) {
// inspect_expr_pre doesn't handle subquery references, so find them explicitly
expr.apply(|expr| {
match expr {
Expr::OuterReferenceColumn(_, col) => {
columns.insert(col.clone());
columns.insert(col);
}
Expr::ScalarSubquery(subquery) => {
outer_columns_helper_multi(&subquery.outer_ref_columns, columns);
Expand Down Expand Up @@ -660,9 +660,9 @@ fn outer_columns(expr: &Expr, columns: &mut HashSet<Column>) {
/// * `exprs` - The expressions to analyze for outer-referenced columns.
/// * `columns` - A mutable reference to a `HashSet<Column>` where detected
/// columns are collected.
fn outer_columns_helper_multi<'a>(
fn outer_columns_helper_multi<'a, 'b>(
exprs: impl IntoIterator<Item = &'a Expr>,
columns: &mut HashSet<Column>,
columns: &'b mut HashSet<&'a Column>,
) {
exprs.into_iter().for_each(|e| outer_columns(e, columns));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -113,12 +113,12 @@ impl RequiredIndicies {
/// * `expr`: An expression for which we want to find necessary field indices.
fn add_expr(&mut self, input_schema: &DFSchemaRef, expr: &Expr) -> Result<()> {
// TODO could remove these clones (and visit the expression directly)
let mut cols = expr.to_columns()?;
let mut cols = expr.column_refs();
// Get outer-referenced (subquery) columns:
outer_columns(expr, &mut cols);
self.indices.reserve(cols.len());
for col in cols {
if let Some(idx) = input_schema.maybe_index_of_column(&col) {
if let Some(idx) = input_schema.maybe_index_of_column(col) {
self.indices.push(idx);
}
}
Expand Down
11 changes: 4 additions & 7 deletions datafusion/optimizer/src/push_down_filter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -561,12 +561,9 @@ fn infer_join_predicates(
.filter_map(|predicate| {
let mut join_cols_to_replace = HashMap::new();

let columns = match predicate.to_columns() {
Ok(columns) => columns,
Err(e) => return Some(Err(e)),
};
let columns = predicate.column_refs();

for col in columns.iter() {
for &col in columns.iter() {
for (l, r) in join_col_keys.iter() {
if col == *l {
join_cols_to_replace.insert(col, *r);
Expand Down Expand Up @@ -798,7 +795,7 @@ impl OptimizerRule for PushDownFilter {
let mut keep_predicates = vec![];
let mut push_predicates = vec![];
for expr in predicates {
let cols = expr.to_columns()?;
let cols = expr.column_refs();
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a pretty good example where there is no need to copy Columns simply to check if they are referenced.

if cols.iter().all(|c| group_expr_columns.contains(c)) {
push_predicates.push(expr);
} else {
Expand Down Expand Up @@ -899,7 +896,7 @@ impl OptimizerRule for PushDownFilter {
let predicate_push_or_keep = split_conjunction(&filter.predicate)
.iter()
.map(|expr| {
let cols = expr.to_columns()?;
let cols = expr.column_refs();
if cols.iter().any(|c| prevent_cols.contains(&c.name)) {
Ok(false) // No push (keep)
} else {
Expand Down
6 changes: 3 additions & 3 deletions datafusion/optimizer/src/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -72,9 +72,9 @@ pub(crate) fn collect_subquery_cols(
) -> Result<BTreeSet<Column>> {
exprs.iter().try_fold(BTreeSet::new(), |mut cols, expr| {
let mut using_cols: Vec<Column> = vec![];
for col in expr.to_columns()?.into_iter() {
if subquery_schema.has_column(&col) {
using_cols.push(col);
for col in expr.column_refs().into_iter() {
if subquery_schema.has_column(col) {
using_cols.push(col.clone());
}
}

Expand Down
2 changes: 1 addition & 1 deletion datafusion/sql/src/statement.rs
Original file line number Diff line number Diff line change
Expand Up @@ -964,7 +964,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
self.order_by_to_sort_expr(&expr, schema, planner_context, true, None)?;
// Verify that columns of all SortExprs exist in the schema:
for expr in expr_vec.iter() {
for column in expr.to_columns()?.iter() {
for column in expr.column_refs().iter() {
if !schema.has_column(column) {
// Return an error if any column is not in the schema:
return plan_err!("Column {column} is not in schema");
Expand Down