From 00f0abc94f705b192f0a310364564859b49e7d54 Mon Sep 17 00:00:00 2001 From: "Heres, Daniel" Date: Wed, 5 May 2021 12:41:36 +0200 Subject: [PATCH 1/6] Fix wrong projection 'optimization' --- datafusion/src/sql/planner.rs | 18 +----------------- datafusion/tests/sql.rs | 17 +++++++++++++++++ 2 files changed, 18 insertions(+), 17 deletions(-) diff --git a/datafusion/src/sql/planner.rs b/datafusion/src/sql/planner.rs index 48900f56aad5..99a5633451aa 100644 --- a/datafusion/src/sql/planner.rs +++ b/datafusion/src/sql/planner.rs @@ -652,26 +652,10 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { } /// Wrap a plan in a projection - /// - /// The projection is applied only when necessary, - /// i.e., when the input fields are different than the - /// projection. Note that if the input fields are the same, but out of - /// order, the projection will be applied. fn project(&self, input: &LogicalPlan, expr: Vec) -> Result { self.validate_schema_satisfies_exprs(&input.schema(), &expr)?; - let plan = LogicalPlanBuilder::from(input).project(expr)?.build()?; - - let project = match input { - LogicalPlan::TableScan { .. } => true, - _ => plan.schema().fields() != input.schema().fields(), - }; - - if project { - Ok(plan) - } else { - Ok(input.clone()) - } + LogicalPlanBuilder::from(input).project(expr)?.build() } fn aggregate( diff --git a/datafusion/tests/sql.rs b/datafusion/tests/sql.rs index bf28525ad437..db23376c68b2 100644 --- a/datafusion/tests/sql.rs +++ b/datafusion/tests/sql.rs @@ -447,6 +447,23 @@ async fn select_distinct_simple() -> Result<()> { Ok(()) } +#[tokio::test] +async fn projection_same_fields() -> Result<()> { + let mut ctx = ExecutionContext::new(); + + let sql = "select (1+1) as a from (select 1 as a);"; + let actual = execute(&mut ctx, sql).await; + + let expected = vec![ + vec!["2"], + ]; + assert_eq!(actual, expected); + + Ok(()) +} + + + #[tokio::test] async fn csv_query_group_by_float64() -> Result<()> { let mut ctx = ExecutionContext::new(); From e47230840b43f5ccb8386c8caf2e698b5f183864 Mon Sep 17 00:00:00 2001 From: "Heres, Daniel" Date: Wed, 5 May 2021 12:45:57 +0200 Subject: [PATCH 2/6] Fmt --- datafusion/tests/sql.rs | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/datafusion/tests/sql.rs b/datafusion/tests/sql.rs index db23376c68b2..fb3f6265eefc 100644 --- a/datafusion/tests/sql.rs +++ b/datafusion/tests/sql.rs @@ -454,16 +454,12 @@ async fn projection_same_fields() -> Result<()> { let sql = "select (1+1) as a from (select 1 as a);"; let actual = execute(&mut ctx, sql).await; - let expected = vec![ - vec!["2"], - ]; + let expected = vec![vec!["2"]]; assert_eq!(actual, expected); Ok(()) } - - #[tokio::test] async fn csv_query_group_by_float64() -> Result<()> { let mut ctx = ExecutionContext::new(); From d0d547bfc97eb145d6c555c8fad6ab4c3f680999 Mon Sep 17 00:00:00 2001 From: "Heres, Daniel" Date: Wed, 5 May 2021 20:18:39 +0200 Subject: [PATCH 3/6] test --- datafusion/src/execution/dataframe_impl.rs | 36 +++++++++++++--------- 1 file changed, 21 insertions(+), 15 deletions(-) diff --git a/datafusion/src/execution/dataframe_impl.rs b/datafusion/src/execution/dataframe_impl.rs index 2a0c39aa48eb..fdc75f92f2e7 100644 --- a/datafusion/src/execution/dataframe_impl.rs +++ b/datafusion/src/execution/dataframe_impl.rs @@ -177,9 +177,11 @@ impl DataFrame for DataFrameImpl { #[cfg(test)] mod tests { + use std::vec; + use super::*; - use crate::execution::context::ExecutionContext; use crate::logical_plan::*; + use crate::{assert_batches_sorted_eq, execution::context::ExecutionContext}; use crate::{datasource::csv::CsvReadOptions, physical_plan::ColumnarValue}; use crate::{physical_plan::functions::ScalarFunctionImplementation, test}; use arrow::datatypes::DataType; @@ -216,8 +218,8 @@ mod tests { Ok(()) } - #[test] - fn aggregate() -> Result<()> { + #[tokio::test] + async fn aggregate() -> Result<()> { // build plan using DataFrame API let df = test_table()?; let group_expr = vec![col("c1")]; @@ -230,18 +232,22 @@ mod tests { count_distinct(col("c12")), ]; - let df = df.aggregate(group_expr, aggr_expr)?; - - let plan = df.to_logical_plan(); - - // build same plan using SQL API - let sql = "SELECT c1, MIN(c12), MAX(c12), AVG(c12), SUM(c12), COUNT(c12), COUNT(DISTINCT c12) \ - FROM aggregate_test_100 \ - GROUP BY c1"; - let sql_plan = create_plan(sql)?; - - // the two plans should be identical - assert_same_plan(&plan, &sql_plan); + let df: Vec = df.aggregate(group_expr, aggr_expr)?.collect().await?; + + assert_batches_sorted_eq!( + vec![ + "+----+----------------------+--------------------+---------------------+--------------------+------------+---------------------+", + "| c1 | MIN(c12) | MAX(c12) | AVG(c12) | SUM(c12) | COUNT(c12) | COUNT(DISTINCT c12) |", + "+----+----------------------+--------------------+---------------------+--------------------+------------+---------------------+", + "| a | 0.02182578039211991 | 0.9800193410444061 | 0.48754517466109415 | 10.238448667882977 | 21 | 21 |", + "| b | 0.04893135681998029 | 0.9185813970744787 | 0.41040709263815384 | 7.797734760124923 | 19 | 19 |", + "| c | 0.0494924465469434 | 0.991517828651004 | 0.6600456536439784 | 13.860958726523545 | 21 | 21 |", + "| d | 0.061029375346466685 | 0.9748360509016578 | 0.48855379387549824 | 8.793968289758968 | 18 | 18 |", + "| e | 0.01479305307777301 | 0.9965400387585364 | 0.48600669271341534 | 10.206140546981722 | 21 | 21 |", + "+----+----------------------+--------------------+---------------------+--------------------+------------+---------------------+", + ], + &df + ); Ok(()) } From 4023f5a9f29c46e9740f4e14ebb31cd798bf4caa Mon Sep 17 00:00:00 2001 From: "Heres, Daniel" Date: Wed, 5 May 2021 20:30:10 +0200 Subject: [PATCH 4/6] Update some more --- datafusion/src/sql/planner.rs | 71 ++++++++++++++++++++--------------- 1 file changed, 41 insertions(+), 30 deletions(-) diff --git a/datafusion/src/sql/planner.rs b/datafusion/src/sql/planner.rs index 99a5633451aa..55741ef86fd3 100644 --- a/datafusion/src/sql/planner.rs +++ b/datafusion/src/sql/planner.rs @@ -1781,9 +1781,10 @@ mod tests { let sql = "SELECT MAX(age) FROM person HAVING MAX(age) < 30"; - let expected = "Filter: #MAX(age) Lt Int64(30)\ - \n Aggregate: groupBy=[[]], aggr=[[MAX(#age)]]\ - \n TableScan: person projection=None"; + let expected = "Projection: #MAX(age)\ + \n Filter: #MAX(age) Lt Int64(30)\ + \n Aggregate: groupBy=[[]], aggr=[[MAX(#age)]]\ + \n TableScan: person projection=None"; quick_test(sql, expected); } @@ -1841,9 +1842,10 @@ mod tests { FROM person GROUP BY first_name HAVING first_name = 'M'"; - let expected = "Filter: #first_name Eq Utf8(\"M\")\ - \n Aggregate: groupBy=[[#first_name]], aggr=[[MAX(#age)]]\ - \n TableScan: person projection=None"; + let expected = "Projection: #first_name, #MAX(age)\ + \n Filter: #first_name Eq Utf8(\"M\")\ + \n Aggregate: groupBy=[[#first_name]], aggr=[[MAX(#age)]]\ + \n TableScan: person projection=None"; quick_test(sql, expected); } @@ -1854,10 +1856,11 @@ mod tests { WHERE id > 5 GROUP BY first_name HAVING MAX(age) < 100"; - let expected = "Filter: #MAX(age) Lt Int64(100)\ - \n Aggregate: groupBy=[[#first_name]], aggr=[[MAX(#age)]]\ - \n Filter: #id Gt Int64(5)\ - \n TableScan: person projection=None"; + let expected = "Projection: #first_name, #MAX(age)\ + \n Filter: #MAX(age) Lt Int64(100)\ + \n Aggregate: groupBy=[[#first_name]], aggr=[[MAX(#age)]]\ + \n Filter: #id Gt Int64(5)\ + \n TableScan: person projection=None"; quick_test(sql, expected); } @@ -1869,10 +1872,11 @@ mod tests { WHERE id > 5 AND age > 18 GROUP BY first_name HAVING MAX(age) < 100"; - let expected = "Filter: #MAX(age) Lt Int64(100)\ - \n Aggregate: groupBy=[[#first_name]], aggr=[[MAX(#age)]]\ - \n Filter: #id Gt Int64(5) And #age Gt Int64(18)\ - \n TableScan: person projection=None"; + let expected = "Projection: #first_name, #MAX(age)\ + \n Filter: #MAX(age) Lt Int64(100)\ + \n Aggregate: groupBy=[[#first_name]], aggr=[[MAX(#age)]]\ + \n Filter: #id Gt Int64(5) And #age Gt Int64(18)\ + \n TableScan: person projection=None"; quick_test(sql, expected); } @@ -1909,9 +1913,10 @@ mod tests { FROM person GROUP BY first_name HAVING MAX(age) > 100"; - let expected = "Filter: #MAX(age) Gt Int64(100)\ - \n Aggregate: groupBy=[[#first_name]], aggr=[[MAX(#age)]]\ - \n TableScan: person projection=None"; + let expected = "Projection: #first_name, #MAX(age)\ + \n Filter: #MAX(age) Gt Int64(100)\ + \n Aggregate: groupBy=[[#first_name]], aggr=[[MAX(#age)]]\ + \n TableScan: person projection=None"; quick_test(sql, expected); } @@ -1934,9 +1939,10 @@ mod tests { FROM person GROUP BY first_name HAVING MAX(age) > 100 AND MAX(age) < 200"; - let expected = "Filter: #MAX(age) Gt Int64(100) And #MAX(age) Lt Int64(200)\ - \n Aggregate: groupBy=[[#first_name]], aggr=[[MAX(#age)]]\ - \n TableScan: person projection=None"; + let expected = "Projection: #first_name, #MAX(age)\ + \n Filter: #MAX(age) Gt Int64(100) And #MAX(age) Lt Int64(200)\ + \n Aggregate: groupBy=[[#first_name]], aggr=[[MAX(#age)]]\ + \n TableScan: person projection=None"; quick_test(sql, expected); } @@ -2044,8 +2050,9 @@ mod tests { fn select_simple_aggregate() { quick_test( "SELECT MIN(age) FROM person", - "Aggregate: groupBy=[[]], aggr=[[MIN(#age)]]\ - \n TableScan: person projection=None", + "Projection: #MIN(age)\ + \n Aggregate: groupBy=[[]], aggr=[[MIN(#age)]]\ + \n TableScan: person projection=None", ); } @@ -2053,8 +2060,9 @@ mod tests { fn test_sum_aggregate() { quick_test( "SELECT SUM(age) from person", - "Aggregate: groupBy=[[]], aggr=[[SUM(#age)]]\ - \n TableScan: person projection=None", + "Projection: #SUM(age)\ + \n Aggregate: groupBy=[[]], aggr=[[SUM(#age)]]\ + \n TableScan: person projection=None", ); } @@ -2333,8 +2341,9 @@ mod tests { fn select_aggregate_with_non_column_inner_expression_with_groupby() { quick_test( "SELECT state, MIN(age + 1) FROM person GROUP BY state", - "Aggregate: groupBy=[[#state]], aggr=[[MIN(#age Plus Int64(1))]]\ - \n TableScan: person projection=None", + "Projection: #state, #MIN(age Plus Int64(1))\ + \n Aggregate: groupBy=[[#state]], aggr=[[MIN(#age Plus Int64(1))]]\ + \n TableScan: person projection=None", ); } @@ -2350,16 +2359,18 @@ mod tests { #[test] fn select_count_one() { let sql = "SELECT COUNT(1) FROM person"; - let expected = "Aggregate: groupBy=[[]], aggr=[[COUNT(UInt8(1))]]\ - \n TableScan: person projection=None"; + let expected = "Projection: #COUNT(UInt8(1))\ + \n Aggregate: groupBy=[[]], aggr=[[COUNT(UInt8(1))]]\ + \n TableScan: person projection=None"; quick_test(sql, expected); } #[test] fn select_count_column() { let sql = "SELECT COUNT(id) FROM person"; - let expected = "Aggregate: groupBy=[[]], aggr=[[COUNT(#id)]]\ - \n TableScan: person projection=None"; + let expected = "Projection: #COUNT(id)\ + \n Aggregate: groupBy=[[]], aggr=[[COUNT(#id)]]\ + \n TableScan: person projection=None"; quick_test(sql, expected); } From 411a7a8173c3e32734012fdd88c52033ef635994 Mon Sep 17 00:00:00 2001 From: "Heres, Daniel" Date: Wed, 5 May 2021 20:36:34 +0200 Subject: [PATCH 5/6] Update tests --- datafusion/src/sql/planner.rs | 41 ++++++++++++++++++++--------------- 1 file changed, 24 insertions(+), 17 deletions(-) diff --git a/datafusion/src/sql/planner.rs b/datafusion/src/sql/planner.rs index 55741ef86fd3..ed7dd377c835 100644 --- a/datafusion/src/sql/planner.rs +++ b/datafusion/src/sql/planner.rs @@ -1721,10 +1721,11 @@ mod tests { ) WHERE fn1 = 'X' AND age < 30"; - let expected = "Filter: #fn1 Eq Utf8(\"X\") And #age Lt Int64(30)\ - \n Projection: #first_name AS fn1, #age\ - \n Filter: #age Gt Int64(20)\ - \n TableScan: person projection=None"; + let expected = "Projection: #fn1, #age\ + \n Filter: #fn1 Eq Utf8(\"X\") And #age Lt Int64(30)\ + \n Projection: #first_name AS fn1, #age\ + \n Filter: #age Gt Int64(20)\ + \n TableScan: person projection=None"; quick_test(sql, expected); } @@ -2035,14 +2036,16 @@ mod tests { fn select_wildcard_with_groupby() { quick_test( "SELECT * FROM person GROUP BY id, first_name, last_name, age, state, salary, birth_date", - "Aggregate: groupBy=[[#id, #first_name, #last_name, #age, #state, #salary, #birth_date]], aggr=[[]]\ - \n TableScan: person projection=None", + "Projection: #id, #first_name, #last_name, #age, #state, #salary, #birth_date\ + \n Aggregate: groupBy=[[#id, #first_name, #last_name, #age, #state, #salary, #birth_date]], aggr=[[]]\ + \n TableScan: person projection=None", ); quick_test( "SELECT * FROM (SELECT first_name, last_name FROM person) GROUP BY first_name, last_name", - "Aggregate: groupBy=[[#first_name, #last_name]], aggr=[[]]\ - \n Projection: #first_name, #last_name\ - \n TableScan: person projection=None", + "Projection: #first_name, #last_name\ + \n Aggregate: groupBy=[[#first_name, #last_name]], aggr=[[]]\ + \n Projection: #first_name, #last_name\ + \n TableScan: person projection=None", ); } @@ -2123,8 +2126,9 @@ mod tests { fn select_simple_aggregate_with_groupby() { quick_test( "SELECT state, MIN(age), MAX(age) FROM person GROUP BY state", - "Aggregate: groupBy=[[#state]], aggr=[[MIN(#age), MAX(#age)]]\ - \n TableScan: person projection=None", + "Projection: #state, #MIN(age), #MAX(age)\ + \n Aggregate: groupBy=[[#state]], aggr=[[MIN(#age), MAX(#age)]]\ + \n TableScan: person projection=None", ); } @@ -2261,8 +2265,9 @@ mod tests { ) { quick_test( "SELECT age + 1, MIN(first_name) FROM person GROUP BY age + 1", - "Aggregate: groupBy=[[#age Plus Int64(1)]], aggr=[[MIN(#first_name)]]\ - \n TableScan: person projection=None", + "Projection: #age Plus Int64(1), #MIN(first_name)\ + \n Aggregate: groupBy=[[#age Plus Int64(1)]], aggr=[[MIN(#first_name)]]\ + \n TableScan: person projection=None", ); quick_test( "SELECT MIN(first_name), age + 1 FROM person GROUP BY age + 1", @@ -2456,8 +2461,9 @@ mod tests { #[test] fn select_group_by() { let sql = "SELECT state FROM person GROUP BY state"; - let expected = "Aggregate: groupBy=[[#state]], aggr=[[]]\ - \n TableScan: person projection=None"; + let expected = "Projection: #state\ + \n Aggregate: groupBy=[[#state]], aggr=[[]]\ + \n TableScan: person projection=None"; quick_test(sql, expected); } @@ -2475,8 +2481,9 @@ mod tests { #[test] fn select_group_by_count_star() { let sql = "SELECT state, COUNT(*) FROM person GROUP BY state"; - let expected = "Aggregate: groupBy=[[#state]], aggr=[[COUNT(UInt8(1))]]\ - \n TableScan: person projection=None"; + let expected = "Projection: #state, #COUNT(UInt8(1))\ + \n Aggregate: groupBy=[[#state]], aggr=[[COUNT(UInt8(1))]]\ + \n TableScan: person projection=None"; quick_test(sql, expected); } From 89463eac17e5d743315af427cbe3f748a6c82695 Mon Sep 17 00:00:00 2001 From: "Heres, Daniel" Date: Wed, 5 May 2021 20:41:26 +0200 Subject: [PATCH 6/6] Fix test --- datafusion/tests/sql.rs | 21 ++++++++++++--------- 1 file changed, 12 insertions(+), 9 deletions(-) diff --git a/datafusion/tests/sql.rs b/datafusion/tests/sql.rs index fb3f6265eefc..5c90f8ac162b 100644 --- a/datafusion/tests/sql.rs +++ b/datafusion/tests/sql.rs @@ -81,15 +81,18 @@ async fn nyc() -> Result<()> { let optimized_plan = ctx.optimize(&logical_plan)?; match &optimized_plan { - LogicalPlan::Aggregate { input, .. } => match input.as_ref() { - LogicalPlan::TableScan { - ref projected_schema, - .. - } => { - assert_eq!(2, projected_schema.fields().len()); - assert_eq!(projected_schema.field(0).name(), "passenger_count"); - assert_eq!(projected_schema.field(1).name(), "fare_amount"); - } + LogicalPlan::Projection { input, .. } => match input.as_ref() { + LogicalPlan::Aggregate { input, .. } => match input.as_ref() { + LogicalPlan::TableScan { + ref projected_schema, + .. + } => { + assert_eq!(2, projected_schema.fields().len()); + assert_eq!(projected_schema.field(0).name(), "passenger_count"); + assert_eq!(projected_schema.field(1).name(), "fare_amount"); + } + _ => unreachable!(), + }, _ => unreachable!(), }, _ => unreachable!(false),