diff --git a/datafusion/core/src/physical_plan/planner.rs b/datafusion/core/src/physical_plan/planner.rs index 768c42978936..7158277f4e35 100644 --- a/datafusion/core/src/physical_plan/planner.rs +++ b/datafusion/core/src/physical_plan/planner.rs @@ -773,12 +773,19 @@ impl DefaultPhysicalPlanner { )?; Ok(Arc::new(FilterExec::try_new(runtime_expr, physical_input)?)) } - LogicalPlan::Union(Union { inputs, .. }) => { + LogicalPlan::Union(Union { inputs, schema }) => { let physical_plans = futures::stream::iter(inputs) .then(|lp| self.create_initial_plan(lp, session_state)) .try_collect::>() .await?; - Ok(Arc::new(UnionExec::new(physical_plans))) + if schema.fields().len() < physical_plans[0].schema().fields().len() { + // `schema` could be a subset of the child schema. For example + // for query "select count(*) from (select a from t union all select a from t)" + // `schema` is empty but child schema contains one field `a`. + Ok(Arc::new(UnionExec::try_new_with_schema(physical_plans, schema.clone())?)) + } else { + Ok(Arc::new(UnionExec::new(physical_plans))) + } } LogicalPlan::Repartition(Repartition { input, diff --git a/datafusion/core/src/physical_plan/union.rs b/datafusion/core/src/physical_plan/union.rs index af57c9ef9cc2..8d17b14bdf1c 100644 --- a/datafusion/core/src/physical_plan/union.rs +++ b/datafusion/core/src/physical_plan/union.rs @@ -30,6 +30,7 @@ use arrow::{ datatypes::{Field, Schema, SchemaRef}, record_batch::RecordBatch, }; +use datafusion_common::{DFSchemaRef, DataFusionError}; use futures::{Stream, StreamExt}; use itertools::Itertools; use log::debug; @@ -63,6 +64,38 @@ pub struct UnionExec { } impl UnionExec { + /// Create a new UnionExec with specified schema. + /// The `schema` should always be a subset of the schema of `inputs`, + /// otherwise, an error will be returned. + pub fn try_new_with_schema( + inputs: Vec>, + schema: DFSchemaRef, + ) -> Result { + let mut exec = Self::new(inputs); + let exec_schema = exec.schema(); + let fields = schema + .fields() + .iter() + .map(|dff| { + exec_schema + .field_with_name(dff.name()) + .cloned() + .map_err(|_| { + DataFusionError::Internal(format!( + "Cannot find the field {:?} in child schema", + dff.name() + )) + }) + }) + .collect::>>()?; + let schema = Arc::new(Schema::new_with_metadata( + fields, + exec.schema().metadata().clone(), + )); + exec.schema = schema; + Ok(exec) + } + /// Create a new UnionExec pub fn new(inputs: Vec>) -> Self { let fields: Vec = (0..inputs[0].schema().fields().len()) diff --git a/datafusion/core/tests/sql/union.rs b/datafusion/core/tests/sql/union.rs index 29856a37b1a9..ac0e39f4d479 100644 --- a/datafusion/core/tests/sql/union.rs +++ b/datafusion/core/tests/sql/union.rs @@ -80,6 +80,23 @@ async fn union_all_with_aggregate() -> Result<()> { Ok(()) } +#[tokio::test] +async fn union_all_with_count() -> Result<()> { + let ctx = SessionContext::new(); + execute_to_batches(&ctx, "CREATE table t as SELECT 1 as a").await; + let sql = "SELECT COUNT(*) FROM (SELECT a from t UNION ALL SELECT a from t)"; + let actual = execute_to_batches(&ctx, sql).await; + let expected = vec![ + "+-----------------+", + "| COUNT(UInt8(1)) |", + "+-----------------+", + "| 2 |", + "+-----------------+", + ]; + assert_batches_eq!(expected, &actual); + Ok(()) +} + #[tokio::test] async fn union_schemas() -> Result<()> { let ctx =