From 630fd72ecb0f6ea14aa06090373b4e02e779b479 Mon Sep 17 00:00:00 2001 From: Batuhan Taskaya Date: Wed, 24 Aug 2022 17:50:58 +0300 Subject: [PATCH 1/3] Add the test for mix of order by/group by on a complex expr --- datafusion/core/tests/sql/group_by.rs | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/datafusion/core/tests/sql/group_by.rs b/datafusion/core/tests/sql/group_by.rs index e3da1b02195a..91732e8da83c 100644 --- a/datafusion/core/tests/sql/group_by.rs +++ b/datafusion/core/tests/sql/group_by.rs @@ -681,3 +681,27 @@ async fn group_by_dictionary() { run_test_case::().await; run_test_case::().await; } + +#[tokio::test] +async fn csv_query_group_by_order_by_substr() -> Result<()> { + let ctx = SessionContext::new(); + register_aggregate_csv(&ctx).await?; + let sql = "SELECT substr(c1, 0, 1), avg(c12) \ + FROM aggregate_test_100 \ + GROUP BY substr(c1, 0, 1) \ + ORDER BY substr(c1, 0, 1)"; + let actual = execute_to_batches(&ctx, sql).await; + let expected = vec![ + "+----+-----------------------------+", + "| c1 | AVG(aggregate_test_100.c12) |", + "+----+-----------------------------+", + "| a | 0.48754517466109415 |", + "| b | 0.41040709263815384 |", + "| c | 0.6600456536439784 |", + "| d | 0.48855379387549824 |", + "| e | 0.48600669271341534 |", + "+----+-----------------------------+", + ]; + assert_batches_sorted_eq!(expected, &actual); + Ok(()) +} From a7975e4a0e1e2280e3f8e92b80fa83feb08f2e18 Mon Sep 17 00:00:00 2001 From: Batuhan Taskaya Date: Sat, 27 Aug 2022 10:32:48 +0300 Subject: [PATCH 2/3] Allow sorting by aggregated groups --- datafusion/core/tests/sql/group_by.rs | 72 ++++++++++++++++++++++----- datafusion/expr/src/expr_rewriter.rs | 16 ++++-- 2 files changed, 73 insertions(+), 15 deletions(-) diff --git a/datafusion/core/tests/sql/group_by.rs b/datafusion/core/tests/sql/group_by.rs index 91732e8da83c..2e1007be81c9 100644 --- a/datafusion/core/tests/sql/group_by.rs +++ b/datafusion/core/tests/sql/group_by.rs @@ -686,21 +686,69 @@ async fn group_by_dictionary() { async fn csv_query_group_by_order_by_substr() -> Result<()> { let ctx = SessionContext::new(); register_aggregate_csv(&ctx).await?; - let sql = "SELECT substr(c1, 0, 1), avg(c12) \ + let sql = "SELECT substr(c1, 1, 1), avg(c12) \ FROM aggregate_test_100 \ - GROUP BY substr(c1, 0, 1) \ - ORDER BY substr(c1, 0, 1)"; + GROUP BY substr(c1, 1, 1) \ + ORDER BY substr(c1, 1, 1)"; let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ - "+----+-----------------------------+", - "| c1 | AVG(aggregate_test_100.c12) |", - "+----+-----------------------------+", - "| a | 0.48754517466109415 |", - "| b | 0.41040709263815384 |", - "| c | 0.6600456536439784 |", - "| d | 0.48855379387549824 |", - "| e | 0.48600669271341534 |", - "+----+-----------------------------+", + "+-------------------------------------------------+-----------------------------+", + "| substr(aggregate_test_100.c1,Int64(1),Int64(1)) | AVG(aggregate_test_100.c12) |", + "+-------------------------------------------------+-----------------------------+", + "| a | 0.48754517466109415 |", + "| b | 0.41040709263815384 |", + "| c | 0.6600456536439784 |", + "| d | 0.48855379387549824 |", + "| e | 0.48600669271341534 |", + "+-------------------------------------------------+-----------------------------+", + ]; + assert_batches_sorted_eq!(expected, &actual); + Ok(()) +} + +#[tokio::test] +async fn csv_query_group_by_order_by_substr_aliased_projection() -> Result<()> { + let ctx = SessionContext::new(); + register_aggregate_csv(&ctx).await?; + let sql = "SELECT substr(c1, 1, 1) as name, avg(c12) as average \ + FROM aggregate_test_100 \ + GROUP BY substr(c1, 1, 1) \ + ORDER BY substr(c1, 1, 1)"; + let actual = execute_to_batches(&ctx, sql).await; + let expected = vec![ + "+------+---------------------+", + "| name | average |", + "+------+---------------------+", + "| a | 0.48754517466109415 |", + "| b | 0.41040709263815384 |", + "| c | 0.6600456536439784 |", + "| d | 0.48855379387549824 |", + "| e | 0.48600669271341534 |", + "+------+---------------------+", + ]; + assert_batches_sorted_eq!(expected, &actual); + Ok(()) +} + +#[tokio::test] +async fn csv_query_group_by_order_by_avg_group_by_substr() -> Result<()> { + let ctx = SessionContext::new(); + register_aggregate_csv(&ctx).await?; + let sql = "SELECT substr(c1, 1, 1) as name, avg(c12) as average \ + FROM aggregate_test_100 \ + GROUP BY substr(c1, 1, 1) \ + ORDER BY avg(c12)"; + let actual = execute_to_batches(&ctx, sql).await; + let expected = vec![ + "+------+---------------------+", + "| name | average |", + "+------+---------------------+", + "| b | 0.41040709263815384 |", + "| e | 0.48600669271341534 |", + "| a | 0.48754517466109415 |", + "| d | 0.48855379387549824 |", + "| c | 0.6600456536439784 |", + "+------+---------------------+", ]; assert_batches_sorted_eq!(expected, &actual); Ok(()) diff --git a/datafusion/expr/src/expr_rewriter.rs b/datafusion/expr/src/expr_rewriter.rs index e8cf049dde6f..9e8fa8a7ec73 100644 --- a/datafusion/expr/src/expr_rewriter.rs +++ b/datafusion/expr/src/expr_rewriter.rs @@ -19,6 +19,7 @@ use crate::expr::GroupingSet; use crate::logical_plan::Aggregate; +use crate::utils::grouping_set_to_exprlist; use crate::{Expr, ExprSchemable, LogicalPlan}; use datafusion_common::Result; use datafusion_common::{Column, DFSchema}; @@ -325,12 +326,16 @@ pub fn rewrite_sort_cols_by_aggs( fn rewrite_sort_col_by_aggs(expr: Expr, plan: &LogicalPlan) -> Result { match plan { LogicalPlan::Aggregate(Aggregate { - input, aggr_expr, .. + input, + aggr_expr, + group_expr, + .. }) => { struct Rewriter<'a> { plan: &'a LogicalPlan, input: &'a LogicalPlan, aggr_expr: &'a Vec, + distinct_group_exprs: &'a Vec, } impl<'a> ExprRewriter for Rewriter<'a> { @@ -341,8 +346,11 @@ fn rewrite_sort_col_by_aggs(expr: Expr, plan: &LogicalPlan) -> Result { return Ok(expr); } let normalized_expr = normalized_expr.unwrap(); - if let Some(found_agg) = - self.aggr_expr.iter().find(|a| (**a) == normalized_expr) + if let Some(found_agg) = self + .aggr_expr + .iter() + .chain(self.distinct_group_exprs) + .find(|a| (**a) == normalized_expr) { let agg = normalize_col(found_agg.clone(), self.plan)?; let col = Expr::Column( @@ -356,10 +364,12 @@ fn rewrite_sort_col_by_aggs(expr: Expr, plan: &LogicalPlan) -> Result { } } + let distinct_group_exprs = grouping_set_to_exprlist(group_expr.as_slice())?; expr.rewrite(&mut Rewriter { plan, input, aggr_expr, + distinct_group_exprs: &distinct_group_exprs, }) } LogicalPlan::Projection(_) => rewrite_sort_col_by_aggs(expr, plan.inputs()[0]), From 6e6b6745563ef4b1a8a994f3af7521557cac0ced Mon Sep 17 00:00:00 2001 From: Batuhan Taskaya Date: Sat, 27 Aug 2022 14:15:01 +0300 Subject: [PATCH 3/3] Prevent duplicate sort expressions with mismatched alias to be included --- datafusion/expr/src/logical_plan/builder.rs | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/datafusion/expr/src/logical_plan/builder.rs b/datafusion/expr/src/logical_plan/builder.rs index 9eb379142ea6..2946a74afd70 100644 --- a/datafusion/expr/src/logical_plan/builder.rs +++ b/datafusion/expr/src/logical_plan/builder.rs @@ -335,16 +335,17 @@ impl LogicalPlanBuilder { .iter() .all(|c| input.schema().field_from_column(c).is_ok()) => { - let missing_exprs = missing_cols + let mut missing_exprs = missing_cols .iter() .map(|c| normalize_col(Expr::Column(c.clone()), &input)) .collect::>>()?; + // Do not let duplicate columns to be added, some of the + // missing_cols may be already present but without the new + // projected alias. + missing_exprs.retain(|e| !expr.contains(e)); expr.extend(missing_exprs); - - Ok(LogicalPlan::Projection(Projection::try_new( - expr, input, alias, - )?)) + Ok(project_with_alias((*input).clone(), expr, alias)?) } _ => { let new_inputs = curr_plan