Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ use datafusion::prelude::*;
use datafusion::scalar::ScalarValue;
use datafusion_catalog::Session;
use datafusion_common::cast::as_primitive_array;
use datafusion_common::{internal_err, not_impl_err};
use datafusion_common::{internal_err, not_impl_err, DataFusionError};
use datafusion_expr::expr::{BinaryExpr, Cast};
use datafusion_functions_aggregate::expr_fn::count;
use datafusion_physical_expr::EquivalenceProperties;
Expand Down Expand Up @@ -134,9 +134,25 @@ impl ExecutionPlan for CustomPlan {
_partition: usize,
_context: Arc<TaskContext>,
) -> Result<SendableRecordBatchStream> {
if self.batches.is_empty() {
return Ok(Box::pin(RecordBatchStreamAdapter::new(
self.schema(),
futures::stream::empty(),
)));
}
let schema_captured = self.schema().clone();
Ok(Box::pin(RecordBatchStreamAdapter::new(
self.schema(),
futures::stream::iter(self.batches.clone().into_iter().map(Ok)),
futures::stream::iter(self.batches.clone().into_iter().map(move |batch| {
let projection: Vec<usize> = schema_captured
.fields()
.iter()
.filter_map(|field| batch.schema().index_of(field.name()).ok())
.collect();
batch
.project(&projection)
.map_err(|e| DataFusionError::ArrowError(Box::new(e), None))
})),
)))
}

Expand Down
83 changes: 69 additions & 14 deletions datafusion/physical-plan/src/repartition/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ use super::metrics::{self, ExecutionPlanMetricsSet, MetricBuilder, MetricsSet};
use super::{
DisplayAs, ExecutionPlanProperties, RecordBatchStream, SendableRecordBatchStream,
};
use crate::coalesce::{LimitedBatchCoalescer, PushBatchStatus};
use crate::execution_plan::{CardinalityEffect, EvaluationType, SchedulingType};
use crate::hash_utils::create_hashes;
use crate::metrics::{BaselineMetrics, SpillMetrics};
Expand Down Expand Up @@ -932,6 +933,7 @@ impl ExecutionPlan for RepartitionExec {
spill_stream,
1, // Each receiver handles one input partition
BaselineMetrics::new(&metrics, partition),
context.session_config().batch_size(),
)) as SendableRecordBatchStream
})
.collect::<Vec<_>>();
Expand Down Expand Up @@ -970,6 +972,7 @@ impl ExecutionPlan for RepartitionExec {
spill_stream,
num_input_partitions,
BaselineMetrics::new(&metrics, partition),
context.session_config().batch_size(),
)) as SendableRecordBatchStream)
}
})
Expand Down Expand Up @@ -1427,9 +1430,12 @@ struct PerPartitionStream {

/// Execution metrics
baseline_metrics: BaselineMetrics,

coalescer: LimitedBatchCoalescer,
}

impl PerPartitionStream {
#[allow(clippy::too_many_arguments)]
fn new(
schema: SchemaRef,
receiver: DistributionReceiver<MaybeBatch>,
Expand All @@ -1438,16 +1444,18 @@ impl PerPartitionStream {
spill_stream: SendableRecordBatchStream,
num_input_partitions: usize,
baseline_metrics: BaselineMetrics,
batch_size: usize,
) -> Self {
Self {
schema,
schema: Arc::clone(&schema),
receiver,
_drop_helper: drop_helper,
reservation,
spill_stream,
state: StreamState::ReadingMemory,
remaining_partitions: num_input_partitions,
baseline_metrics,
coalescer: LimitedBatchCoalescer::new(schema, batch_size, None),
}
}

Expand Down Expand Up @@ -1540,7 +1548,49 @@ impl Stream for PerPartitionStream {
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Option<Self::Item>> {
let poll = self.poll_next_inner(cx);
let cloned_time = self.baseline_metrics.elapsed_compute().clone();
let mut completed = false;

let poll;
loop {
if let Some(batch) = self.coalescer.next_completed_batch() {
poll = Poll::Ready(Some(Ok(batch)));
break;
}
if completed {
poll = Poll::Ready(None);
break;
}
let inner_poll = self.poll_next_inner(cx);
let _timer = cloned_time.timer();

match inner_poll {
Poll::Pending => {
poll = Poll::Pending;
break;
}
Poll::Ready(None) => {
completed = true;
self.coalescer.finish()?;
}
Poll::Ready(Some(Ok(batch))) => {
match self.coalescer.push_batch(batch)? {
PushBatchStatus::Continue => {
// Keep pushing more batches
}
PushBatchStatus::LimitReached => {
// limit was reached, so stop early
completed = true;
self.coalescer.finish()?;
}
}
}
Poll::Ready(Some(err)) => {
poll = Poll::Ready(Some(err));
break;
}
}
}
self.baseline_metrics.record_poll(poll)
}
}
Expand Down Expand Up @@ -1575,9 +1625,9 @@ mod tests {
use datafusion_common::exec_err;
use datafusion_common::test_util::batches_to_sort_string;
use datafusion_common_runtime::JoinSet;
use datafusion_execution::config::SessionConfig;
use datafusion_execution::runtime_env::RuntimeEnvBuilder;
use insta::assert_snapshot;
use itertools::Itertools;

#[tokio::test]
async fn one_to_many_round_robin() -> Result<()> {
Expand Down Expand Up @@ -1671,7 +1721,10 @@ mod tests {
input_partitions: Vec<Vec<RecordBatch>>,
partitioning: Partitioning,
) -> Result<Vec<Vec<RecordBatch>>> {
let task_ctx = Arc::new(TaskContext::default());
let task_ctx = Arc::new(
TaskContext::default()
.with_session_config(SessionConfig::new().with_batch_size(8)),
);
// create physical plan
let exec =
TestMemoryExec::try_new_exec(&input_partitions, Arc::clone(schema), None)?;
Expand Down Expand Up @@ -1950,14 +2003,13 @@ mod tests {
});
let batches_with_drop = crate::common::collect(output_stream1).await.unwrap();

fn sort(batch: Vec<RecordBatch>) -> Vec<RecordBatch> {
batch
.into_iter()
.sorted_by_key(|b| format!("{b:?}"))
.collect()
}

assert_eq!(sort(batches_without_drop), sort(batches_with_drop));
let items_vec_with_drop = str_batches_to_vec(&batches_with_drop);
let items_set_with_drop: HashSet<&str> =
items_vec_with_drop.iter().copied().collect();
assert_eq!(
items_set_with_drop.symmetric_difference(&items_set).count(),
0
);
}

fn str_batches_to_vec(batches: &[RecordBatch]) -> Vec<&str> {
Expand Down Expand Up @@ -2396,6 +2448,7 @@ mod test {
use arrow::compute::SortOptions;
use arrow::datatypes::{DataType, Field, Schema};
use datafusion_common::assert_batches_eq;
use datafusion_execution::config::SessionConfig;

use super::*;
use crate::test::TestMemoryExec;
Expand Down Expand Up @@ -2507,8 +2560,10 @@ mod test {
let runtime = RuntimeEnvBuilder::default()
.with_memory_limit(64, 1.0)
.build_arc()?;

let task_ctx = TaskContext::default().with_runtime(runtime);
let session_config = SessionConfig::new().with_batch_size(4);
let task_ctx = TaskContext::default()
.with_runtime(runtime)
.with_session_config(session_config);
let task_ctx = Arc::new(task_ctx);

// Create physical plan with order preservation
Expand Down