diff --git a/datafusion/physical-plan/src/joins/hash_join/shared_bounds.rs b/datafusion/physical-plan/src/joins/hash_join/shared_bounds.rs index fba6b2c2db2e2..0af4015ff7239 100644 --- a/datafusion/physical-plan/src/joins/hash_join/shared_bounds.rs +++ b/datafusion/physical-plan/src/joins/hash_join/shared_bounds.rs @@ -699,33 +699,49 @@ impl fmt::Debug for SharedBuildAccumulator { } } +#[cfg(test)] +pub(super) fn make_partitioned_accumulator_for_test( + num_partitions: usize, +) -> SharedBuildAccumulator { + let probe_schema = Arc::new(Schema::new(vec![Field::new( + "probe_key", + DataType::Int32, + false, + )])); + let dynamic_filter = Arc::new(DynamicFilterPhysicalExpr::new(vec![], lit(true))); + SharedBuildAccumulator { + inner: Mutex::new(AccumulatorState { + data: AccumulatedBuildData::Partitioned { + partitions: vec![PartitionStatus::Pending; num_partitions], + completed_partitions: 0, + }, + completion: CompletionState::Pending, + }), + completion_notify: Notify::new(), + dynamic_filter, + on_right: vec![], + repartition_random_state: SeededRandomState::with_seed(1), + probe_schema, + } +} + +#[cfg(test)] +pub(super) fn completed_partitions_for_test(acc: &SharedBuildAccumulator) -> usize { + let guard = acc.inner.lock(); + let AccumulatedBuildData::Partitioned { + completed_partitions, + .. + } = &guard.data + else { + panic!("expected partitioned accumulator"); + }; + *completed_partitions +} + #[cfg(test)] mod tests { use super::*; - fn make_partitioned_accumulator(num_partitions: usize) -> SharedBuildAccumulator { - let probe_schema = Arc::new(Schema::new(vec![Field::new( - "probe_key", - DataType::Int32, - false, - )])); - let dynamic_filter = Arc::new(DynamicFilterPhysicalExpr::new(vec![], lit(true))); - SharedBuildAccumulator { - inner: Mutex::new(AccumulatorState { - data: AccumulatedBuildData::Partitioned { - partitions: vec![PartitionStatus::Pending; num_partitions], - completed_partitions: 0, - }, - completion: CompletionState::Pending, - }), - completion_notify: Notify::new(), - dynamic_filter, - on_right: vec![], - repartition_random_state: SeededRandomState::with_seed(1), - probe_schema, - } - } - fn partitioned_state(acc: &SharedBuildAccumulator) -> (Vec, usize) { let guard = acc.inner.lock(); let AccumulatedBuildData::Partitioned { @@ -748,7 +764,7 @@ mod tests { // `Reported`. This test pins that invariant. #[test] fn report_canceled_partition_is_noop_after_report() { - let acc = make_partitioned_accumulator(2); + let acc = make_partitioned_accumulator_for_test(2); { let mut guard = acc.inner.lock(); @@ -780,7 +796,7 @@ mod tests { // which is what unblocks sibling partitions waiting on the coordinator. #[test] fn report_canceled_partition_marks_pending_partition_canceled() { - let acc = make_partitioned_accumulator(2); + let acc = make_partitioned_accumulator_for_test(2); acc.report_canceled_partition(0); let (partitions, completed) = partitioned_state(&acc); diff --git a/datafusion/physical-plan/src/joins/hash_join/stream.rs b/datafusion/physical-plan/src/joins/hash_join/stream.rs index 040470c9be12b..d403fa43cda4b 100644 --- a/datafusion/physical-plan/src/joins/hash_join/stream.rs +++ b/datafusion/physical-plan/src/joins/hash_join/stream.rs @@ -173,15 +173,109 @@ impl ProcessProbeBatchState { /// Lifecycle of this partition's build-data report to the shared coordinator. /// -/// `ReportScheduled` means the reporting `OnceFut` has been constructed but is -/// lazy: the coordinator has not yet observed the report. Only `ReportDelivered` -/// guarantees the coordinator saw it, so `Drop` must still cancel the partition -/// when the state is `ReportScheduled` — otherwise sibling partitions wait -/// forever for a report that never runs. +/// `Scheduled` means the reporting `OnceFut` has been constructed but is lazy: +/// the coordinator has not necessarily observed the report. Only `Delivered` +/// guarantees the coordinator saw it, so `Drop` must still cancel a `Scheduled` +/// partition — otherwise sibling partitions can wait forever for a report that +/// never runs. +#[derive(Debug, PartialEq, Eq)] enum BuildReportState { NotReported, - ReportScheduled, - ReportDelivered, + Scheduled, + Delivered, + Canceled, + Finalized, +} + +/// Owns the stream-side lifecycle for one partition's build-data report. +struct BuildReportHandle { + partition: usize, + mode: PartitionMode, + build_accumulator: Option>, + waiter: Option>, + state: BuildReportState, +} + +impl BuildReportHandle { + fn new( + partition: usize, + mode: PartitionMode, + build_accumulator: Option>, + ) -> Self { + Self { + partition, + mode, + build_accumulator, + waiter: None, + state: BuildReportState::NotReported, + } + } + + fn has_accumulator(&self) -> bool { + self.build_accumulator.is_some() + } + + fn schedule(&mut self, build_data: PartitionBuildData) { + let Some(build_accumulator) = &self.build_accumulator else { + // Defensive no-op terminal state; current callers avoid scheduling + // unless an accumulator is present. + self.finalize(); + return; + }; + + debug_assert!(matches!(self.state, BuildReportState::NotReported)); + let acc = Arc::clone(build_accumulator); + self.waiter = Some(OnceFut::new(async move { + acc.report_build_data(build_data).await + })); + self.state = BuildReportState::Scheduled; + } + + fn poll_delivery(&mut self, cx: &mut std::task::Context<'_>) -> Poll> { + if let Some(ref mut fut) = self.waiter { + ready!(fut.get_shared(cx))?; + if !matches!(self.state, BuildReportState::Delivered) { + debug_assert!(matches!(self.state, BuildReportState::Scheduled)); + self.state = BuildReportState::Delivered; + } + } + Poll::Ready(Ok(())) + } + + fn cancel_pending(&mut self) { + if matches!( + self.state, + BuildReportState::Delivered + | BuildReportState::Canceled + | BuildReportState::Finalized + ) { + return; + } + + if self.mode == PartitionMode::Partitioned + && let Some(build_accumulator) = &self.build_accumulator + { + build_accumulator.report_canceled_partition(self.partition); + self.state = BuildReportState::Canceled; + } else { + self.finalize(); + } + } + + fn finalize(&mut self) { + self.state = BuildReportState::Finalized; + } + + #[cfg(test)] + fn state(&self) -> &BuildReportState { + &self.state + } +} + +impl Drop for BuildReportHandle { + fn drop(&mut self) { + self.cancel_pending(); + } } /// [`Stream`] for [`super::HashJoinExec`] that does the actual join. @@ -228,13 +322,8 @@ pub(super) struct HashJoinStream { build_indices_buffer: Vec, /// Specifies whether the right side has an ordering to potentially preserve right_side_ordered: bool, - /// Shared build accumulator for coordinating dynamic filter updates (collects hash maps and/or bounds, optional) - build_accumulator: Option>, - /// Optional future to signal when build information has been reported by all partitions - /// and the dynamic filter has been updated - build_waiter: Option>, - /// Tracks where this partition is in the build-data reporting lifecycle. - build_report_state: BuildReportState, + /// Owns this partition's build-data report lifecycle. + build_report: BuildReportHandle, /// Partitioning mode to use mode: PartitionMode, /// Output buffer for coalescing small batches into larger ones with optional fetch limit. @@ -414,9 +503,7 @@ impl HashJoinStream { probe_indices_buffer: Vec::with_capacity(batch_size), build_indices_buffer: Vec::with_capacity(batch_size), right_side_ordered, - build_accumulator, - build_waiter: None, - build_report_state: BuildReportState::NotReported, + build_report: BuildReportHandle::new(partition, mode, build_accumulator), mode, output_buffer, null_aware, @@ -449,9 +536,9 @@ impl HashJoinStream { &mut self, left_data: &Arc, ) -> HashJoinStreamState { - let Some(build_accumulator) = self.build_accumulator.as_ref() else { + if !self.build_report.has_accumulator() { return Self::state_after_build_ready(self.join_type, left_data.as_ref()); - }; + } let pushdown = left_data.membership().clone(); let bounds = left_data @@ -473,11 +560,7 @@ impl HashJoinStream { ), }; - let acc = Arc::clone(build_accumulator); - self.build_waiter = Some(OnceFut::new(async move { - acc.report_build_data(build_data).await - })); - self.build_report_state = BuildReportState::ReportScheduled; + self.build_report.schedule(build_data); HashJoinStreamState::WaitPartitionBoundsReport } @@ -541,10 +624,7 @@ impl HashJoinStream { &mut self, cx: &mut std::task::Context<'_>, ) -> Poll>>> { - if let Some(ref mut fut) = self.build_waiter { - ready!(fut.get_shared(cx))?; - self.build_report_state = BuildReportState::ReportDelivered; - } + ready!(self.build_report.poll_delivery(cx))?; let build_side = self.build_side.try_as_ready()?; self.state = Self::state_after_build_ready(self.join_type, build_side.left_data.as_ref()); @@ -966,14 +1046,74 @@ impl Stream for HashJoinStream { } } -impl Drop for HashJoinStream { - fn drop(&mut self) { - if self.mode == PartitionMode::Partitioned - && !matches!(self.build_report_state, BuildReportState::ReportDelivered) - && let Some(build_accumulator) = &self.build_accumulator +#[cfg(test)] +mod tests { + use super::*; + use crate::joins::hash_join::shared_bounds::{ + PushdownStrategy, completed_partitions_for_test, + make_partitioned_accumulator_for_test, + }; + + fn empty_build_data(partition_id: usize) -> PartitionBuildData { + PartitionBuildData::Partitioned { + partition_id, + pushdown: PushdownStrategy::Empty, + bounds: PartitionBounds::new(vec![]), + } + } + + fn partitioned_handle(acc: &Arc) -> BuildReportHandle { + BuildReportHandle::new(0, PartitionMode::Partitioned, Some(Arc::clone(acc))) + } + + #[test] + fn build_report_handle_cancels_scheduled_partition_on_drop() { + let acc = Arc::new(make_partitioned_accumulator_for_test(2)); + { - build_accumulator.report_canceled_partition(self.partition); - self.build_report_state = BuildReportState::ReportDelivered; + let mut handle = partitioned_handle(&acc); + handle.schedule(empty_build_data(0)); + assert_eq!(handle.state(), &BuildReportState::Scheduled); } + + assert_eq!(completed_partitions_for_test(&acc), 1); + } + + #[test] + fn build_report_handle_does_not_cancel_delivered_partition_on_drop() { + let acc = Arc::new(make_partitioned_accumulator_for_test(1)); + + { + let mut handle = partitioned_handle(&acc); + handle.schedule(empty_build_data(0)); + let mut cx = std::task::Context::from_waker(futures::task::noop_waker_ref()); + assert!(matches!(handle.poll_delivery(&mut cx), Poll::Ready(Ok(())))); + assert_eq!(handle.state(), &BuildReportState::Delivered); + } + + assert_eq!(completed_partitions_for_test(&acc), 1); + } + + #[test] + fn build_report_handle_cancel_pending_is_idempotent() { + let acc = Arc::new(make_partitioned_accumulator_for_test(2)); + let mut handle = partitioned_handle(&acc); + handle.schedule(empty_build_data(0)); + + handle.cancel_pending(); + handle.cancel_pending(); + + assert_eq!(handle.state(), &BuildReportState::Canceled); + assert_eq!(completed_partitions_for_test(&acc), 1); + } + + #[test] + fn build_report_handle_no_accumulator_finalizes() { + let mut handle = BuildReportHandle::new(0, PartitionMode::Partitioned, None); + + handle.schedule(empty_build_data(0)); + handle.cancel_pending(); + + assert_eq!(handle.state(), &BuildReportState::Finalized); } }