diff --git a/datafusion/src/physical_plan/sort.rs b/datafusion/src/physical_plan/sort.rs index 7747030d8a93..437519a7d2a2 100644 --- a/datafusion/src/physical_plan/sort.rs +++ b/datafusion/src/physical_plan/sort.rs @@ -241,7 +241,6 @@ impl SortStream { sort_time: Arc, ) -> Self { let (tx, rx) = futures::channel::oneshot::channel(); - let schema = input.schema(); tokio::spawn(async move { let schema = input.schema(); diff --git a/datafusion/src/physical_plan/windows.rs b/datafusion/src/physical_plan/windows.rs index 659d2183819d..7eb14943facf 100644 --- a/datafusion/src/physical_plan/windows.rs +++ b/datafusion/src/physical_plan/windows.rs @@ -19,7 +19,7 @@ use crate::error::{DataFusionError, Result}; use crate::physical_plan::{ - aggregates, + aggregates, common, expressions::{Literal, NthValue, RowNumber}, type_coercion::coerce, window_functions::signature_for_built_in, @@ -29,20 +29,18 @@ use crate::physical_plan::{ RecordBatchStream, SendableRecordBatchStream, WindowAccumulator, WindowExpr, }; use crate::scalar::ScalarValue; -use arrow::compute::concat; use arrow::{ - array::{Array, ArrayRef}, + array::ArrayRef, datatypes::{Field, Schema, SchemaRef}, error::{ArrowError, Result as ArrowResult}, record_batch::RecordBatch, }; use async_trait::async_trait; -use futures::stream::{Stream, StreamExt}; +use futures::stream::Stream; use futures::Future; use pin_project_lite::pin_project; use std::any::Any; use std::convert::TryInto; -use std::iter; use std::pin::Pin; use std::sync::Arc; use std::task::{Context, Poll}; @@ -339,22 +337,15 @@ fn window_aggregate_batch( window_accumulators: &mut [WindowAccumulatorItem], expressions: &[Vec>], ) -> Result>> { - // 1.1 iterate accumulators and respective expressions together - // 1.2 evaluate expressions - // 1.3 update / merge window accumulators with the expressions' values - - // 1.1 window_accumulators .iter_mut() .zip(expressions) .map(|(window_acc, expr)| { - // 1.2 let values = &expr .iter() - .map(|e| e.evaluate(batch)) + .map(|e| e.evaluate(&batch)) .map(|r| r.map(|v| v.into_array(batch.num_rows()))) .collect::>>()?; - window_acc.scan_batch(batch.num_rows(), values) }) .into_iter() @@ -380,60 +371,50 @@ fn create_window_accumulators( .collect::>>() } -async fn compute_window_aggregate( - schema: SchemaRef, +/// Compute the window aggregate columns +/// +/// 1. get a list of window accumulators +/// 2. evaluate the args +/// 3. scan args with window functions +/// 4. concat with final aggregations +/// +/// FIXME so far this fn does not support: +/// 1. partition by +/// 2. order by +/// 3. window frame +/// +/// which will require further work: +/// 1. inter-partition order by using vec partition-point (https://github.com/apache/arrow-datafusion/issues/360) +/// 2. inter-partition parallelism using one-shot channel (https://github.com/apache/arrow-datafusion/issues/299) +/// 3. convert aggregation based window functions to be self-contain so that: (https://github.com/apache/arrow-datafusion/issues/361) +/// a. some can be grow-only window-accumulating +/// b. some can be grow-and-shrink window-accumulating +/// c. some can be based on segment tree +fn compute_window_aggregates( window_expr: Vec>, - mut input: SendableRecordBatchStream, -) -> ArrowResult { - let mut window_accumulators = create_window_accumulators(&window_expr) - .map_err(DataFusionError::into_arrow_external_error)?; - - let expressions = window_expressions(&window_expr) - .map_err(DataFusionError::into_arrow_external_error)?; - - let expressions = Arc::new(expressions); - - // TODO each element shall have some size hint - let mut accumulator: Vec> = - iter::repeat(vec![]).take(window_expr.len()).collect(); - - let mut original_batches: Vec = vec![]; - - let mut total_num_rows = 0; - - while let Some(batch) = input.next().await { - let batch = batch?; - total_num_rows += batch.num_rows(); - original_batches.push(batch.clone()); - - let batch_aggregated = - window_aggregate_batch(&batch, &mut window_accumulators, &expressions) - .map_err(DataFusionError::into_arrow_external_error)?; - accumulator.iter_mut().zip(batch_aggregated).for_each( - |(acc_for_window, window_batch)| { - if let Some(data) = window_batch { - acc_for_window.push(data); - } - }, - ); + batch: &RecordBatch, +) -> Result> { + let mut window_accumulators = create_window_accumulators(&window_expr)?; + let expressions = Arc::new(window_expressions(&window_expr)?); + let num_rows = batch.num_rows(); + let window_aggregates = + window_aggregate_batch(batch, &mut window_accumulators, &expressions)?; + let final_aggregates = finalize_window_aggregation(&window_accumulators)?; + + // both must equal to window_expr.len() + if window_aggregates.len() != final_aggregates.len() { + return Err(DataFusionError::Internal( + "Impossibly got len mismatch".to_owned(), + )); } - let aggregated_mapped = finalize_window_aggregation(&window_accumulators) - .map_err(DataFusionError::into_arrow_external_error)?; - - let mut columns: Vec = accumulator + window_aggregates .iter() - .zip(aggregated_mapped) - .map(|(acc, agg)| { - Ok(match (acc, agg) { - (acc, Some(scalar_value)) if acc.is_empty() => { - scalar_value.to_array_of_size(total_num_rows) - } - (acc, None) if !acc.is_empty() => { - let vec_array: Vec<&dyn Array> = - acc.iter().map(|arc| arc.as_ref()).collect(); - concat(&vec_array)? - } + .zip(final_aggregates) + .map(|(wa, fa)| { + Ok(match (wa, fa) { + (None, Some(fa)) => fa.to_array_of_size(num_rows), + (Some(wa), None) if wa.len() == num_rows => wa.clone(), _ => { return Err(DataFusionError::Execution( "Invalid window function behavior".to_owned(), @@ -441,20 +422,7 @@ async fn compute_window_aggregate( } }) }) - .collect::>>() - .map_err(DataFusionError::into_arrow_external_error)?; - - for i in 0..(schema.fields().len() - window_expr.len()) { - let col = concat( - &original_batches - .iter() - .map(|batch| batch.column(i).as_ref()) - .collect::>(), - )?; - columns.push(col); - } - - RecordBatch::try_new(schema.clone(), columns) + .collect() } impl WindowAggStream { @@ -467,7 +435,8 @@ impl WindowAggStream { let (tx, rx) = futures::channel::oneshot::channel(); let schema_clone = schema.clone(); tokio::spawn(async move { - let result = compute_window_aggregate(schema_clone, window_expr, input).await; + let schema = schema_clone.clone(); + let result = WindowAggStream::process(input, window_expr, schema).await; tx.send(result) }); @@ -477,6 +446,30 @@ impl WindowAggStream { schema, } } + + async fn process( + input: SendableRecordBatchStream, + window_expr: Vec>, + schema: SchemaRef, + ) -> ArrowResult { + let input_schema = input.schema(); + let batches = common::collect(input) + .await + .map_err(DataFusionError::into_arrow_external_error)?; + let batch = common::combine_batches(&batches, input_schema.clone())?; + if let Some(batch) = batch { + // calculate window cols + let mut columns = compute_window_aggregates(window_expr, &batch) + .map_err(DataFusionError::into_arrow_external_error)?; + // combine with the original cols + // note the setup of window aggregates is that they newly calculated window + // expressions are always prepended to the columns + columns.extend_from_slice(batch.columns()); + RecordBatch::try_new(schema, columns) + } else { + Ok(RecordBatch::new_empty(schema)) + } + } } impl Stream for WindowAggStream {