Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 72 additions & 0 deletions datafusion/core/tests/sql/group_by.rs
Original file line number Diff line number Diff line change
Expand Up @@ -681,3 +681,75 @@ async fn group_by_dictionary() {
run_test_case::<UInt32Type>().await;
run_test_case::<UInt64Type>().await;
}

#[tokio::test]
async fn csv_query_group_by_order_by_substr() -> Result<()> {
let ctx = SessionContext::new();
register_aggregate_csv(&ctx).await?;
let sql = "SELECT substr(c1, 1, 1), avg(c12) \
FROM aggregate_test_100 \
GROUP BY substr(c1, 1, 1) \
ORDER BY substr(c1, 1, 1)";
let actual = execute_to_batches(&ctx, sql).await;
let expected = vec![
"+-------------------------------------------------+-----------------------------+",
"| substr(aggregate_test_100.c1,Int64(1),Int64(1)) | AVG(aggregate_test_100.c12) |",
"+-------------------------------------------------+-----------------------------+",
"| a | 0.48754517466109415 |",
"| b | 0.41040709263815384 |",
"| c | 0.6600456536439784 |",
"| d | 0.48855379387549824 |",
"| e | 0.48600669271341534 |",
"+-------------------------------------------------+-----------------------------+",
];
assert_batches_sorted_eq!(expected, &actual);
Ok(())
}

#[tokio::test]
async fn csv_query_group_by_order_by_substr_aliased_projection() -> Result<()> {
let ctx = SessionContext::new();
register_aggregate_csv(&ctx).await?;
let sql = "SELECT substr(c1, 1, 1) as name, avg(c12) as average \
FROM aggregate_test_100 \
GROUP BY substr(c1, 1, 1) \
ORDER BY substr(c1, 1, 1)";
let actual = execute_to_batches(&ctx, sql).await;
let expected = vec![
"+------+---------------------+",
"| name | average |",
"+------+---------------------+",
"| a | 0.48754517466109415 |",
"| b | 0.41040709263815384 |",
"| c | 0.6600456536439784 |",
"| d | 0.48855379387549824 |",
"| e | 0.48600669271341534 |",
"+------+---------------------+",
];
assert_batches_sorted_eq!(expected, &actual);
Ok(())
}

#[tokio::test]
async fn csv_query_group_by_order_by_avg_group_by_substr() -> Result<()> {
let ctx = SessionContext::new();
register_aggregate_csv(&ctx).await?;
let sql = "SELECT substr(c1, 1, 1) as name, avg(c12) as average \
FROM aggregate_test_100 \
GROUP BY substr(c1, 1, 1) \
ORDER BY avg(c12)";
let actual = execute_to_batches(&ctx, sql).await;
let expected = vec![
"+------+---------------------+",
"| name | average |",
"+------+---------------------+",
"| b | 0.41040709263815384 |",
"| e | 0.48600669271341534 |",
"| a | 0.48754517466109415 |",
"| d | 0.48855379387549824 |",
"| c | 0.6600456536439784 |",
"+------+---------------------+",
];
assert_batches_sorted_eq!(expected, &actual);
Ok(())
}
16 changes: 13 additions & 3 deletions datafusion/expr/src/expr_rewriter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

use crate::expr::GroupingSet;
use crate::logical_plan::Aggregate;
use crate::utils::grouping_set_to_exprlist;
use crate::{Expr, ExprSchemable, LogicalPlan};
use datafusion_common::Result;
use datafusion_common::{Column, DFSchema};
Expand Down Expand Up @@ -325,12 +326,16 @@ pub fn rewrite_sort_cols_by_aggs(
fn rewrite_sort_col_by_aggs(expr: Expr, plan: &LogicalPlan) -> Result<Expr> {
match plan {
LogicalPlan::Aggregate(Aggregate {
input, aggr_expr, ..
input,
aggr_expr,
group_expr,
..
}) => {
struct Rewriter<'a> {
plan: &'a LogicalPlan,
input: &'a LogicalPlan,
aggr_expr: &'a Vec<Expr>,
distinct_group_exprs: &'a Vec<Expr>,
}

impl<'a> ExprRewriter for Rewriter<'a> {
Expand All @@ -341,8 +346,11 @@ fn rewrite_sort_col_by_aggs(expr: Expr, plan: &LogicalPlan) -> Result<Expr> {
return Ok(expr);
}
let normalized_expr = normalized_expr.unwrap();
if let Some(found_agg) =
self.aggr_expr.iter().find(|a| (**a) == normalized_expr)
if let Some(found_agg) = self
.aggr_expr
.iter()
.chain(self.distinct_group_exprs)
.find(|a| (**a) == normalized_expr)
{
let agg = normalize_col(found_agg.clone(), self.plan)?;
let col = Expr::Column(
Expand All @@ -356,10 +364,12 @@ fn rewrite_sort_col_by_aggs(expr: Expr, plan: &LogicalPlan) -> Result<Expr> {
}
}

let distinct_group_exprs = grouping_set_to_exprlist(group_expr.as_slice())?;
expr.rewrite(&mut Rewriter {
plan,
input,
aggr_expr,
distinct_group_exprs: &distinct_group_exprs,
})
}
LogicalPlan::Projection(_) => rewrite_sort_col_by_aggs(expr, plan.inputs()[0]),
Expand Down
11 changes: 6 additions & 5 deletions datafusion/expr/src/logical_plan/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -335,16 +335,17 @@ impl LogicalPlanBuilder {
.iter()
.all(|c| input.schema().field_from_column(c).is_ok()) =>
{
let missing_exprs = missing_cols
let mut missing_exprs = missing_cols
.iter()
.map(|c| normalize_col(Expr::Column(c.clone()), &input))
.collect::<Result<Vec<_>>>()?;

// Do not let duplicate columns to be added, some of the
// missing_cols may be already present but without the new
// projected alias.
missing_exprs.retain(|e| !expr.contains(e));
expr.extend(missing_exprs);

Ok(LogicalPlan::Projection(Projection::try_new(
expr, input, alias,
)?))
Ok(project_with_alias((*input).clone(), expr, alias)?)
}
_ => {
let new_inputs = curr_plan
Expand Down