diff --git a/datafusion/core/tests/custom_sources_cases/provider_filter_pushdown.rs b/datafusion/core/tests/custom_sources_cases/provider_filter_pushdown.rs index c80c0b4bf54b..4c528b65c9b2 100644 --- a/datafusion/core/tests/custom_sources_cases/provider_filter_pushdown.rs +++ b/datafusion/core/tests/custom_sources_cases/provider_filter_pushdown.rs @@ -35,7 +35,7 @@ use datafusion::prelude::*; use datafusion::scalar::ScalarValue; use datafusion_catalog::Session; use datafusion_common::cast::as_primitive_array; -use datafusion_common::{internal_err, not_impl_err}; +use datafusion_common::{internal_err, not_impl_err, DataFusionError}; use datafusion_expr::expr::{BinaryExpr, Cast}; use datafusion_functions_aggregate::expr_fn::count; use datafusion_physical_expr::EquivalenceProperties; @@ -134,9 +134,25 @@ impl ExecutionPlan for CustomPlan { _partition: usize, _context: Arc, ) -> Result { + if self.batches.is_empty() { + return Ok(Box::pin(RecordBatchStreamAdapter::new( + self.schema(), + futures::stream::empty(), + ))); + } + let schema_captured = self.schema().clone(); Ok(Box::pin(RecordBatchStreamAdapter::new( self.schema(), - futures::stream::iter(self.batches.clone().into_iter().map(Ok)), + futures::stream::iter(self.batches.clone().into_iter().map(move |batch| { + let projection: Vec = schema_captured + .fields() + .iter() + .filter_map(|field| batch.schema().index_of(field.name()).ok()) + .collect(); + batch + .project(&projection) + .map_err(|e| DataFusionError::ArrowError(Box::new(e), None)) + })), ))) } diff --git a/datafusion/physical-plan/src/repartition/mod.rs b/datafusion/physical-plan/src/repartition/mod.rs index 843d975c7d76..f53e158a1e68 100644 --- a/datafusion/physical-plan/src/repartition/mod.rs +++ b/datafusion/physical-plan/src/repartition/mod.rs @@ -30,6 +30,7 @@ use super::metrics::{self, ExecutionPlanMetricsSet, MetricBuilder, MetricsSet}; use super::{ DisplayAs, ExecutionPlanProperties, RecordBatchStream, SendableRecordBatchStream, }; +use crate::coalesce::{LimitedBatchCoalescer, PushBatchStatus}; use crate::execution_plan::{CardinalityEffect, EvaluationType, SchedulingType}; use crate::hash_utils::create_hashes; use crate::metrics::{BaselineMetrics, SpillMetrics}; @@ -932,6 +933,7 @@ impl ExecutionPlan for RepartitionExec { spill_stream, 1, // Each receiver handles one input partition BaselineMetrics::new(&metrics, partition), + context.session_config().batch_size(), )) as SendableRecordBatchStream }) .collect::>(); @@ -970,6 +972,7 @@ impl ExecutionPlan for RepartitionExec { spill_stream, num_input_partitions, BaselineMetrics::new(&metrics, partition), + context.session_config().batch_size(), )) as SendableRecordBatchStream) } }) @@ -1427,9 +1430,12 @@ struct PerPartitionStream { /// Execution metrics baseline_metrics: BaselineMetrics, + + coalescer: LimitedBatchCoalescer, } impl PerPartitionStream { + #[allow(clippy::too_many_arguments)] fn new( schema: SchemaRef, receiver: DistributionReceiver, @@ -1438,9 +1444,10 @@ impl PerPartitionStream { spill_stream: SendableRecordBatchStream, num_input_partitions: usize, baseline_metrics: BaselineMetrics, + batch_size: usize, ) -> Self { Self { - schema, + schema: Arc::clone(&schema), receiver, _drop_helper: drop_helper, reservation, @@ -1448,6 +1455,7 @@ impl PerPartitionStream { state: StreamState::ReadingMemory, remaining_partitions: num_input_partitions, baseline_metrics, + coalescer: LimitedBatchCoalescer::new(schema, batch_size, None), } } @@ -1540,7 +1548,49 @@ impl Stream for PerPartitionStream { mut self: Pin<&mut Self>, cx: &mut Context<'_>, ) -> Poll> { - let poll = self.poll_next_inner(cx); + let cloned_time = self.baseline_metrics.elapsed_compute().clone(); + let mut completed = false; + + let poll; + loop { + if let Some(batch) = self.coalescer.next_completed_batch() { + poll = Poll::Ready(Some(Ok(batch))); + break; + } + if completed { + poll = Poll::Ready(None); + break; + } + let inner_poll = self.poll_next_inner(cx); + let _timer = cloned_time.timer(); + + match inner_poll { + Poll::Pending => { + poll = Poll::Pending; + break; + } + Poll::Ready(None) => { + completed = true; + self.coalescer.finish()?; + } + Poll::Ready(Some(Ok(batch))) => { + match self.coalescer.push_batch(batch)? { + PushBatchStatus::Continue => { + // Keep pushing more batches + } + PushBatchStatus::LimitReached => { + // limit was reached, so stop early + completed = true; + self.coalescer.finish()?; + } + } + } + Poll::Ready(Some(err)) => { + poll = Poll::Ready(Some(err)); + break; + } + } + } self.baseline_metrics.record_poll(poll) } } @@ -1575,9 +1625,9 @@ mod tests { use datafusion_common::exec_err; use datafusion_common::test_util::batches_to_sort_string; use datafusion_common_runtime::JoinSet; + use datafusion_execution::config::SessionConfig; use datafusion_execution::runtime_env::RuntimeEnvBuilder; use insta::assert_snapshot; - use itertools::Itertools; #[tokio::test] async fn one_to_many_round_robin() -> Result<()> { @@ -1671,7 +1721,10 @@ mod tests { input_partitions: Vec>, partitioning: Partitioning, ) -> Result>> { - let task_ctx = Arc::new(TaskContext::default()); + let task_ctx = Arc::new( + TaskContext::default() + .with_session_config(SessionConfig::new().with_batch_size(8)), + ); // create physical plan let exec = TestMemoryExec::try_new_exec(&input_partitions, Arc::clone(schema), None)?; @@ -1950,14 +2003,13 @@ mod tests { }); let batches_with_drop = crate::common::collect(output_stream1).await.unwrap(); - fn sort(batch: Vec) -> Vec { - batch - .into_iter() - .sorted_by_key(|b| format!("{b:?}")) - .collect() - } - - assert_eq!(sort(batches_without_drop), sort(batches_with_drop)); + let items_vec_with_drop = str_batches_to_vec(&batches_with_drop); + let items_set_with_drop: HashSet<&str> = + items_vec_with_drop.iter().copied().collect(); + assert_eq!( + items_set_with_drop.symmetric_difference(&items_set).count(), + 0 + ); } fn str_batches_to_vec(batches: &[RecordBatch]) -> Vec<&str> { @@ -2396,6 +2448,7 @@ mod test { use arrow::compute::SortOptions; use arrow::datatypes::{DataType, Field, Schema}; use datafusion_common::assert_batches_eq; + use datafusion_execution::config::SessionConfig; use super::*; use crate::test::TestMemoryExec; @@ -2507,8 +2560,10 @@ mod test { let runtime = RuntimeEnvBuilder::default() .with_memory_limit(64, 1.0) .build_arc()?; - - let task_ctx = TaskContext::default().with_runtime(runtime); + let session_config = SessionConfig::new().with_batch_size(4); + let task_ctx = TaskContext::default() + .with_runtime(runtime) + .with_session_config(session_config); let task_ctx = Arc::new(task_ctx); // Create physical plan with order preservation