Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 10 additions & 6 deletions datafusion/optimizer/src/optimizer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -178,16 +178,15 @@ impl Optimizer {
F: FnMut(&LogicalPlan, &dyn OptimizerRule),
{
let mut new_plan = plan.clone();
debug!("Input logical plan:\n{}\n", plan.display_indent());
trace!("Full input logical plan:\n{:?}", plan);
log_plan("Optimizer input", plan);
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a driveby cleanup to improve logging (specifically, also add trace! to log schema)


for rule in &self.rules {
let result = rule.optimize(&new_plan, optimizer_config);
match result {
Ok(plan) => {
new_plan = plan;
observer(&new_plan, rule.as_ref());
debug!("After apply {} rule:\n", rule.name());
debug!("Optimized logical plan:\n{}\n", new_plan.display_indent());
log_plan(rule.name(), &new_plan);
}
Err(ref e) => {
if optimizer_config.skip_failing_rules {
Expand All @@ -209,12 +208,17 @@ impl Optimizer {
}
}
}
debug!("Optimized logical plan:\n{}\n", new_plan.display_indent());
trace!("Full Optimized logical plan:\n {:?}", new_plan);
log_plan("Optimized plan", &new_plan);
Ok(new_plan)
}
}

/// Log the plan in debug/tracing mode after some part of the optimizer runs
fn log_plan(description: &str, plan: &LogicalPlan) {
debug!("{description}:\n{}\n", plan.display_indent());
trace!("{description}::\n{}\n", plan.display_indent_schema());
}

#[cfg(test)]
mod tests {
use crate::optimizer::Optimizer;
Expand Down
37 changes: 36 additions & 1 deletion datafusion/optimizer/src/unwrap_cast_in_comparison.rs
Original file line number Diff line number Diff line change
Expand Up @@ -97,12 +97,47 @@ fn optimize(plan: &LogicalPlan) -> Result<LogicalPlan> {
let new_exprs = plan
.expressions()
.into_iter()
.map(|expr| expr.rewrite(&mut expr_rewriter))
.map(|expr| {
let original_name = name_for_alias(&expr)?;
let expr = expr.rewrite(&mut expr_rewriter)?;
add_alias_if_changed(&original_name, expr)
})
.collect::<Result<Vec<_>>>()?;

from_plan(plan, new_exprs.as_slice(), new_inputs.as_slice())
}

fn name_for_alias(expr: &Expr) -> Result<String> {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I plan to make this easier on the eyes as a follow on PR

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

follow on #3727

match expr {
Expr::Sort { expr, .. } => name_for_alias(expr),
expr => expr.name(),
}
}

fn add_alias_if_changed(original_name: &str, expr: Expr) -> Result<Expr> {
let new_name = name_for_alias(&expr)?;

if new_name == original_name {
return Ok(expr);
}

Ok(match expr {
Expr::Sort {
expr,
asc,
nulls_first,
} => {
let expr = add_alias_if_changed(original_name, *expr)?;
Expr::Sort {
expr: Box::new(expr),
asc,
nulls_first,
}
}
expr => expr.alias(original_name),
})
}

struct UnwrapCastExprRewriter {
schema: DFSchemaRef,
}
Expand Down
21 changes: 19 additions & 2 deletions datafusion/optimizer/tests/integration-test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,19 @@ use std::any::Any;
use std::collections::HashMap;
use std::sync::Arc;

#[cfg(test)]
#[ctor::ctor]
fn init() {
let _ = env_logger::try_init();
}

#[test]
fn case_when() -> Result<()> {
let sql = "SELECT CASE WHEN col_int32 > 0 THEN 1 ELSE 0 END FROM test";
let plan = test_sql(sql)?;
let expected =
"Projection: CASE WHEN test.col_int32 > Int32(0) THEN Int64(1) ELSE Int64(0) END\
\n TableScan: test projection=[col_int32]";
"Projection: CASE WHEN test.col_int32 > Int32(0) THEN Int64(1) ELSE Int64(0) END AS CASE WHEN test.col_int32 > Int64(0) THEN Int64(1) ELSE Int64(0) END\
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FYI @andygrove the alias was added to this as well

\n TableScan: test projection=[col_int32]";
assert_eq!(expected, format!("{:?}", plan));

let sql = "SELECT CASE WHEN col_uint32 > 0 THEN 1 ELSE 0 END FROM test";
Expand All @@ -46,6 +52,17 @@ fn case_when() -> Result<()> {
Ok(())
}

#[test]
fn case_when_aggregate() -> Result<()> {
let sql = "SELECT col_utf8, SUM(CASE WHEN col_int32 > 0 THEN 1 ELSE 0 END) AS n FROM test GROUP BY col_utf8";
let plan = test_sql(sql)?;
let expected = "Projection: test.col_utf8, SUM(CASE WHEN test.col_int32 > Int64(0) THEN Int64(1) ELSE Int64(0) END) AS n\
\n Aggregate: groupBy=[[test.col_utf8]], aggr=[[SUM(CASE WHEN test.col_int32 > Int32(0) THEN Int64(1) ELSE Int64(0) END) AS SUM(CASE WHEN test.col_int32 > Int64(0) THEN Int64(1) ELSE Int64(0) END)]]\
\n TableScan: test projection=[col_int32, col_utf8]";
assert_eq!(expected, format!("{:?}", plan));
Ok(())
}

#[test]
fn unsigned_target_type() -> Result<()> {
let sql = "SELECT * FROM test WHERE col_uint32 > 0";
Expand Down