Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 53 additions & 24 deletions datafusion/optimizer/src/common_subexpr_eliminate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

use crate::{OptimizerConfig, OptimizerRule};
use arrow::datatypes::DataType;
use datafusion_common::{DFField, DFSchema, Result};
use datafusion_common::{DFField, DFSchema, DataFusionError, Result};
use datafusion_expr::{
col,
expr::GroupingSet,
Expand Down Expand Up @@ -107,7 +107,7 @@ fn optimize(
)?;

Ok(LogicalPlan::Projection(Projection::try_new_with_schema(
new_expr.pop().unwrap(),
pop_expr(&mut new_expr)?,
Arc::new(new_input),
schema.clone(),
alias.clone(),
Expand Down Expand Up @@ -139,10 +139,16 @@ fn optimize(
optimizer_config,
)?;

Ok(LogicalPlan::Filter(Filter {
predicate: new_expr.pop().unwrap().pop().unwrap(),
input: Arc::new(new_input),
}))
if let Some(predicate) = pop_expr(&mut new_expr)?.pop() {
Ok(LogicalPlan::Filter(Filter {
predicate,
input: Arc::new(new_input),
}))
} else {
Err(DataFusionError::Internal(
"Failed to pop predicate expr".to_string(),
))
}
}
LogicalPlan::Window(Window {
input,
Expand All @@ -161,7 +167,7 @@ fn optimize(

Ok(LogicalPlan::Window(Window {
input: Arc::new(new_input),
window_expr: new_expr.pop().unwrap(),
window_expr: pop_expr(&mut new_expr)?,
schema: schema.clone(),
}))
}
Expand All @@ -182,8 +188,8 @@ fn optimize(
optimizer_config,
)?;
// note the reversed pop order.
let new_aggr_expr = new_expr.pop().unwrap();
let new_group_expr = new_expr.pop().unwrap();
let new_aggr_expr = pop_expr(&mut new_expr)?;
let new_group_expr = pop_expr(&mut new_expr)?;

Ok(LogicalPlan::Aggregate(Aggregate::try_new(
Arc::new(new_input),
Expand All @@ -204,7 +210,7 @@ fn optimize(
)?;

Ok(LogicalPlan::Sort(Sort {
expr: new_expr.pop().unwrap(),
expr: pop_expr(&mut new_expr)?,
input: Arc::new(new_input),
}))
}
Expand Down Expand Up @@ -241,6 +247,12 @@ fn optimize(
}
}

fn pop_expr(new_expr: &mut Vec<Vec<Expr>>) -> Result<Vec<Expr>> {
new_expr
.pop()
.ok_or_else(|| DataFusionError::Internal("Failed to pop expression".to_string()))
}

fn to_arrays(
expr: &[Expr],
input: &LogicalPlan,
Expand Down Expand Up @@ -268,12 +280,20 @@ fn build_project_plan(
let mut fields_set = HashSet::new();

for id in affected_id {
let (expr, _, data_type) = expr_set.get(&id).unwrap();
// todo: check `nullable`
let field = DFField::new(None, &id, data_type.clone(), true);
fields_set.insert(field.name().to_owned());
fields.push(field);
project_exprs.push(expr.clone().alias(&id));
match expr_set.get(&id) {
Some((expr, _, data_type)) => {
// todo: check `nullable`
let field = DFField::new(None, &id, data_type.clone(), true);
fields_set.insert(field.name().to_owned());
fields.push(field);
project_exprs.push(expr.clone().alias(&id));
}
_ => {
return Err(DataFusionError::Internal(
"expr_set invalid state".to_string(),
))
}
}
}

for field in input.schema().fields() {
Expand Down Expand Up @@ -639,13 +659,19 @@ impl ExprRewriter for CommonSubexprRewriter<'_> {
self.curr_index += 1;
return Ok(RewriteRecursion::Skip);
}
let (_, counter, _) = self.expr_set.get(curr_id).unwrap();
if *counter > 1 {
self.affected_id.insert(curr_id.clone());
Ok(RewriteRecursion::Mutate)
} else {
self.curr_index += 1;
Ok(RewriteRecursion::Skip)
match self.expr_set.get(curr_id) {
Some((_, counter, _)) => {
if *counter > 1 {
self.affected_id.insert(curr_id.clone());
Ok(RewriteRecursion::Mutate)
} else {
self.curr_index += 1;
Ok(RewriteRecursion::Skip)
}
}
_ => Err(DataFusionError::Internal(
"expr_set invalid state".to_string(),
)),
}
}

Expand All @@ -658,9 +684,12 @@ impl ExprRewriter for CommonSubexprRewriter<'_> {
let (series_number, id) = &self.id_array[self.curr_index];
self.curr_index += 1;
// Skip sub-node of a replaced tree, or without identifier, or is not repeated expr.
let expr_set_item = self.expr_set.get(id).ok_or_else(|| {
DataFusionError::Internal("expr_set invalid state".to_string())
})?;
if *series_number < self.max_series_number
|| id.is_empty()
|| self.expr_set.get(id).unwrap().1 <= 1
|| expr_set_item.1 <= 1
{
return Ok(expr);
}
Expand Down