Skip to content

Commit

Permalink
Fix aggregate type coercion bug (#3710)
Browse files Browse the repository at this point in the history
* Do not change output expr name in `UnwrapCastInComparison`

* Update

* Update test

* Fix regression

* Update tests

* clippy
  • Loading branch information
alamb committed Oct 5, 2022
1 parent 965133c commit 64669e9
Show file tree
Hide file tree
Showing 3 changed files with 65 additions and 9 deletions.
16 changes: 10 additions & 6 deletions datafusion/optimizer/src/optimizer.rs
Expand Up @@ -178,16 +178,15 @@ impl Optimizer {
F: FnMut(&LogicalPlan, &dyn OptimizerRule),
{
let mut new_plan = plan.clone();
debug!("Input logical plan:\n{}\n", plan.display_indent());
trace!("Full input logical plan:\n{:?}", plan);
log_plan("Optimizer input", plan);

for rule in &self.rules {
let result = rule.optimize(&new_plan, optimizer_config);
match result {
Ok(plan) => {
new_plan = plan;
observer(&new_plan, rule.as_ref());
debug!("After apply {} rule:\n", rule.name());
debug!("Optimized logical plan:\n{}\n", new_plan.display_indent());
log_plan(rule.name(), &new_plan);
}
Err(ref e) => {
if optimizer_config.skip_failing_rules {
Expand All @@ -209,12 +208,17 @@ impl Optimizer {
}
}
}
debug!("Optimized logical plan:\n{}\n", new_plan.display_indent());
trace!("Full Optimized logical plan:\n {:?}", new_plan);
log_plan("Optimized plan", &new_plan);
Ok(new_plan)
}
}

/// Log the plan in debug/tracing mode after some part of the optimizer runs
fn log_plan(description: &str, plan: &LogicalPlan) {
debug!("{description}:\n{}\n", plan.display_indent());
trace!("{description}::\n{}\n", plan.display_indent_schema());
}

#[cfg(test)]
mod tests {
use crate::optimizer::Optimizer;
Expand Down
37 changes: 36 additions & 1 deletion datafusion/optimizer/src/unwrap_cast_in_comparison.rs
Expand Up @@ -97,12 +97,47 @@ fn optimize(plan: &LogicalPlan) -> Result<LogicalPlan> {
let new_exprs = plan
.expressions()
.into_iter()
.map(|expr| expr.rewrite(&mut expr_rewriter))
.map(|expr| {
let original_name = name_for_alias(&expr)?;
let expr = expr.rewrite(&mut expr_rewriter)?;
add_alias_if_changed(&original_name, expr)
})
.collect::<Result<Vec<_>>>()?;

from_plan(plan, new_exprs.as_slice(), new_inputs.as_slice())
}

fn name_for_alias(expr: &Expr) -> Result<String> {
match expr {
Expr::Sort { expr, .. } => name_for_alias(expr),
expr => expr.name(),
}
}

fn add_alias_if_changed(original_name: &str, expr: Expr) -> Result<Expr> {
let new_name = name_for_alias(&expr)?;

if new_name == original_name {
return Ok(expr);
}

Ok(match expr {
Expr::Sort {
expr,
asc,
nulls_first,
} => {
let expr = add_alias_if_changed(original_name, *expr)?;
Expr::Sort {
expr: Box::new(expr),
asc,
nulls_first,
}
}
expr => expr.alias(original_name),
})
}

struct UnwrapCastExprRewriter {
schema: DFSchemaRef,
}
Expand Down
21 changes: 19 additions & 2 deletions datafusion/optimizer/tests/integration-test.rs
Expand Up @@ -29,13 +29,19 @@ use std::any::Any;
use std::collections::HashMap;
use std::sync::Arc;

#[cfg(test)]
#[ctor::ctor]
fn init() {
let _ = env_logger::try_init();
}

#[test]
fn case_when() -> Result<()> {
let sql = "SELECT CASE WHEN col_int32 > 0 THEN 1 ELSE 0 END FROM test";
let plan = test_sql(sql)?;
let expected =
"Projection: CASE WHEN test.col_int32 > Int32(0) THEN Int64(1) ELSE Int64(0) END\
\n TableScan: test projection=[col_int32]";
"Projection: CASE WHEN test.col_int32 > Int32(0) THEN Int64(1) ELSE Int64(0) END AS CASE WHEN test.col_int32 > Int64(0) THEN Int64(1) ELSE Int64(0) END\
\n TableScan: test projection=[col_int32]";
assert_eq!(expected, format!("{:?}", plan));

let sql = "SELECT CASE WHEN col_uint32 > 0 THEN 1 ELSE 0 END FROM test";
Expand All @@ -46,6 +52,17 @@ fn case_when() -> Result<()> {
Ok(())
}

#[test]
fn case_when_aggregate() -> Result<()> {
let sql = "SELECT col_utf8, SUM(CASE WHEN col_int32 > 0 THEN 1 ELSE 0 END) AS n FROM test GROUP BY col_utf8";
let plan = test_sql(sql)?;
let expected = "Projection: test.col_utf8, SUM(CASE WHEN test.col_int32 > Int64(0) THEN Int64(1) ELSE Int64(0) END) AS n\
\n Aggregate: groupBy=[[test.col_utf8]], aggr=[[SUM(CASE WHEN test.col_int32 > Int32(0) THEN Int64(1) ELSE Int64(0) END) AS SUM(CASE WHEN test.col_int32 > Int64(0) THEN Int64(1) ELSE Int64(0) END)]]\
\n TableScan: test projection=[col_int32, col_utf8]";
assert_eq!(expected, format!("{:?}", plan));
Ok(())
}

#[test]
fn unsigned_target_type() -> Result<()> {
let sql = "SELECT * FROM test WHERE col_uint32 > 0";
Expand Down

0 comments on commit 64669e9

Please sign in to comment.