diff --git a/datafusion/core/tests/sql/group_by.rs b/datafusion/core/tests/sql/group_by.rs index e3da1b02195a..2e1007be81c9 100644 --- a/datafusion/core/tests/sql/group_by.rs +++ b/datafusion/core/tests/sql/group_by.rs @@ -681,3 +681,75 @@ async fn group_by_dictionary() { run_test_case::().await; run_test_case::().await; } + +#[tokio::test] +async fn csv_query_group_by_order_by_substr() -> Result<()> { + let ctx = SessionContext::new(); + register_aggregate_csv(&ctx).await?; + let sql = "SELECT substr(c1, 1, 1), avg(c12) \ + FROM aggregate_test_100 \ + GROUP BY substr(c1, 1, 1) \ + ORDER BY substr(c1, 1, 1)"; + let actual = execute_to_batches(&ctx, sql).await; + let expected = vec![ + "+-------------------------------------------------+-----------------------------+", + "| substr(aggregate_test_100.c1,Int64(1),Int64(1)) | AVG(aggregate_test_100.c12) |", + "+-------------------------------------------------+-----------------------------+", + "| a | 0.48754517466109415 |", + "| b | 0.41040709263815384 |", + "| c | 0.6600456536439784 |", + "| d | 0.48855379387549824 |", + "| e | 0.48600669271341534 |", + "+-------------------------------------------------+-----------------------------+", + ]; + assert_batches_sorted_eq!(expected, &actual); + Ok(()) +} + +#[tokio::test] +async fn csv_query_group_by_order_by_substr_aliased_projection() -> Result<()> { + let ctx = SessionContext::new(); + register_aggregate_csv(&ctx).await?; + let sql = "SELECT substr(c1, 1, 1) as name, avg(c12) as average \ + FROM aggregate_test_100 \ + GROUP BY substr(c1, 1, 1) \ + ORDER BY substr(c1, 1, 1)"; + let actual = execute_to_batches(&ctx, sql).await; + let expected = vec![ + "+------+---------------------+", + "| name | average |", + "+------+---------------------+", + "| a | 0.48754517466109415 |", + "| b | 0.41040709263815384 |", + "| c | 0.6600456536439784 |", + "| d | 0.48855379387549824 |", + "| e | 0.48600669271341534 |", + "+------+---------------------+", + ]; + assert_batches_sorted_eq!(expected, &actual); + Ok(()) +} + +#[tokio::test] +async fn csv_query_group_by_order_by_avg_group_by_substr() -> Result<()> { + let ctx = SessionContext::new(); + register_aggregate_csv(&ctx).await?; + let sql = "SELECT substr(c1, 1, 1) as name, avg(c12) as average \ + FROM aggregate_test_100 \ + GROUP BY substr(c1, 1, 1) \ + ORDER BY avg(c12)"; + let actual = execute_to_batches(&ctx, sql).await; + let expected = vec![ + "+------+---------------------+", + "| name | average |", + "+------+---------------------+", + "| b | 0.41040709263815384 |", + "| e | 0.48600669271341534 |", + "| a | 0.48754517466109415 |", + "| d | 0.48855379387549824 |", + "| c | 0.6600456536439784 |", + "+------+---------------------+", + ]; + assert_batches_sorted_eq!(expected, &actual); + Ok(()) +} diff --git a/datafusion/expr/src/expr_rewriter.rs b/datafusion/expr/src/expr_rewriter.rs index e8cf049dde6f..9e8fa8a7ec73 100644 --- a/datafusion/expr/src/expr_rewriter.rs +++ b/datafusion/expr/src/expr_rewriter.rs @@ -19,6 +19,7 @@ use crate::expr::GroupingSet; use crate::logical_plan::Aggregate; +use crate::utils::grouping_set_to_exprlist; use crate::{Expr, ExprSchemable, LogicalPlan}; use datafusion_common::Result; use datafusion_common::{Column, DFSchema}; @@ -325,12 +326,16 @@ pub fn rewrite_sort_cols_by_aggs( fn rewrite_sort_col_by_aggs(expr: Expr, plan: &LogicalPlan) -> Result { match plan { LogicalPlan::Aggregate(Aggregate { - input, aggr_expr, .. + input, + aggr_expr, + group_expr, + .. }) => { struct Rewriter<'a> { plan: &'a LogicalPlan, input: &'a LogicalPlan, aggr_expr: &'a Vec, + distinct_group_exprs: &'a Vec, } impl<'a> ExprRewriter for Rewriter<'a> { @@ -341,8 +346,11 @@ fn rewrite_sort_col_by_aggs(expr: Expr, plan: &LogicalPlan) -> Result { return Ok(expr); } let normalized_expr = normalized_expr.unwrap(); - if let Some(found_agg) = - self.aggr_expr.iter().find(|a| (**a) == normalized_expr) + if let Some(found_agg) = self + .aggr_expr + .iter() + .chain(self.distinct_group_exprs) + .find(|a| (**a) == normalized_expr) { let agg = normalize_col(found_agg.clone(), self.plan)?; let col = Expr::Column( @@ -356,10 +364,12 @@ fn rewrite_sort_col_by_aggs(expr: Expr, plan: &LogicalPlan) -> Result { } } + let distinct_group_exprs = grouping_set_to_exprlist(group_expr.as_slice())?; expr.rewrite(&mut Rewriter { plan, input, aggr_expr, + distinct_group_exprs: &distinct_group_exprs, }) } LogicalPlan::Projection(_) => rewrite_sort_col_by_aggs(expr, plan.inputs()[0]), diff --git a/datafusion/expr/src/logical_plan/builder.rs b/datafusion/expr/src/logical_plan/builder.rs index 9eb379142ea6..2946a74afd70 100644 --- a/datafusion/expr/src/logical_plan/builder.rs +++ b/datafusion/expr/src/logical_plan/builder.rs @@ -335,16 +335,17 @@ impl LogicalPlanBuilder { .iter() .all(|c| input.schema().field_from_column(c).is_ok()) => { - let missing_exprs = missing_cols + let mut missing_exprs = missing_cols .iter() .map(|c| normalize_col(Expr::Column(c.clone()), &input)) .collect::>>()?; + // Do not let duplicate columns to be added, some of the + // missing_cols may be already present but without the new + // projected alias. + missing_exprs.retain(|e| !expr.contains(e)); expr.extend(missing_exprs); - - Ok(LogicalPlan::Projection(Projection::try_new( - expr, input, alias, - )?)) + Ok(project_with_alias((*input).clone(), expr, alias)?) } _ => { let new_inputs = curr_plan