Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -134,10 +134,38 @@ impl ExecutionPlan for CustomPlan {
_partition: usize,
_context: Arc<TaskContext>,
) -> Result<SendableRecordBatchStream> {
Ok(Box::pin(RecordBatchStreamAdapter::new(
self.schema(),
futures::stream::iter(self.batches.clone().into_iter().map(Ok)),
)))
if self.batches.is_empty() {
Ok(Box::pin(RecordBatchStreamAdapter::new(
self.schema(),
futures::stream::empty(),
)))
} else {
let batch_schema = self.batches[0].schema();
let projection: Vec<usize> = self
.schema()
.fields()
.iter()
.filter_map(|field| batch_schema.index_of(field.name()).ok())
.collect();

Ok(Box::pin(RecordBatchStreamAdapter::new(
self.schema(),
futures::stream::iter(self.batches.clone().into_iter().map(
move |batch| {
let res = batch.project(&projection);
match res {
Ok(b) => Ok(b),
Err(e) => {
Err(datafusion_common::DataFusionError::ArrowError(
Box::new(e),
None,
))
}
}
},
)),
)))
}
}

fn statistics(&self) -> Result<Statistics> {
Expand Down
192 changes: 115 additions & 77 deletions datafusion/physical-plan/src/repartition/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ use super::metrics::{self, ExecutionPlanMetricsSet, MetricBuilder, MetricsSet};
use super::{
DisplayAs, ExecutionPlanProperties, RecordBatchStream, SendableRecordBatchStream,
};
use crate::coalesce::LimitedBatchCoalescer;
use crate::execution_plan::{CardinalityEffect, EvaluationType, SchedulingType};
use crate::hash_utils::create_hashes;
use crate::metrics::{BaselineMetrics, SpillMetrics};
Expand Down Expand Up @@ -932,6 +933,7 @@ impl ExecutionPlan for RepartitionExec {
spill_stream,
1, // Each receiver handles one input partition
BaselineMetrics::new(&metrics, partition),
context.session_config().batch_size() / num_input_partitions,
)) as SendableRecordBatchStream
})
.collect::<Vec<_>>();
Expand Down Expand Up @@ -959,7 +961,6 @@ impl ExecutionPlan for RepartitionExec {
.into_iter()
.next()
.expect("at least one spill reader should exist");

Ok(Box::pin(PerPartitionStream::new(
schema_captured,
rx.into_iter()
Expand All @@ -970,6 +971,7 @@ impl ExecutionPlan for RepartitionExec {
spill_stream,
num_input_partitions,
BaselineMetrics::new(&metrics, partition),
context.session_config().batch_size(),
)) as SendableRecordBatchStream)
}
})
Expand Down Expand Up @@ -1427,9 +1429,12 @@ struct PerPartitionStream {

/// Execution metrics
baseline_metrics: BaselineMetrics,

batch_coalescer: LimitedBatchCoalescer,
}

impl PerPartitionStream {
#[allow(clippy::too_many_arguments)]
fn new(
schema: SchemaRef,
receiver: DistributionReceiver<MaybeBatch>,
Expand All @@ -1438,16 +1443,29 @@ impl PerPartitionStream {
spill_stream: SendableRecordBatchStream,
num_input_partitions: usize,
baseline_metrics: BaselineMetrics,
batch_size: usize,
) -> Self {
Self {
schema,
schema: Arc::clone(&schema),
receiver,
_drop_helper: drop_helper,
reservation,
spill_stream,
state: StreamState::ReadingMemory,
remaining_partitions: num_input_partitions,
baseline_metrics,
batch_coalescer: LimitedBatchCoalescer::new(schema, batch_size, None),
}
}

fn flush_remaining_batch(
&mut self,
) -> Poll<Option<std::result::Result<RecordBatch, DataFusionError>>> {
// Flush any remaining buffered batch
match self.batch_coalescer.finish() {
Ok(()) => Poll::Ready(self.batch_coalescer.next_completed_batch().map(Ok)),

Err(e) => Poll::Ready(Some(Err(e))),
}
}

Expand All @@ -1460,75 +1478,82 @@ impl PerPartitionStream {
let _timer = cloned_time.timer();

loop {
match self.state {
StreamState::ReadingMemory => {
// Poll the memory channel for next message
let value = match self.receiver.recv().poll_unpin(cx) {
Poll::Ready(v) => v,
Poll::Pending => {
// Nothing from channel, wait
return Poll::Pending;
}
};

match value {
Some(Some(v)) => match v {
Ok(RepartitionBatch::Memory(batch)) => {
// Release memory and return batch
self.reservation
.lock()
.shrink(batch.get_array_memory_size());
return Poll::Ready(Some(Ok(batch)));
loop {
match self.state {
StreamState::ReadingMemory => {
// Poll the memory channel for next message
let value = match self.receiver.recv().poll_unpin(cx) {
Poll::Ready(v) => v,
Poll::Pending => {
// Nothing from channel, wait
return Poll::Pending;
}
Ok(RepartitionBatch::Spilled) => {
// Batch was spilled, transition to reading from spill stream
// We must block on spill stream until we get the batch
// to preserve ordering
self.state = StreamState::ReadingSpilled;
};

match value {
Some(Some(v)) => match v {
Ok(RepartitionBatch::Memory(batch)) => {
// Release memory and return batch
self.reservation
.lock()
.shrink(batch.get_array_memory_size());
self.batch_coalescer.push_batch(batch)?;
break;
}
Ok(RepartitionBatch::Spilled) => {
// Batch was spilled, transition to reading from spill stream
// We must block on spill stream until we get the batch
// to preserve ordering
self.state = StreamState::ReadingSpilled;
continue;
}
Err(e) => {
return Poll::Ready(Some(Err(e)));
}
},
Some(None) => {
// One input partition finished
self.remaining_partitions -= 1;
if self.remaining_partitions == 0 {
// All input partitions finished
return self.flush_remaining_batch();
}
// Continue to poll for more data from other partitions
continue;
}
Err(e) => {
return Poll::Ready(Some(Err(e)));
None => {
// Channel closed unexpectedly
return self.flush_remaining_batch();
}
},
Some(None) => {
// One input partition finished
self.remaining_partitions -= 1;
if self.remaining_partitions == 0 {
// All input partitions finished
return Poll::Ready(None);
}
// Continue to poll for more data from other partitions
continue;
}
None => {
// Channel closed unexpectedly
return Poll::Ready(None);
}
}
}
StreamState::ReadingSpilled => {
// Poll spill stream for the spilled batch
match self.spill_stream.poll_next_unpin(cx) {
Poll::Ready(Some(Ok(batch))) => {
self.state = StreamState::ReadingMemory;
return Poll::Ready(Some(Ok(batch)));
}
Poll::Ready(Some(Err(e))) => {
return Poll::Ready(Some(Err(e)));
}
Poll::Ready(None) => {
// Spill stream ended, keep draining the memory channel
self.state = StreamState::ReadingMemory;
}
Poll::Pending => {
// Spilled batch not ready yet, must wait
// This preserves ordering by blocking until spill data arrives
return Poll::Pending;
StreamState::ReadingSpilled => {
// Poll spill stream for the spilled batch
match self.spill_stream.poll_next_unpin(cx) {
Poll::Ready(Some(Ok(batch))) => {
self.state = StreamState::ReadingMemory;
self.batch_coalescer.push_batch(batch)?;
break;
}
Poll::Ready(Some(Err(e))) => {
return Poll::Ready(Some(Err(e)));
}
Poll::Ready(None) => {
// Spill stream ended, keep draining the memory channel
self.state = StreamState::ReadingMemory;
}
Poll::Pending => {
// Spilled batch not ready yet, must wait
// This preserves ordering by blocking until spill data arrives
return Poll::Pending;
}
}
}
}
}
if let Some(batch) = self.batch_coalescer.next_completed_batch() {
return Poll::Ready(Some(Ok(batch)));
}
}
}
}
Expand Down Expand Up @@ -1575,9 +1600,9 @@ mod tests {
use datafusion_common::exec_err;
use datafusion_common::test_util::batches_to_sort_string;
use datafusion_common_runtime::JoinSet;
use datafusion_execution::config::SessionConfig;
use datafusion_execution::runtime_env::RuntimeEnvBuilder;
use insta::assert_snapshot;
use itertools::Itertools;

#[tokio::test]
async fn one_to_many_round_robin() -> Result<()> {
Expand All @@ -1588,7 +1613,7 @@ mod tests {

// repartition from 1 input to 4 output
let output_partitions =
repartition(&schema, partitions, Partitioning::RoundRobinBatch(4)).await?;
repartition(&schema, partitions, Partitioning::RoundRobinBatch(4), 8).await?;

assert_eq!(4, output_partitions.len());
assert_eq!(13, output_partitions[0].len());
Expand All @@ -1608,7 +1633,7 @@ mod tests {

// repartition from 3 input to 1 output
let output_partitions =
repartition(&schema, partitions, Partitioning::RoundRobinBatch(1)).await?;
repartition(&schema, partitions, Partitioning::RoundRobinBatch(1), 8).await?;

assert_eq!(1, output_partitions.len());
assert_eq!(150, output_partitions[0].len());
Expand All @@ -1625,7 +1650,7 @@ mod tests {

// repartition from 3 input to 5 output
let output_partitions =
repartition(&schema, partitions, Partitioning::RoundRobinBatch(5)).await?;
repartition(&schema, partitions, Partitioning::RoundRobinBatch(5), 8).await?;

assert_eq!(5, output_partitions.len());
assert_eq!(30, output_partitions[0].len());
Expand All @@ -1648,6 +1673,7 @@ mod tests {
&schema,
partitions,
Partitioning::Hash(vec![col("c0", &schema)?], 8),
8,
)
.await?;

Expand All @@ -1670,8 +1696,11 @@ mod tests {
schema: &SchemaRef,
input_partitions: Vec<Vec<RecordBatch>>,
partitioning: Partitioning,
batch_size: usize,
) -> Result<Vec<Vec<RecordBatch>>> {
let task_ctx = Arc::new(TaskContext::default());
let session_config = SessionConfig::new().with_batch_size(batch_size);
let task_ctx =
Arc::new(TaskContext::default().with_session_config(session_config));
// create physical plan
let exec =
TestMemoryExec::try_new_exec(&input_partitions, Arc::clone(schema), None)?;
Expand Down Expand Up @@ -1702,7 +1731,8 @@ mod tests {
vec![partition.clone(), partition.clone(), partition.clone()];

// repartition from 3 input to 5 output
repartition(&schema, partitions, Partitioning::RoundRobinBatch(5)).await
repartition(&schema, partitions, Partitioning::RoundRobinBatch(5), 8)
.await
});

let output_partitions = handle.join().await.unwrap().unwrap();
Expand Down Expand Up @@ -1898,7 +1928,9 @@ mod tests {
// with different compilers, we will compare the same execution with
// and without dropping the output stream.
async fn hash_repartition_with_dropping_output_stream() {
let task_ctx = Arc::new(TaskContext::default());
let session_config = SessionConfig::new().with_batch_size(4);
let task_ctx =
Arc::new(TaskContext::default().with_session_config(session_config));
let partitioning = Partitioning::Hash(
vec![Arc::new(crate::expressions::Column::new(
"my_awesome_field",
Expand Down Expand Up @@ -1950,14 +1982,17 @@ mod tests {
});
let batches_with_drop = crate::common::collect(output_stream1).await.unwrap();

fn sort(batch: Vec<RecordBatch>) -> Vec<RecordBatch> {
batch
.into_iter()
.sorted_by_key(|b| format!("{b:?}"))
.collect()
}
let items_vec_with_drop = str_batches_to_vec(&batches_with_drop);
let items_set_with_drop: HashSet<&str> =
items_vec_with_drop.iter().copied().collect();

assert_eq!(sort(batches_without_drop), sort(batches_with_drop));
assert_eq!(
items_set_with_drop.symmetric_difference(&items_set).count(),
0,
"items with drop {:?} and without drop {:?} are different",
items_set_with_drop,
items_set
);
}

fn str_batches_to_vec(batches: &[RecordBatch]) -> Vec<&str> {
Expand Down Expand Up @@ -2396,6 +2431,7 @@ mod test {
use arrow::compute::SortOptions;
use arrow::datatypes::{DataType, Field, Schema};
use datafusion_common::assert_batches_eq;
use datafusion_execution::config::SessionConfig;

use super::*;
use crate::test::TestMemoryExec;
Expand Down Expand Up @@ -2507,8 +2543,10 @@ mod test {
let runtime = RuntimeEnvBuilder::default()
.with_memory_limit(64, 1.0)
.build_arc()?;

let task_ctx = TaskContext::default().with_runtime(runtime);
let session_config = SessionConfig::new().with_batch_size(4);
let task_ctx = TaskContext::default()
.with_runtime(runtime)
.with_session_config(session_config);
let task_ctx = Arc::new(task_ctx);

// Create physical plan with order preservation
Expand Down
Loading