diff --git a/datafusion/core/src/scheduler/mod.rs b/datafusion/core/src/scheduler/mod.rs index 3a5a6131c5ce..24c30dc3e495 100644 --- a/datafusion/core/src/scheduler/mod.rs +++ b/datafusion/core/src/scheduler/mod.rs @@ -323,7 +323,11 @@ mod tests { async fn test_simple() { init_logging(); - let scheduler = Scheduler::new(4); + let scheduler = SchedulerBuilder::new(4) + .panic_handler(|panic| { + unreachable!("not expect panic: {:?}", panic); + }) + .build(); let config = SessionConfig::new().with_target_partitions(4); let context = SessionContext::with_config(config); @@ -341,6 +345,8 @@ mod tests { "select id, b from (select id, b from table1 union all select id, b from table2 where a > 100 order by id) as t where b > 10 order by id, b", "select id, MIN(b), MAX(b), AVG(b) from table1 group by id order by id", "select count(*) from table1 where table1.a > 4", + "WITH gp AS (SELECT id FROM table1 GROUP BY id) + SELECT COUNT(CAST(CAST(gp.id || 'xx' AS TIMESTAMP) AS BIGINT)) FROM gp", ]; for sql in queries { @@ -353,8 +359,8 @@ mod tests { info!("Plan: {}", displayable(plan.as_ref()).indent()); let stream = scheduler.schedule(plan, task).unwrap().stream(); - let scheduled: Vec<_> = stream.try_collect().await.unwrap(); - let expected = query.collect().await.unwrap(); + let scheduled: Vec<_> = stream.try_collect().await.unwrap_or_default(); + let expected = query.collect().await.unwrap_or_default(); let total_expected = expected.iter().map(|x| x.num_rows()).sum::(); let total_scheduled = scheduled.iter().map(|x| x.num_rows()).sum::(); diff --git a/datafusion/core/src/scheduler/task.rs b/datafusion/core/src/scheduler/task.rs index b723a37ce7e8..9283810e787d 100644 --- a/datafusion/core/src/scheduler/task.rs +++ b/datafusion/core/src/scheduler/task.rs @@ -108,17 +108,24 @@ impl Task { routable: &RoutablePipeline, error: DataFusionError, ) { - self.context.send_query_output(partition, Err(error)); - if let Some(link) = routable.output { - trace!( - "Closing pipeline: {:?}, partition: {}, due to error", - link, - self.waker.partition, - ); - - self.context.pipelines[link.pipeline] - .pipeline - .close(link.child, self.waker.partition); + match routable.output { + Some(link) => { + // The query output partitioning may not match the current pipeline's + // but the query output has at least one partition + // so send error to the first partition of the query output. + self.context.send_query_output(0, Err(error)); + + trace!( + "Closing pipeline: {:?}, partition: {}, due to error", + link, + self.waker.partition, + ); + + self.context.pipelines[link.pipeline] + .pipeline + .close(link.child, self.waker.partition); + } + None => self.context.send_query_output(partition, Err(error)), } } @@ -303,6 +310,10 @@ impl ExecutionContext { /// Sends `output` to this query's output stream fn send_query_output(&self, partition: usize, output: Result) { + debug_assert!( + self.output.len() > partition, + "the specified partition exceeds the total number of output partitions" + ); let _ = self.output[partition].unbounded_send(Some(output)); }