diff --git a/datafusion/src/optimizer/projection_push_down.rs b/datafusion/src/optimizer/projection_push_down.rs index a9e571f3d00b..2cd7384e2443 100644 --- a/datafusion/src/optimizer/projection_push_down.rs +++ b/datafusion/src/optimizer/projection_push_down.rs @@ -29,7 +29,10 @@ use crate::optimizer::utils; use crate::sql::utils::find_sort_exprs; use arrow::datatypes::{Field, Schema}; use arrow::error::Result as ArrowResult; -use std::{collections::HashSet, sync::Arc}; +use std::{ + collections::{BTreeSet, HashSet}, + sync::Arc, +}; use utils::optimize_explain; /// Optimizer that removes unused projections and aggregations from plans @@ -75,9 +78,12 @@ fn get_projected_schema( // // we discard non-existing columns because some column names are not part of the schema, // e.g. when the column derives from an aggregation - let mut projection: Vec = required_columns + // + // Use BTreeSet to remove potential duplicates (e.g. union) as + // well as to sort the projection to ensure deterministic behavior + let mut projection: BTreeSet = required_columns .iter() - .filter(|c| c.relation.as_ref() == table_name) + .filter(|c| c.relation.is_none() || c.relation.as_ref() == table_name) .map(|c| schema.index_of(&c.name)) .filter_map(ArrowResult::ok) .collect(); @@ -87,7 +93,7 @@ fn get_projected_schema( // Ensure that we are reading at least one column from the table in case the query // does not reference any columns directly such as "SELECT COUNT(1) FROM table", // except when the table is empty (no column) - projection.push(0); + projection.insert(0); } else { // for table scan without projection, we default to return all columns projection = schema @@ -95,13 +101,10 @@ fn get_projected_schema( .iter() .enumerate() .map(|(i, _)| i) - .collect::>(); + .collect::>(); } } - // sort the projection otherwise we get non-deterministic behavior - projection.sort_unstable(); - // create the projected schema let mut projected_fields: Vec = Vec::with_capacity(projection.len()); match table_name { @@ -120,6 +123,7 @@ fn get_projected_schema( } } + let projection = projection.into_iter().collect::>(); Ok((projection, projected_fields.to_dfschema_ref()?)) } @@ -438,7 +442,9 @@ fn optimize_plan( mod tests { use super::*; - use crate::logical_plan::{col, lit, max, min, Expr, JoinType, LogicalPlanBuilder}; + use crate::logical_plan::{ + col, exprlist_to_fields, lit, max, min, Expr, JoinType, LogicalPlanBuilder, + }; use crate::test::*; use arrow::datatypes::DataType; @@ -568,6 +574,35 @@ mod tests { Ok(()) } + #[test] + fn table_scan_projected_schema_non_qualified_relation() -> Result<()> { + let table_scan = test_table_scan()?; + let input_schema = table_scan.schema(); + assert_eq!(3, input_schema.fields().len()); + assert_fields_eq(&table_scan, vec!["a", "b", "c"]); + + // Build the LogicalPlan directly (don't use PlanBuilder), so + // that the Column references are unqualified (e.g. their + // relation is `None`). PlanBuilder resolves the expressions + let expr = vec![col("a"), col("b")]; + let projected_fields = exprlist_to_fields(&expr, input_schema).unwrap(); + let projected_schema = DFSchema::new(projected_fields).unwrap(); + let plan = LogicalPlan::Projection { + expr, + input: Arc::new(table_scan), + schema: Arc::new(projected_schema), + }; + + assert_fields_eq(&plan, vec!["a", "b"]); + + let expected = "Projection: #a, #b\ + \n TableScan: test projection=Some([0, 1])"; + + assert_optimized_plan_eq(&plan, expected); + + Ok(()) + } + #[test] fn table_limit() -> Result<()> { let table_scan = test_table_scan()?;