Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ use datafusion::prelude::*;
use datafusion::scalar::ScalarValue;
use datafusion_catalog::Session;
use datafusion_common::cast::as_primitive_array;
use datafusion_common::{internal_err, not_impl_err};
use datafusion_common::{internal_err, not_impl_err, DataFusionError};
use datafusion_expr::expr::{BinaryExpr, Cast};
use datafusion_functions_aggregate::expr_fn::count;
use datafusion_physical_expr::EquivalenceProperties;
Expand Down Expand Up @@ -134,9 +134,19 @@ impl ExecutionPlan for CustomPlan {
_partition: usize,
_context: Arc<TaskContext>,
) -> Result<SendableRecordBatchStream> {
let schema_captured = self.schema().clone();
Ok(Box::pin(RecordBatchStreamAdapter::new(
self.schema(),
futures::stream::iter(self.batches.clone().into_iter().map(Ok)),
futures::stream::iter(self.batches.clone().into_iter().map(move |batch| {
Copy link
Contributor Author

@jizezhang jizezhang Dec 1, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please see #18782 (comment) for my thoughts/reason on updating this test.

let projection: Vec<usize> = schema_captured
.fields()
.iter()
.filter_map(|field| batch.schema().index_of(field.name()).ok())
.collect();
batch
.project(&projection)
.map_err(|e| DataFusionError::ArrowError(Box::new(e), None))
})),
)))
}

Expand Down
100 changes: 75 additions & 25 deletions datafusion/physical-plan/src/repartition/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ use super::metrics::{self, ExecutionPlanMetricsSet, MetricBuilder, MetricsSet};
use super::{
DisplayAs, ExecutionPlanProperties, RecordBatchStream, SendableRecordBatchStream,
};
use crate::coalesce::LimitedBatchCoalescer;
use crate::execution_plan::{CardinalityEffect, EvaluationType, SchedulingType};
use crate::hash_utils::create_hashes;
use crate::metrics::{BaselineMetrics, SpillMetrics};
Expand Down Expand Up @@ -932,6 +933,7 @@ impl ExecutionPlan for RepartitionExec {
spill_stream,
1, // Each receiver handles one input partition
BaselineMetrics::new(&metrics, partition),
None, // subsequent merge sort already does batching https://github.com/apache/datafusion/blob/e4dcf0c85611ad0bd291f03a8e03fe56d773eb16/datafusion/physical-plan/src/sorts/merge.rs#L286
)) as SendableRecordBatchStream
})
.collect::<Vec<_>>();
Expand Down Expand Up @@ -970,6 +972,7 @@ impl ExecutionPlan for RepartitionExec {
spill_stream,
num_input_partitions,
BaselineMetrics::new(&metrics, partition),
Some(context.session_config().batch_size()),
)) as SendableRecordBatchStream)
}
})
Expand Down Expand Up @@ -1427,9 +1430,12 @@ struct PerPartitionStream {

/// Execution metrics
baseline_metrics: BaselineMetrics,

batch_coalescer: Option<LimitedBatchCoalescer>,
}

impl PerPartitionStream {
#[allow(clippy::too_many_arguments)]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
#[allow(clippy::too_many_arguments)]
#[expect(clippy::too_many_arguments)]

By using expect instead of allow the Clippy rule will fail too once it is no more needed and the developer will have to remove it. Otherwise it may become obsolete.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for pointing out, will update in a revision later.

fn new(
schema: SchemaRef,
receiver: DistributionReceiver<MaybeBatch>,
Expand All @@ -1438,7 +1444,10 @@ impl PerPartitionStream {
spill_stream: SendableRecordBatchStream,
num_input_partitions: usize,
baseline_metrics: BaselineMetrics,
batch_size: Option<usize>,
) -> Self {
let batch_coalescer =
batch_size.map(|s| LimitedBatchCoalescer::new(Arc::clone(&schema), s, None));
Self {
schema,
receiver,
Expand All @@ -1448,6 +1457,7 @@ impl PerPartitionStream {
state: StreamState::ReadingMemory,
remaining_partitions: num_input_partitions,
baseline_metrics,
batch_coalescer,
}
}

Expand Down Expand Up @@ -1540,7 +1550,46 @@ impl Stream for PerPartitionStream {
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Option<Self::Item>> {
let poll = self.poll_next_inner(cx);
let poll = match self.batch_coalescer.take() {
Some(mut coalescer) => {
let cloned_time = self.baseline_metrics.elapsed_compute().clone();
let mut completed = false;
let poll;
loop {
if let Some(batch) = coalescer.next_completed_batch() {
poll = Poll::Ready(Some(Ok(batch)));
break;
}
if completed {
poll = Poll::Ready(None);
break;
}
let inner_poll = self.poll_next_inner(cx);
let _timer = cloned_time.timer();
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't this be before the poll_next_inner() call ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My understanding was that poll_next_inner already tracks compute time, hence I put it here to only track compute time for coalescer methods.


match inner_poll {
Poll::Pending => {
poll = Poll::Pending;
break;
}
Poll::Ready(None) => {
completed = true;
coalescer.finish()?;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
coalescer.finish()?;
if let Err(e) = coalescer.finish() {
self.batch_coalescer = Some(coalescer);
return self.baseline_metrics.record_poll(Poll::Ready(Some(Err(e))));
}

Otherwise in case of an error the self.batch_coalescer won't be restored.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think in this case, the loop would continue and then break on line 1561 or 1565. Then line 1588 is reached, which restores the self.batch_coalescer, but please let me know if that does not make sense.

}
Poll::Ready(Some(Ok(batch))) => {
coalescer.push_batch(batch)?;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
coalescer.push_batch(batch)?;
coalescer.?;
if let Err(e) = coalescer.push_batch(batch) {
self.batch_coalescer = Some(coalescer);
return self.baseline_metrics.record_poll(Poll::Ready(Some(Err(e))));
}

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similar to reply above, I think in this case, the loop would continue and then break on line 1561 or 1565. Then line 1588 is reached, which restores the self.batch_coalescer, but please let me know if that does not make sense.

}
Poll::Ready(Some(err)) => {
poll = Poll::Ready(Some(err));
break;
}
}
}
self.batch_coalescer = Some(coalescer);
poll
}
None => self.poll_next_inner(cx),
};
self.baseline_metrics.record_poll(poll)
}
}
Expand Down Expand Up @@ -1577,7 +1626,6 @@ mod tests {
use datafusion_common_runtime::JoinSet;
use datafusion_execution::runtime_env::RuntimeEnvBuilder;
use insta::assert_snapshot;
use itertools::Itertools;

#[tokio::test]
async fn one_to_many_round_robin() -> Result<()> {
Expand All @@ -1591,10 +1639,13 @@ mod tests {
repartition(&schema, partitions, Partitioning::RoundRobinBatch(4)).await?;

assert_eq!(4, output_partitions.len());
assert_eq!(13, output_partitions[0].len());
assert_eq!(13, output_partitions[1].len());
assert_eq!(12, output_partitions[2].len());
assert_eq!(12, output_partitions[3].len());
for partition in &output_partitions {
assert_eq!(1, partition.len());
}
assert_eq!(13 * 8, output_partitions[0][0].num_rows());
assert_eq!(13 * 8, output_partitions[1][0].num_rows());
assert_eq!(12 * 8, output_partitions[2][0].num_rows());
assert_eq!(12 * 8, output_partitions[3][0].num_rows());

Ok(())
}
Expand All @@ -1611,7 +1662,7 @@ mod tests {
repartition(&schema, partitions, Partitioning::RoundRobinBatch(1)).await?;

assert_eq!(1, output_partitions.len());
assert_eq!(150, output_partitions[0].len());
assert_eq!(150 * 8, output_partitions[0][0].num_rows());

Ok(())
}
Expand All @@ -1627,12 +1678,12 @@ mod tests {
let output_partitions =
repartition(&schema, partitions, Partitioning::RoundRobinBatch(5)).await?;

let total_rows_per_partition = 8 * 50 * 3 / 5;
assert_eq!(5, output_partitions.len());
assert_eq!(30, output_partitions[0].len());
assert_eq!(30, output_partitions[1].len());
assert_eq!(30, output_partitions[2].len());
assert_eq!(30, output_partitions[3].len());
assert_eq!(30, output_partitions[4].len());
for partition in output_partitions {
assert_eq!(1, partition.len());
assert_eq!(total_rows_per_partition, partition[0].num_rows());
}

Ok(())
}
Expand Down Expand Up @@ -1707,12 +1758,12 @@ mod tests {

let output_partitions = handle.join().await.unwrap().unwrap();

let total_rows_per_partition = 8 * 50 * 3 / 5;
assert_eq!(5, output_partitions.len());
assert_eq!(30, output_partitions[0].len());
assert_eq!(30, output_partitions[1].len());
assert_eq!(30, output_partitions[2].len());
assert_eq!(30, output_partitions[3].len());
assert_eq!(30, output_partitions[4].len());
for partition in output_partitions {
assert_eq!(1, partition.len());
assert_eq!(total_rows_per_partition, partition[0].num_rows());
}

Ok(())
}
Expand Down Expand Up @@ -1950,14 +2001,13 @@ mod tests {
});
let batches_with_drop = crate::common::collect(output_stream1).await.unwrap();

fn sort(batch: Vec<RecordBatch>) -> Vec<RecordBatch> {
batch
.into_iter()
.sorted_by_key(|b| format!("{b:?}"))
.collect()
}

assert_eq!(sort(batches_without_drop), sort(batches_with_drop));
let items_vec_with_drop = str_batches_to_vec(&batches_with_drop);
let items_set_with_drop: HashSet<&str> =
items_vec_with_drop.iter().copied().collect();
assert_eq!(
items_set_with_drop.symmetric_difference(&items_set).count(),
0
);
}

fn str_batches_to_vec(batches: &[RecordBatch]) -> Vec<&str> {
Expand Down