diff --git a/datafusion/core/src/datasource/memory.rs b/datafusion/core/src/datasource/memory.rs index 784aa2aff232..5398bb0903ca 100644 --- a/datafusion/core/src/datasource/memory.rs +++ b/datafusion/core/src/datasource/memory.rs @@ -29,12 +29,12 @@ use async_trait::async_trait; use datafusion_common::SchemaExt; use datafusion_execution::TaskContext; use tokio::sync::RwLock; +use tokio::task::JoinSet; use crate::datasource::{TableProvider, TableType}; use crate::error::{DataFusionError, Result}; use crate::execution::context::SessionState; use crate::logical_expr::Expr; -use crate::physical_plan::common::AbortOnDropSingle; use crate::physical_plan::insert::{DataSink, InsertExec}; use crate::physical_plan::memory::MemoryExec; use crate::physical_plan::{common, SendableRecordBatchStream}; @@ -89,26 +89,31 @@ impl MemTable { let exec = t.scan(state, None, &[], None).await?; let partition_count = exec.output_partitioning().partition_count(); - let tasks = (0..partition_count) - .map(|part_i| { - let task = state.task_ctx(); - let exec = exec.clone(); - let task = tokio::spawn(async move { - let stream = exec.execute(part_i, task)?; - common::collect(stream).await - }); - - AbortOnDropSingle::new(task) - }) - // this collect *is needed* so that the join below can - // switch between tasks - .collect::>(); + let mut join_set = JoinSet::new(); + + for part_idx in 0..partition_count { + let task = state.task_ctx(); + let exec = exec.clone(); + join_set.spawn(async move { + let stream = exec.execute(part_idx, task)?; + common::collect(stream).await + }); + } let mut data: Vec> = Vec::with_capacity(exec.output_partitioning().partition_count()); - for result in futures::future::join_all(tasks).await { - data.push(result.map_err(|e| DataFusionError::External(Box::new(e)))??) + while let Some(result) = join_set.join_next().await { + match result { + Ok(res) => data.push(res?), + Err(e) => { + if e.is_panic() { + std::panic::resume_unwind(e.into_panic()); + } else { + unreachable!(); + } + } + } } let exec = MemoryExec::try_new(&data, schema.clone(), None)?; diff --git a/datafusion/core/src/datasource/physical_plan/csv.rs b/datafusion/core/src/datasource/physical_plan/csv.rs index 027bd1945be6..eba51615cddf 100644 --- a/datafusion/core/src/datasource/physical_plan/csv.rs +++ b/datafusion/core/src/datasource/physical_plan/csv.rs @@ -23,7 +23,6 @@ use crate::datasource::physical_plan::file_stream::{ }; use crate::datasource::physical_plan::FileMeta; use crate::error::{DataFusionError, Result}; -use crate::physical_plan::common::AbortOnDropSingle; use crate::physical_plan::expressions::PhysicalSortExpr; use crate::physical_plan::metrics::{ExecutionPlanMetricsSet, MetricsSet}; use crate::physical_plan::{ @@ -46,7 +45,7 @@ use std::fs; use std::path::Path; use std::sync::Arc; use std::task::Poll; -use tokio::task::{self, JoinHandle}; +use tokio::task::JoinSet; /// Execution plan for scanning a CSV file #[derive(Debug, Clone)] @@ -331,7 +330,7 @@ pub async fn plan_to_csv( ))); } - let mut tasks = vec![]; + let mut join_set = JoinSet::new(); for i in 0..plan.output_partitioning().partition_count() { let plan = plan.clone(); let filename = format!("part-{i}.csv"); @@ -340,22 +339,29 @@ pub async fn plan_to_csv( let mut writer = csv::Writer::new(file); let stream = plan.execute(i, task_ctx.clone())?; - let handle: JoinHandle> = task::spawn(async move { - stream + join_set.spawn(async move { + let result: Result<()> = stream .map(|batch| writer.write(&batch?)) .try_collect() .await - .map_err(DataFusionError::from) + .map_err(DataFusionError::from); + result }); - tasks.push(AbortOnDropSingle::new(handle)); } - futures::future::join_all(tasks) - .await - .into_iter() - .try_for_each(|result| { - result.map_err(|e| DataFusionError::Execution(format!("{e}")))? - })?; + while let Some(result) = join_set.join_next().await { + match result { + Ok(res) => res?, // propagate DataFusion error + Err(e) => { + if e.is_panic() { + std::panic::resume_unwind(e.into_panic()); + } else { + unreachable!(); + } + } + } + } + Ok(()) } diff --git a/datafusion/core/src/datasource/physical_plan/json.rs b/datafusion/core/src/datasource/physical_plan/json.rs index b736fd783999..64f70776606a 100644 --- a/datafusion/core/src/datasource/physical_plan/json.rs +++ b/datafusion/core/src/datasource/physical_plan/json.rs @@ -22,7 +22,6 @@ use crate::datasource::physical_plan::file_stream::{ }; use crate::datasource::physical_plan::FileMeta; use crate::error::{DataFusionError, Result}; -use crate::physical_plan::common::AbortOnDropSingle; use crate::physical_plan::expressions::PhysicalSortExpr; use crate::physical_plan::metrics::{ExecutionPlanMetricsSet, MetricsSet}; use crate::physical_plan::{ @@ -44,7 +43,7 @@ use std::io::BufReader; use std::path::Path; use std::sync::Arc; use std::task::Poll; -use tokio::task::{self, JoinHandle}; +use tokio::task::JoinSet; use super::FileScanConfig; @@ -266,7 +265,7 @@ pub async fn plan_to_json( ))); } - let mut tasks = vec![]; + let mut join_set = JoinSet::new(); for i in 0..plan.output_partitioning().partition_count() { let plan = plan.clone(); let filename = format!("part-{i}.json"); @@ -274,22 +273,29 @@ pub async fn plan_to_json( let file = fs::File::create(path)?; let mut writer = json::LineDelimitedWriter::new(file); let stream = plan.execute(i, task_ctx.clone())?; - let handle: JoinHandle> = task::spawn(async move { - stream + join_set.spawn(async move { + let result: Result<()> = stream .map(|batch| writer.write(&batch?)) .try_collect() .await - .map_err(DataFusionError::from) + .map_err(DataFusionError::from); + result }); - tasks.push(AbortOnDropSingle::new(handle)); } - futures::future::join_all(tasks) - .await - .into_iter() - .try_for_each(|result| { - result.map_err(|e| DataFusionError::Execution(format!("{e}")))? - })?; + while let Some(result) = join_set.join_next().await { + match result { + Ok(res) => res?, // propagate DataFusion error + Err(e) => { + if e.is_panic() { + std::panic::resume_unwind(e.into_panic()); + } else { + unreachable!(); + } + } + } + } + Ok(()) } diff --git a/datafusion/core/src/datasource/physical_plan/parquet.rs b/datafusion/core/src/datasource/physical_plan/parquet.rs index f538255bc20d..96e5ce9fa0fd 100644 --- a/datafusion/core/src/datasource/physical_plan/parquet.rs +++ b/datafusion/core/src/datasource/physical_plan/parquet.rs @@ -31,7 +31,6 @@ use crate::{ execution::context::TaskContext, physical_optimizer::pruning::PruningPredicate, physical_plan::{ - common::AbortOnDropSingle, metrics::{ExecutionPlanMetricsSet, MetricBuilder, MetricsSet}, ordering_equivalence_properties_helper, DisplayFormatType, ExecutionPlan, Partitioning, SendableRecordBatchStream, Statistics, @@ -64,6 +63,7 @@ use parquet::arrow::{ArrowWriter, ParquetRecordBatchStreamBuilder, ProjectionMas use parquet::basic::{ConvertedType, LogicalType}; use parquet::file::{metadata::ParquetMetaData, properties::WriterProperties}; use parquet::schema::types::ColumnDescriptor; +use tokio::task::JoinSet; mod metrics; pub mod page_filter; @@ -701,7 +701,7 @@ pub async fn plan_to_parquet( ))); } - let mut tasks = vec![]; + let mut join_set = JoinSet::new(); for i in 0..plan.output_partitioning().partition_count() { let plan = plan.clone(); let filename = format!("part-{i}.parquet"); @@ -710,27 +710,30 @@ pub async fn plan_to_parquet( let mut writer = ArrowWriter::try_new(file, plan.schema(), writer_properties.clone())?; let stream = plan.execute(i, task_ctx.clone())?; - let handle: tokio::task::JoinHandle> = - tokio::task::spawn(async move { - stream - .map(|batch| { - writer.write(&batch?).map_err(DataFusionError::ParquetError) - }) - .try_collect() - .await - .map_err(DataFusionError::from)?; + join_set.spawn(async move { + stream + .map(|batch| writer.write(&batch?).map_err(DataFusionError::ParquetError)) + .try_collect() + .await + .map_err(DataFusionError::from)?; + + writer.close().map_err(DataFusionError::from).map(|_| ()) + }); + } - writer.close().map_err(DataFusionError::from).map(|_| ()) - }); - tasks.push(AbortOnDropSingle::new(handle)); + while let Some(result) = join_set.join_next().await { + match result { + Ok(res) => res?, + Err(e) => { + if e.is_panic() { + std::panic::resume_unwind(e.into_panic()); + } else { + unreachable!(); + } + } + } } - futures::future::join_all(tasks) - .await - .into_iter() - .try_for_each(|result| { - result.map_err(|e| DataFusionError::Execution(format!("{e}")))? - })?; Ok(()) } diff --git a/datafusion/core/src/physical_plan/mod.rs b/datafusion/core/src/physical_plan/mod.rs index 5abecf6b167c..7efd5a19eeac 100644 --- a/datafusion/core/src/physical_plan/mod.rs +++ b/datafusion/core/src/physical_plan/mod.rs @@ -38,6 +38,7 @@ pub use display::{DefaultDisplay, DisplayAs, DisplayFormatType, VerboseDisplay}; use futures::stream::{Stream, TryStreamExt}; use std::fmt; use std::fmt::Debug; +use tokio::task::JoinSet; use datafusion_common::tree_node::Transformed; use datafusion_common::DataFusionError; @@ -445,20 +446,37 @@ pub async fn collect_partitioned( ) -> Result>> { let streams = execute_stream_partitioned(plan, context)?; + let mut join_set = JoinSet::new(); // Execute the plan and collect the results into batches. - let handles = streams - .into_iter() - .enumerate() - .map(|(idx, stream)| async move { - let handle = tokio::task::spawn(stream.try_collect()); - AbortOnDropSingle::new(handle).await.map_err(|e| { - DataFusionError::Execution(format!( - "collect_partitioned partition {idx} panicked: {e}" - )) - })? + streams.into_iter().enumerate().for_each(|(idx, stream)| { + join_set.spawn(async move { + let result: Result> = stream.try_collect().await; + (idx, result) }); + }); + + let mut batches = vec![]; + // Note that currently this doesn't identify the thread that panicked + // + // TODO: Replace with [join_next_with_id](https://docs.rs/tokio/latest/tokio/task/struct.JoinSet.html#method.join_next_with_id + // once it is stable + while let Some(result) = join_set.join_next().await { + match result { + Ok((idx, res)) => batches.push((idx, res?)), + Err(e) => { + if e.is_panic() { + std::panic::resume_unwind(e.into_panic()); + } else { + unreachable!(); + } + } + } + } + + batches.sort_by_key(|(idx, _)| *idx); + let batches = batches.into_iter().map(|(_, batch)| batch).collect(); - futures::future::try_join_all(handles).await + Ok(batches) } /// Execute the [ExecutionPlan] and return a vec with one stream per output partition @@ -713,7 +731,6 @@ pub mod unnest; pub mod values; pub mod windows; -use crate::physical_plan::common::AbortOnDropSingle; use crate::physical_plan::repartition::RepartitionExec; use crate::physical_plan::sorts::sort_preserving_merge::SortPreservingMergeExec; use datafusion_execution::TaskContext; diff --git a/datafusion/core/src/physical_plan/repartition/mod.rs b/datafusion/core/src/physical_plan/repartition/mod.rs index 72ff0c37135b..82f71ceade2d 100644 --- a/datafusion/core/src/physical_plan/repartition/mod.rs +++ b/datafusion/core/src/physical_plan/repartition/mod.rs @@ -263,7 +263,7 @@ struct RepartitionMetrics { /// Time in nanos to execute child operator and fetch batches fetch_time: metrics::Time, /// Time in nanos to perform repartitioning - repart_time: metrics::Time, + repartition_time: metrics::Time, /// Time in nanos for sending resulting batches to channels send_time: metrics::Time, } @@ -293,7 +293,7 @@ impl RepartitionMetrics { Self { fetch_time, - repart_time, + repartition_time: repart_time, send_time, } } @@ -407,7 +407,7 @@ impl ExecutionPlan for RepartitionExec { // note we use a custom channel that ensures there is always data for each receiver // but limits the amount of buffering if required. let (txs, rxs) = channels(num_output_partitions); - // Clone sender for ech input partitions + // Clone sender for each input partitions let txs = txs .into_iter() .map(|item| vec![item; num_input_partitions]) @@ -564,34 +564,31 @@ impl RepartitionExec { /// Pulls data from the specified input plan, feeding it to the /// output partitions based on the desired partitioning /// - /// i is the input partition index - /// /// txs hold the output sending channels for each output partition async fn pull_from_input( input: Arc, - i: usize, - mut txs: HashMap< + partition: usize, + mut output_channels: HashMap< usize, (DistributionSender, SharedMemoryReservation), >, partitioning: Partitioning, - r_metrics: RepartitionMetrics, + metrics: RepartitionMetrics, context: Arc, ) -> Result<()> { let mut partitioner = - BatchPartitioner::try_new(partitioning, r_metrics.repart_time.clone())?; + BatchPartitioner::try_new(partitioning, metrics.repartition_time.clone())?; // execute the child operator - let timer = r_metrics.fetch_time.timer(); - let mut stream = input.execute(i, context)?; + let timer = metrics.fetch_time.timer(); + let mut stream = input.execute(partition, context)?; timer.done(); - // While there are still outputs to send to, keep - // pulling inputs + // While there are still outputs to send to, keep pulling inputs let mut batches_until_yield = partitioner.num_partitions(); - while !txs.is_empty() { + while !output_channels.is_empty() { // fetch the next batch - let timer = r_metrics.fetch_time.timer(); + let timer = metrics.fetch_time.timer(); let result = stream.next().await; timer.done(); @@ -605,15 +602,15 @@ impl RepartitionExec { let (partition, batch) = res?; let size = batch.get_array_memory_size(); - let timer = r_metrics.send_time.timer(); + let timer = metrics.send_time.timer(); // if there is still a receiver, send to it - if let Some((tx, reservation)) = txs.get_mut(&partition) { + if let Some((tx, reservation)) = output_channels.get_mut(&partition) { reservation.lock().try_grow(size)?; if tx.send(Some(Ok(batch))).await.is_err() { // If the other end has hung up, it was an early shutdown (e.g. LIMIT) reservation.lock().shrink(size); - txs.remove(&partition); + output_channels.remove(&partition); } } timer.done();