Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
176 changes: 123 additions & 53 deletions datafusion/physical-plan/src/repartition/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ use parking_lot::Mutex;

mod distributor_channels;
use distributor_channels::{
DistributionReceiver, DistributionSender, channels, partition_aware_channels,
DistributionReceiver, DistributionSender, partition_aware_channels,
};

/// A batch in the repartition queue - either in memory or spilled to disk.
Expand Down Expand Up @@ -299,26 +299,17 @@ impl RepartitionExecState {

let spill_manager = Arc::new(spill_manager);

let (txs, rxs) = if preserve_order {
// Create partition-aware channels with one channel per (input, output) pair
// This provides backpressure while maintaining proper ordering
let (txs_all, rxs_all) =
partition_aware_channels(num_input_partitions, num_output_partitions);
// Take transpose of senders and receivers. `state.channels` keeps track of entries per output partition
let txs = transpose(txs_all);
let rxs = transpose(rxs_all);
(txs, rxs)
} else {
// Create one channel per *output* partition with backpressure
let (txs, rxs) = channels(num_output_partitions);
// Clone sender for each input partitions
let txs = txs
.into_iter()
.map(|item| vec![item; num_input_partitions])
.collect::<Vec<_>>();
let rxs = rxs.into_iter().map(|item| vec![item]).collect::<Vec<_>>();
(txs, rxs)
};
// Create one channel per (input, output) pair regardless of mode. For
// non-preserve-order this removes the previous MPSC shared-sender
// pattern (and its per-output channel state mutex contention) — each
// input task now has its own SPSC channel to each output, and the
// non-preserve-order consumer merges the N streams with `select_all`
// below.
let (txs_all, rxs_all) =
partition_aware_channels(num_input_partitions, num_output_partitions);
// Transpose so the outer Vec is indexed by output partition.
let txs = transpose(txs_all);
let rxs = transpose(rxs_all);

let mut channels = HashMap::with_capacity(txs.len());
for (partition, (tx, rx)) in txs.into_iter().zip(rxs).enumerate() {
Expand All @@ -328,22 +319,15 @@ impl RepartitionExecState {
.register(context.memory_pool()),
));

// Create spill channels based on mode:
// - preserve_order: one spill channel per (input, output) pair for proper FIFO ordering
// - non-preserve-order: one shared spill channel per output partition since all inputs
// share the same receiver
// One spill channel per (input, output) pair for proper FIFO
// ordering within each input stream.
let max_file_size = context
.session_config()
.options()
.execution
.max_spill_file_size_bytes;
let num_spill_channels = if preserve_order {
num_input_partitions
} else {
1
};
let (spill_writers, spill_readers): (Vec<_>, Vec<_>) = (0
..num_spill_channels)
..num_input_partitions)
.map(|_| spill_pool::channel(max_file_size, Arc::clone(&spill_manager)))
.unzip();

Expand All @@ -367,16 +351,15 @@ impl RepartitionExecState {
let txs: HashMap<_, _> = channels
.iter()
.map(|(partition, channels)| {
// In preserve_order mode: each input gets its own spill writer (index i)
// In non-preserve-order mode: all inputs share spill writer 0 via clone
let spill_writer_idx = if preserve_order { i } else { 0 };
// Each input gets its own spill writer (indexed by input
// partition) matching the per-(input, output) channel
// layout.
(
*partition,
OutputChannel {
sender: channels.tx[i].clone(),
reservation: Arc::clone(&channels.reservation),
spill_writer: channels.spill_writers[spill_writer_idx]
.clone(),
spill_writer: channels.spill_writers[i].clone(),
},
)
})
Expand All @@ -393,7 +376,11 @@ impl RepartitionExecState {
txs,
partitioning.clone(),
metrics,
// preserve_order depends on partition index to start from 0
// preserve_order depends on partition index to start from 0.
// For non-preserve-order we previously passed `i`, but with
// per-(input, output) channels each input writes to its own
// channel so the round-robin partitioner should still be
// offset by input index to spread work across outputs.
if preserve_order { 0 } else { i },
num_input_partitions,
));
Expand Down Expand Up @@ -997,8 +984,6 @@ impl ExecutionPlan for RepartitionExec {
)?;
}

let num_input_partitions = input.output_partitioning().partition_count();

let stream = futures::stream::once(async move {
// lock scope
let (rx, reservation, spill_readers, abort_helper) = {
Expand Down Expand Up @@ -1077,23 +1062,41 @@ impl ExecutionPlan for RepartitionExec {
.with_spill_manager(spill_manager)
.build()
} else {
// Non-preserve-order case: single input stream, so use the first spill reader
let spill_stream = spill_readers
// Non-preserve-order case: each input has its own channel and
// spill reader. Build one `PerPartitionStream` per input and
// merge them with `select_all` (unordered first-ready merge).
// Coalescing is done on the merged output so batches from
// different inputs are combined into properly-sized batches.
let input_streams = rx
.into_iter()
.next()
.expect("at least one spill reader should exist");
.zip(spill_readers)
.map(|(receiver, spill_stream)| {
Box::pin(PerPartitionStream::new(
Arc::clone(&schema_captured),
receiver,
Arc::clone(&abort_helper),
Arc::clone(&reservation),
spill_stream,
// Each stream now corresponds to one input
// partition, so it expects a single completion
// signal.
1,
BaselineMetrics::new(&metrics, partition),
// Coalescing happens on the merged stream below.
None,
)) as SendableRecordBatchStream
})
.collect::<Vec<_>>();

let merged = Box::pin(RecordBatchStreamAdapter::new(
Arc::clone(&schema_captured),
futures::stream::select_all(input_streams),
)) as SendableRecordBatchStream;

Ok(Box::pin(PerPartitionStream::new(
Ok(Box::pin(CoalescingOutputStream::new(
merged,
schema_captured,
rx.into_iter()
.next()
.expect("at least one receiver should exist"),
abort_helper,
reservation,
spill_stream,
num_input_partitions,
BaselineMetrics::new(&metrics, partition),
Some(context.session_config().batch_size()),
context.session_config().batch_size(),
)) as SendableRecordBatchStream)
}
})
Expand Down Expand Up @@ -1758,6 +1761,73 @@ impl RecordBatchStream for PerPartitionStream {
}
}

/// Wraps a `SendableRecordBatchStream` with a [`LimitedBatchCoalescer`] so the
/// output batches reach the configured target batch size. Used to coalesce the
/// unordered merge of per-input streams in the non-preserve-order case.
struct CoalescingOutputStream {
input: SendableRecordBatchStream,
coalescer: LimitedBatchCoalescer,
schema: SchemaRef,
completed: bool,
}

impl CoalescingOutputStream {
fn new(
input: SendableRecordBatchStream,
schema: SchemaRef,
target_batch_size: usize,
) -> Self {
Self {
coalescer: LimitedBatchCoalescer::new(
Arc::clone(&schema),
target_batch_size,
None,
),
input,
schema,
completed: false,
}
}
}

impl Stream for CoalescingOutputStream {
type Item = Result<RecordBatch>;

fn poll_next(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Option<Self::Item>> {
loop {
if let Some(batch) = self.coalescer.next_completed_batch() {
return Poll::Ready(Some(Ok(batch)));
}
if self.completed {
return Poll::Ready(None);
}
match ready!(self.input.poll_next_unpin(cx)) {
None => {
self.completed = true;
if let Err(e) = self.coalescer.finish() {
return Poll::Ready(Some(Err(e)));
}
}
Some(Ok(batch)) => {
if let Err(e) = self.coalescer.push_batch(batch) {
return Poll::Ready(Some(Err(e)));
}
}
Some(Err(e)) => return Poll::Ready(Some(Err(e))),
}
}
}
}

impl RecordBatchStream for CoalescingOutputStream {
fn schema(&self) -> SchemaRef {
Arc::clone(&self.schema)
}
}

#[cfg(test)]
mod tests {
use std::collections::HashSet;
Expand Down
Loading