Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
205 changes: 193 additions & 12 deletions datafusion/src/physical_plan/repartition.rs
Original file line number Diff line number Diff line change
Expand Up @@ -147,12 +147,13 @@ impl ExecutionPlan for RepartitionExec {
let fetch_time = self.fetch_time_nanos.clone();
let repart_time = self.repart_time_nanos.clone();
let send_time = self.send_time_nanos.clone();
let mut txs: HashMap<_, _> = channels
let txs: HashMap<_, _> = channels
.iter()
.map(|(partition, (tx, _rx))| (*partition, tx.clone()))
.collect();
let partitioning = self.partitioning.clone();
let _: JoinHandle<Result<()>> = tokio::spawn(async move {
let mut txs_captured = txs.clone();
let input_task: JoinHandle<Result<()>> = tokio::spawn(async move {
// execute the child operator
let now = Instant::now();
let mut stream = input.execute(i).await?;
Expand All @@ -170,13 +171,13 @@ impl ExecutionPlan for RepartitionExec {
if result.is_none() {
break;
}
let result = result.unwrap();
let result: ArrowResult<RecordBatch> = result.unwrap();

match &partitioning {
Partitioning::RoundRobinBatch(_) => {
let now = Instant::now();
let output_partition = counter % num_output_partitions;
let tx = txs.get_mut(&output_partition).unwrap();
let tx = txs_captured.get_mut(&output_partition).unwrap();
tx.send(Some(result)).map_err(|e| {
DataFusionError::Execution(e.to_string())
})?;
Expand Down Expand Up @@ -230,7 +231,9 @@ impl ExecutionPlan for RepartitionExec {
);
repart_time.add(now.elapsed().as_nanos() as usize);
let now = Instant::now();
let tx = txs.get_mut(&num_output_partition).unwrap();
let tx = txs_captured
.get_mut(&num_output_partition)
.unwrap();
tx.send(Some(output_batch)).map_err(|e| {
DataFusionError::Execution(e.to_string())
})?;
Expand All @@ -249,13 +252,12 @@ impl ExecutionPlan for RepartitionExec {
counter += 1;
}

// notify each output partition that this input partition has no more data
for (_, tx) in txs {
tx.send(None)
.map_err(|e| DataFusionError::Execution(e.to_string()))?;
}
Ok(())
});

// In a separate task, wait for each input to be done
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the actual code change (to check for return value in another task). Otherwise the rest of this PR is tests

// (and pass along any errors)
tokio::spawn(async move { Self::wait_for_task(input_task, txs).await });
}
}

Expand Down Expand Up @@ -308,6 +310,45 @@ impl RepartitionExec {
send_time_nanos: SQLMetric::time_nanos(),
})
}

/// Waits for `input_task` which is consuming one of the inputs to
Copy link
Contributor

@tustvold tustvold Jun 7, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if it might be slightly clearer to push the body of the main task into a fallible function, and to then handle propagating any error it returns within the spawned task? i.e. rather than propagating the error through the JoinHandle, make the task that is spawned onto tokio infallible and handle its errors internally??

Edit: I guess the advantage with this approach would be that you could propagate panics as well...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree the approach you describe would be clearer (and avoid needing a separate task) 👍

The reason I did not pull the main body out into its own function was mostly "trying to keep the diff small" (or perhaps my own laziness wanting to avoid having to figure out all the types of the arguments that got captured),

Perhaps that would be a good follow on PR (there is a lot of messiness / duplication for updating counters which I would also kind of like to fix too)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in #538

/// complete. Upon each successful completion, sends a `None` to
/// each of the output tx channels to signal one of the inputs is
/// complete. Upon error, propagates the errors to all output tx
/// channels.
async fn wait_for_task(
input_task: JoinHandle<Result<()>>,
txs: HashMap<usize, UnboundedSender<Option<ArrowResult<RecordBatch>>>>,
) {
// wait for completion, and propagate error
// note we ignore errors on send (.ok) as that means the receiver has already shutdown.
match input_task.await {
// Error in joining task
Err(e) => {
for (_, tx) in txs {
let err = DataFusionError::Execution(format!("Join Error: {}", e));
let err = Err(err.into_arrow_external_error());
tx.send(Some(err)).ok();
}
}
// Error from running input task
Ok(Err(e)) => {
for (_, tx) in txs {
// wrap it because need to send error to all output partitions
let err = DataFusionError::Execution(e.to_string());
let err = Err(err.into_arrow_external_error());
tx.send(Some(err)).ok();
}
}
// Input task completed successfully
Ok(Ok(())) => {
// notify each output partition that this input partition has no more data
for (_, tx) in txs {
tx.send(None).ok();
}
}
}
}
}

struct RepartitionStream {
Expand Down Expand Up @@ -356,10 +397,17 @@ impl RecordBatchStream for RepartitionStream {
#[cfg(test)]
mod tests {
use super::*;
use crate::physical_plan::memory::MemoryExec;
use arrow::array::UInt32Array;
use crate::{
assert_batches_sorted_eq,
physical_plan::memory::MemoryExec,
test::exec::{ErrorExec, MockExec},
};
use arrow::datatypes::{DataType, Field, Schema};
use arrow::record_batch::RecordBatch;
use arrow::{
array::{ArrayRef, StringArray, UInt32Array},
error::ArrowError,
};

#[tokio::test]
async fn one_to_many_round_robin() -> Result<()> {
Expand Down Expand Up @@ -517,4 +565,137 @@ mod tests {

Ok(())
}

#[tokio::test]
async fn unsupported_partitioning() {
// have to send at least one batch through to provoke error
let batch = RecordBatch::try_from_iter(vec![(
"my_awesome_field",
Arc::new(StringArray::from(vec!["foo", "bar"])) as ArrayRef,
)])
.unwrap();

let schema = batch.schema();
let input = MockExec::new(vec![Ok(batch)], schema);
// This generates an error (partitioning type not supported)
// but only after the plan is executed. The error should be
// returned and no results produced
let partitioning = Partitioning::UnknownPartitioning(1);
let exec = RepartitionExec::try_new(Arc::new(input), partitioning).unwrap();
let output_stream = exec.execute(0).await.unwrap();

// Expect that an error is returned
let result_string = crate::physical_plan::common::collect(output_stream)
.await
.unwrap_err()
.to_string();
assert!(
result_string
.contains("Unsupported repartitioning scheme UnknownPartitioning(1)"),
"actual: {}",
result_string
);
}

#[tokio::test]
async fn error_for_input_exec() {
// This generates an error on a call to execute. The error
// should be returned and no results produced.

let input = ErrorExec::new();
let partitioning = Partitioning::RoundRobinBatch(1);
let exec = RepartitionExec::try_new(Arc::new(input), partitioning).unwrap();

// Note: this should pass (the stream can be created) but the
// error when the input is executed should get passed back
let output_stream = exec.execute(0).await.unwrap();

// Expect that an error is returned
let result_string = crate::physical_plan::common::collect(output_stream)
.await
.unwrap_err()
.to_string();
assert!(
result_string.contains("ErrorExec, unsurprisingly, errored in partition 0"),
"actual: {}",
result_string
);
}

#[tokio::test]
async fn repartition_with_error_in_stream() {
let batch = RecordBatch::try_from_iter(vec![(
"my_awesome_field",
Arc::new(StringArray::from(vec!["foo", "bar"])) as ArrayRef,
)])
.unwrap();

// input stream returns one good batch and then one error. The
// error should be returned.
let err = Err(ArrowError::ComputeError("bad data error".to_string()));

let schema = batch.schema();
let input = MockExec::new(vec![Ok(batch), err], schema);
let partitioning = Partitioning::RoundRobinBatch(1);
let exec = RepartitionExec::try_new(Arc::new(input), partitioning).unwrap();

// Note: this should pass (the stream can be created) but the
// error when the input is executed should get passed back
let output_stream = exec.execute(0).await.unwrap();

// Expect that an error is returned
let result_string = crate::physical_plan::common::collect(output_stream)
.await
.unwrap_err()
.to_string();
assert!(
result_string.contains("bad data error"),
"actual: {}",
result_string
);
}

#[tokio::test]
async fn repartition_with_delayed_stream() {
let batch1 = RecordBatch::try_from_iter(vec![(
"my_awesome_field",
Arc::new(StringArray::from(vec!["foo", "bar"])) as ArrayRef,
)])
.unwrap();

let batch2 = RecordBatch::try_from_iter(vec![(
"my_awesome_field",
Arc::new(StringArray::from(vec!["frob", "baz"])) as ArrayRef,
)])
.unwrap();

// The mock exec doesn't return immediately (instead it
// requires the input to wait at least once)
let schema = batch1.schema();
let expected_batches = vec![batch1.clone(), batch2.clone()];
let input = MockExec::new(vec![Ok(batch1), Ok(batch2)], schema);
let partitioning = Partitioning::RoundRobinBatch(1);

let exec = RepartitionExec::try_new(Arc::new(input), partitioning).unwrap();

let expected = vec![
"+------------------+",
"| my_awesome_field |",
"+------------------+",
"| foo |",
"| bar |",
"| frob |",
"| baz |",
"+------------------+",
];

assert_batches_sorted_eq!(&expected, &expected_batches);

let output_stream = exec.execute(0).await.unwrap();
let batches = crate::physical_plan::common::collect(output_stream)
.await
.unwrap();

assert_batches_sorted_eq!(&expected, &batches);
}
}
Loading