Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 8 additions & 10 deletions datafusion/expr/src/expr_rewriter/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -318,21 +318,20 @@ impl NamePreserver {
Self { use_alias: true }
}

pub fn save(&self, expr: &Expr) -> Result<SavedName> {
let original_name = if self.use_alias {
pub fn save(&self, expr: &Expr) -> SavedName {
if self.use_alias {
let (relation, name) = expr.qualified_name();
SavedName::Saved { relation, name }
} else {
SavedName::None
};
Ok(original_name)
}
}
}

impl SavedName {
/// Ensures the qualified name of the rewritten expression is preserved
pub fn restore(self, expr: Expr) -> Result<Expr> {
let expr = match self {
pub fn restore(self, expr: Expr) -> Expr {
match self {
SavedName::Saved { relation, name } => {
let (new_relation, new_name) = expr.qualified_name();
if new_relation != relation || new_name != name {
Expand All @@ -342,8 +341,7 @@ impl SavedName {
}
}
SavedName::None => expr,
};
Ok(expr)
}
}
}

Expand Down Expand Up @@ -543,9 +541,9 @@ mod test {
let mut rewriter = TestRewriter {
rewrite_to: rewrite_to.clone(),
};
let saved_name = NamePreserver { use_alias: true }.save(&expr_from).unwrap();
let saved_name = NamePreserver { use_alias: true }.save(&expr_from);
let new_expr = expr_from.clone().rewrite(&mut rewriter).unwrap().data;
let new_expr = saved_name.restore(new_expr).unwrap();
let new_expr = saved_name.restore(new_expr);

let original_name = expr_from.qualified_name();
let new_name = new_expr.qualified_name();
Expand Down
4 changes: 2 additions & 2 deletions datafusion/expr/src/logical_plan/plan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1441,7 +1441,7 @@ impl LogicalPlan {
let schema = Arc::clone(plan.schema());
let name_preserver = NamePreserver::new(&plan);
plan.map_expressions(|e| {
let original_name = name_preserver.save(&e)?;
let original_name = name_preserver.save(&e);
let transformed_expr =
e.infer_placeholder_types(&schema)?.transform_up(|e| {
if let Expr::Placeholder(Placeholder { id, .. }) = e {
Expand All @@ -1452,7 +1452,7 @@ impl LogicalPlan {
}
})?;
// Preserve name to avoid breaking column references to this expression
transformed_expr.map_data(|expr| original_name.restore(expr))
Ok(transformed_expr.update_data(|expr| original_name.restore(expr)))
})
})
.map(|res| res.data)
Expand Down
4 changes: 2 additions & 2 deletions datafusion/optimizer/src/analyzer/count_wildcard_rule.rs
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ fn is_count_star_window_aggregate(window_function: &WindowFunction) -> bool {
fn analyze_internal(plan: LogicalPlan) -> Result<Transformed<LogicalPlan>> {
let name_preserver = NamePreserver::new(&plan);
plan.map_expressions(|expr| {
let original_name = name_preserver.save(&expr)?;
let original_name = name_preserver.save(&expr);
let transformed_expr = expr.transform_up(|expr| match expr {
Expr::WindowFunction(mut window_function)
if is_count_star_window_aggregate(&window_function) =>
Expand All @@ -94,7 +94,7 @@ fn analyze_internal(plan: LogicalPlan) -> Result<Transformed<LogicalPlan>> {
}
_ => Ok(Transformed::no(expr)),
})?;
transformed_expr.map_data(|data| original_name.restore(data))
Ok(transformed_expr.update_data(|data| original_name.restore(data)))
})
}

Expand Down
4 changes: 2 additions & 2 deletions datafusion/optimizer/src/analyzer/function_rewrite.rs
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ impl ApplyFunctionRewrites {
let name_preserver = NamePreserver::new(&plan);

plan.map_expressions(|expr| {
let original_name = name_preserver.save(&expr)?;
let original_name = name_preserver.save(&expr);

// recursively transform the expression, applying the rewrites at each step
let transformed_expr = expr.transform_up(|expr| {
Expand All @@ -74,7 +74,7 @@ impl ApplyFunctionRewrites {
Ok(result)
})?;

transformed_expr.map_data(|expr| original_name.restore(expr))
Ok(transformed_expr.update_data(|expr| original_name.restore(expr)))
})
}
}
Expand Down
6 changes: 3 additions & 3 deletions datafusion/optimizer/src/analyzer/type_coercion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -119,9 +119,9 @@ fn analyze_internal(
let name_preserver = NamePreserver::new(&plan);
// apply coercion rewrite all expressions in the plan individually
plan.map_expressions(|expr| {
let original_name = name_preserver.save(&expr)?;
expr.rewrite(&mut expr_rewrite)?
.map_data(|expr| original_name.restore(expr))
let original_name = name_preserver.save(&expr);
expr.rewrite(&mut expr_rewrite)
.map(|transformed| transformed.update_data(|e| original_name.restore(e)))
})?
// some plans need extra coercion after their expressions are coerced
.map_data(|plan| expr_rewrite.coerce_plan(plan))?
Expand Down
12 changes: 6 additions & 6 deletions datafusion/optimizer/src/common_subexpr_eliminate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -414,9 +414,9 @@ impl CommonSubexprEliminate {
exprs
.iter()
.map(|expr| name_preserver.save(expr))
.collect::<Result<Vec<_>>>()
.collect::<Vec<_>>()
})
.collect::<Result<Vec<_>>>()?;
.collect::<Vec<_>>();
new_window_expr_list.into_iter().zip(saved_names).try_rfold(
new_input,
|plan, (new_window_expr, saved_names)| {
Expand All @@ -426,7 +426,7 @@ impl CommonSubexprEliminate {
.map(|(new_window_expr, saved_name)| {
saved_name.restore(new_window_expr)
})
.collect::<Result<Vec<_>>>()?;
.collect::<Vec<_>>();
Window::try_new(new_window_expr, Arc::new(plan))
.map(LogicalPlan::Window)
},
Expand Down Expand Up @@ -604,14 +604,14 @@ impl CommonSubexprEliminate {
let saved_names = aggr_expr
.iter()
.map(|expr| name_perserver.save(expr))
.collect::<Result<Vec<_>>>()?;
.collect::<Vec<_>>();
let new_aggr_expr = rewritten_aggr_expr
.into_iter()
.zip(saved_names.into_iter())
.zip(saved_names)
.map(|(new_expr, saved_name)| {
saved_name.restore(new_expr)
})
.collect::<Result<Vec<Expr>>>()?;
.collect::<Vec<Expr>>();

// Since `group_expr` may have changed, schema may also.
// Use `try_new()` method.
Expand Down
6 changes: 3 additions & 3 deletions datafusion/optimizer/src/optimize_projections/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -497,7 +497,7 @@ fn merge_consecutive_projections(proj: Projection) -> Result<Transformed<Project
let name_preserver = NamePreserver::new_for_projection();
let mut original_names = vec![];
let new_exprs = expr.into_iter().map_until_stop_and_collect(|expr| {
original_names.push(name_preserver.save(&expr)?);
original_names.push(name_preserver.save(&expr));

// do not rewrite top level Aliases (rewriter will remove all aliases within exprs)
match expr {
Expand All @@ -519,9 +519,9 @@ fn merge_consecutive_projections(proj: Projection) -> Result<Transformed<Project
let new_exprs = new_exprs
.data
.into_iter()
.zip(original_names.into_iter())
.zip(original_names)
.map(|(expr, original_name)| original_name.restore(expr))
.collect::<Result<Vec<_>>>()?;
.collect::<Vec<_>>();
Projection::try_new(new_exprs, prev_projection.input).map(Transformed::yes)
} else {
// not rewritten, so put the projection back together
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -123,10 +123,10 @@ impl SimplifyExpressions {
// Preserve expression names to avoid changing the schema of the plan.
let name_preserver = NamePreserver::new(&plan);
plan.map_expressions(|e| {
let original_name = name_preserver.save(&e)?;
let original_name = name_preserver.save(&e);
let new_e = simplifier
.simplify(e)
.and_then(|expr| original_name.restore(expr))?;
.map(|expr| original_name.restore(expr))?;
// TODO it would be nice to have a way to know if the expression was simplified
// or not. For now conservatively return Transformed::yes
Ok(Transformed::yes(new_e))
Expand Down
6 changes: 3 additions & 3 deletions datafusion/optimizer/src/unwrap_cast_in_comparison.rs
Original file line number Diff line number Diff line change
Expand Up @@ -117,9 +117,9 @@ impl OptimizerRule for UnwrapCastInComparison {

let name_preserver = NamePreserver::new(&plan);
plan.map_expressions(|expr| {
let original_name = name_preserver.save(&expr)?;
expr.rewrite(&mut expr_rewriter)?
.map_data(|expr| original_name.restore(expr))
let original_name = name_preserver.save(&expr);
expr.rewrite(&mut expr_rewriter)
.map(|transformed| transformed.update_data(|e| original_name.restore(e)))
})
}
}
Expand Down