diff --git a/datafusion/core/tests/custom_sources_cases/provider_filter_pushdown.rs b/datafusion/core/tests/custom_sources_cases/provider_filter_pushdown.rs index c80c0b4bf54b..f4e905e1eda0 100644 --- a/datafusion/core/tests/custom_sources_cases/provider_filter_pushdown.rs +++ b/datafusion/core/tests/custom_sources_cases/provider_filter_pushdown.rs @@ -35,7 +35,7 @@ use datafusion::prelude::*; use datafusion::scalar::ScalarValue; use datafusion_catalog::Session; use datafusion_common::cast::as_primitive_array; -use datafusion_common::{internal_err, not_impl_err}; +use datafusion_common::{internal_err, not_impl_err, DataFusionError}; use datafusion_expr::expr::{BinaryExpr, Cast}; use datafusion_functions_aggregate::expr_fn::count; use datafusion_physical_expr::EquivalenceProperties; @@ -134,9 +134,19 @@ impl ExecutionPlan for CustomPlan { _partition: usize, _context: Arc, ) -> Result { + let schema_captured = self.schema().clone(); Ok(Box::pin(RecordBatchStreamAdapter::new( self.schema(), - futures::stream::iter(self.batches.clone().into_iter().map(Ok)), + futures::stream::iter(self.batches.clone().into_iter().map(move |batch| { + let projection: Vec = schema_captured + .fields() + .iter() + .filter_map(|field| batch.schema().index_of(field.name()).ok()) + .collect(); + batch + .project(&projection) + .map_err(|e| DataFusionError::ArrowError(Box::new(e), None)) + })), ))) } diff --git a/datafusion/physical-plan/src/repartition/mod.rs b/datafusion/physical-plan/src/repartition/mod.rs index 843d975c7d76..7b10660cec2b 100644 --- a/datafusion/physical-plan/src/repartition/mod.rs +++ b/datafusion/physical-plan/src/repartition/mod.rs @@ -30,6 +30,7 @@ use super::metrics::{self, ExecutionPlanMetricsSet, MetricBuilder, MetricsSet}; use super::{ DisplayAs, ExecutionPlanProperties, RecordBatchStream, SendableRecordBatchStream, }; +use crate::coalesce::LimitedBatchCoalescer; use crate::execution_plan::{CardinalityEffect, EvaluationType, SchedulingType}; use crate::hash_utils::create_hashes; use crate::metrics::{BaselineMetrics, SpillMetrics}; @@ -932,6 +933,7 @@ impl ExecutionPlan for RepartitionExec { spill_stream, 1, // Each receiver handles one input partition BaselineMetrics::new(&metrics, partition), + None, // subsequent merge sort already does batching https://github.com/apache/datafusion/blob/e4dcf0c85611ad0bd291f03a8e03fe56d773eb16/datafusion/physical-plan/src/sorts/merge.rs#L286 )) as SendableRecordBatchStream }) .collect::>(); @@ -970,6 +972,7 @@ impl ExecutionPlan for RepartitionExec { spill_stream, num_input_partitions, BaselineMetrics::new(&metrics, partition), + Some(context.session_config().batch_size()), )) as SendableRecordBatchStream) } }) @@ -1427,9 +1430,12 @@ struct PerPartitionStream { /// Execution metrics baseline_metrics: BaselineMetrics, + + batch_coalescer: Option, } impl PerPartitionStream { + #[allow(clippy::too_many_arguments)] fn new( schema: SchemaRef, receiver: DistributionReceiver, @@ -1438,7 +1444,10 @@ impl PerPartitionStream { spill_stream: SendableRecordBatchStream, num_input_partitions: usize, baseline_metrics: BaselineMetrics, + batch_size: Option, ) -> Self { + let batch_coalescer = + batch_size.map(|s| LimitedBatchCoalescer::new(Arc::clone(&schema), s, None)); Self { schema, receiver, @@ -1448,6 +1457,7 @@ impl PerPartitionStream { state: StreamState::ReadingMemory, remaining_partitions: num_input_partitions, baseline_metrics, + batch_coalescer, } } @@ -1540,7 +1550,46 @@ impl Stream for PerPartitionStream { mut self: Pin<&mut Self>, cx: &mut Context<'_>, ) -> Poll> { - let poll = self.poll_next_inner(cx); + let poll = match self.batch_coalescer.take() { + Some(mut coalescer) => { + let cloned_time = self.baseline_metrics.elapsed_compute().clone(); + let mut completed = false; + let poll; + loop { + if let Some(batch) = coalescer.next_completed_batch() { + poll = Poll::Ready(Some(Ok(batch))); + break; + } + if completed { + poll = Poll::Ready(None); + break; + } + let inner_poll = self.poll_next_inner(cx); + let _timer = cloned_time.timer(); + + match inner_poll { + Poll::Pending => { + poll = Poll::Pending; + break; + } + Poll::Ready(None) => { + completed = true; + coalescer.finish()?; + } + Poll::Ready(Some(Ok(batch))) => { + coalescer.push_batch(batch)?; + } + Poll::Ready(Some(err)) => { + poll = Poll::Ready(Some(err)); + break; + } + } + } + self.batch_coalescer = Some(coalescer); + poll + } + None => self.poll_next_inner(cx), + }; self.baseline_metrics.record_poll(poll) } } @@ -1577,7 +1626,6 @@ mod tests { use datafusion_common_runtime::JoinSet; use datafusion_execution::runtime_env::RuntimeEnvBuilder; use insta::assert_snapshot; - use itertools::Itertools; #[tokio::test] async fn one_to_many_round_robin() -> Result<()> { @@ -1591,10 +1639,13 @@ mod tests { repartition(&schema, partitions, Partitioning::RoundRobinBatch(4)).await?; assert_eq!(4, output_partitions.len()); - assert_eq!(13, output_partitions[0].len()); - assert_eq!(13, output_partitions[1].len()); - assert_eq!(12, output_partitions[2].len()); - assert_eq!(12, output_partitions[3].len()); + for partition in &output_partitions { + assert_eq!(1, partition.len()); + } + assert_eq!(13 * 8, output_partitions[0][0].num_rows()); + assert_eq!(13 * 8, output_partitions[1][0].num_rows()); + assert_eq!(12 * 8, output_partitions[2][0].num_rows()); + assert_eq!(12 * 8, output_partitions[3][0].num_rows()); Ok(()) } @@ -1611,7 +1662,7 @@ mod tests { repartition(&schema, partitions, Partitioning::RoundRobinBatch(1)).await?; assert_eq!(1, output_partitions.len()); - assert_eq!(150, output_partitions[0].len()); + assert_eq!(150 * 8, output_partitions[0][0].num_rows()); Ok(()) } @@ -1627,12 +1678,12 @@ mod tests { let output_partitions = repartition(&schema, partitions, Partitioning::RoundRobinBatch(5)).await?; + let total_rows_per_partition = 8 * 50 * 3 / 5; assert_eq!(5, output_partitions.len()); - assert_eq!(30, output_partitions[0].len()); - assert_eq!(30, output_partitions[1].len()); - assert_eq!(30, output_partitions[2].len()); - assert_eq!(30, output_partitions[3].len()); - assert_eq!(30, output_partitions[4].len()); + for partition in output_partitions { + assert_eq!(1, partition.len()); + assert_eq!(total_rows_per_partition, partition[0].num_rows()); + } Ok(()) } @@ -1707,12 +1758,12 @@ mod tests { let output_partitions = handle.join().await.unwrap().unwrap(); + let total_rows_per_partition = 8 * 50 * 3 / 5; assert_eq!(5, output_partitions.len()); - assert_eq!(30, output_partitions[0].len()); - assert_eq!(30, output_partitions[1].len()); - assert_eq!(30, output_partitions[2].len()); - assert_eq!(30, output_partitions[3].len()); - assert_eq!(30, output_partitions[4].len()); + for partition in output_partitions { + assert_eq!(1, partition.len()); + assert_eq!(total_rows_per_partition, partition[0].num_rows()); + } Ok(()) } @@ -1950,14 +2001,13 @@ mod tests { }); let batches_with_drop = crate::common::collect(output_stream1).await.unwrap(); - fn sort(batch: Vec) -> Vec { - batch - .into_iter() - .sorted_by_key(|b| format!("{b:?}")) - .collect() - } - - assert_eq!(sort(batches_without_drop), sort(batches_with_drop)); + let items_vec_with_drop = str_batches_to_vec(&batches_with_drop); + let items_set_with_drop: HashSet<&str> = + items_vec_with_drop.iter().copied().collect(); + assert_eq!( + items_set_with_drop.symmetric_difference(&items_set).count(), + 0 + ); } fn str_batches_to_vec(batches: &[RecordBatch]) -> Vec<&str> {