Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 41 additions & 25 deletions datafusion/physical-plan/src/joins/hash_join/shared_bounds.rs
Original file line number Diff line number Diff line change
Expand Up @@ -699,33 +699,49 @@ impl fmt::Debug for SharedBuildAccumulator {
}
}

#[cfg(test)]
pub(super) fn make_partitioned_accumulator_for_test(
num_partitions: usize,
) -> SharedBuildAccumulator {
let probe_schema = Arc::new(Schema::new(vec![Field::new(
"probe_key",
DataType::Int32,
false,
)]));
let dynamic_filter = Arc::new(DynamicFilterPhysicalExpr::new(vec![], lit(true)));
SharedBuildAccumulator {
inner: Mutex::new(AccumulatorState {
data: AccumulatedBuildData::Partitioned {
partitions: vec![PartitionStatus::Pending; num_partitions],
completed_partitions: 0,
},
completion: CompletionState::Pending,
}),
completion_notify: Notify::new(),
dynamic_filter,
on_right: vec![],
repartition_random_state: SeededRandomState::with_seed(1),
probe_schema,
}
}

#[cfg(test)]
pub(super) fn completed_partitions_for_test(acc: &SharedBuildAccumulator) -> usize {
let guard = acc.inner.lock();
let AccumulatedBuildData::Partitioned {
completed_partitions,
..
} = &guard.data
else {
panic!("expected partitioned accumulator");
};
*completed_partitions
}

#[cfg(test)]
mod tests {
use super::*;

fn make_partitioned_accumulator(num_partitions: usize) -> SharedBuildAccumulator {
let probe_schema = Arc::new(Schema::new(vec![Field::new(
"probe_key",
DataType::Int32,
false,
)]));
let dynamic_filter = Arc::new(DynamicFilterPhysicalExpr::new(vec![], lit(true)));
SharedBuildAccumulator {
inner: Mutex::new(AccumulatorState {
data: AccumulatedBuildData::Partitioned {
partitions: vec![PartitionStatus::Pending; num_partitions],
completed_partitions: 0,
},
completion: CompletionState::Pending,
}),
completion_notify: Notify::new(),
dynamic_filter,
on_right: vec![],
repartition_random_state: SeededRandomState::with_seed(1),
probe_schema,
}
}

fn partitioned_state(acc: &SharedBuildAccumulator) -> (Vec<PartitionStatus>, usize) {
let guard = acc.inner.lock();
let AccumulatedBuildData::Partitioned {
Expand All @@ -748,7 +764,7 @@ mod tests {
// `Reported`. This test pins that invariant.
#[test]
fn report_canceled_partition_is_noop_after_report() {
let acc = make_partitioned_accumulator(2);
let acc = make_partitioned_accumulator_for_test(2);

{
let mut guard = acc.inner.lock();
Expand Down Expand Up @@ -780,7 +796,7 @@ mod tests {
// which is what unblocks sibling partitions waiting on the coordinator.
#[test]
fn report_canceled_partition_marks_pending_partition_canceled() {
let acc = make_partitioned_accumulator(2);
let acc = make_partitioned_accumulator_for_test(2);

acc.report_canceled_partition(0);
let (partitions, completed) = partitioned_state(&acc);
Expand Down
210 changes: 175 additions & 35 deletions datafusion/physical-plan/src/joins/hash_join/stream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -173,15 +173,109 @@ impl ProcessProbeBatchState {

/// Lifecycle of this partition's build-data report to the shared coordinator.
///
/// `ReportScheduled` means the reporting `OnceFut` has been constructed but is
/// lazy: the coordinator has not yet observed the report. Only `ReportDelivered`
/// guarantees the coordinator saw it, so `Drop` must still cancel the partition
/// when the state is `ReportScheduled` — otherwise sibling partitions wait
/// forever for a report that never runs.
/// `Scheduled` means the reporting `OnceFut` has been constructed but is lazy:
/// the coordinator has not necessarily observed the report. Only `Delivered`
/// guarantees the coordinator saw it, so `Drop` must still cancel a `Scheduled`
/// partition — otherwise sibling partitions can wait forever for a report that
/// never runs.
#[derive(Debug, PartialEq, Eq)]
enum BuildReportState {
NotReported,
ReportScheduled,
ReportDelivered,
Scheduled,
Delivered,
Canceled,
Finalized,
}

/// Owns the stream-side lifecycle for one partition's build-data report.
struct BuildReportHandle {
partition: usize,
mode: PartitionMode,
build_accumulator: Option<Arc<SharedBuildAccumulator>>,
waiter: Option<OnceFut<()>>,
state: BuildReportState,
}

impl BuildReportHandle {
fn new(
partition: usize,
mode: PartitionMode,
build_accumulator: Option<Arc<SharedBuildAccumulator>>,
) -> Self {
Self {
partition,
mode,
build_accumulator,
waiter: None,
state: BuildReportState::NotReported,
}
}

fn has_accumulator(&self) -> bool {
self.build_accumulator.is_some()
}

fn schedule(&mut self, build_data: PartitionBuildData) {
let Some(build_accumulator) = &self.build_accumulator else {
// Defensive no-op terminal state; current callers avoid scheduling
// unless an accumulator is present.
self.finalize();
return;
};

debug_assert!(matches!(self.state, BuildReportState::NotReported));
let acc = Arc::clone(build_accumulator);
self.waiter = Some(OnceFut::new(async move {
acc.report_build_data(build_data).await
}));
self.state = BuildReportState::Scheduled;
}

fn poll_delivery(&mut self, cx: &mut std::task::Context<'_>) -> Poll<Result<()>> {
if let Some(ref mut fut) = self.waiter {
ready!(fut.get_shared(cx))?;
if !matches!(self.state, BuildReportState::Delivered) {
debug_assert!(matches!(self.state, BuildReportState::Scheduled));
self.state = BuildReportState::Delivered;
}
}
Poll::Ready(Ok(()))
}

fn cancel_pending(&mut self) {
if matches!(
self.state,
BuildReportState::Delivered
| BuildReportState::Canceled
| BuildReportState::Finalized
) {
return;
}

if self.mode == PartitionMode::Partitioned
&& let Some(build_accumulator) = &self.build_accumulator
{
build_accumulator.report_canceled_partition(self.partition);
self.state = BuildReportState::Canceled;
} else {
self.finalize();
}
}

fn finalize(&mut self) {
self.state = BuildReportState::Finalized;
}

#[cfg(test)]
fn state(&self) -> &BuildReportState {
&self.state
}
}

impl Drop for BuildReportHandle {
fn drop(&mut self) {
self.cancel_pending();
}
}

/// [`Stream`] for [`super::HashJoinExec`] that does the actual join.
Expand Down Expand Up @@ -228,13 +322,8 @@ pub(super) struct HashJoinStream {
build_indices_buffer: Vec<u64>,
/// Specifies whether the right side has an ordering to potentially preserve
right_side_ordered: bool,
/// Shared build accumulator for coordinating dynamic filter updates (collects hash maps and/or bounds, optional)
build_accumulator: Option<Arc<SharedBuildAccumulator>>,
/// Optional future to signal when build information has been reported by all partitions
/// and the dynamic filter has been updated
build_waiter: Option<OnceFut<()>>,
/// Tracks where this partition is in the build-data reporting lifecycle.
build_report_state: BuildReportState,
/// Owns this partition's build-data report lifecycle.
build_report: BuildReportHandle,
/// Partitioning mode to use
mode: PartitionMode,
/// Output buffer for coalescing small batches into larger ones with optional fetch limit.
Expand Down Expand Up @@ -414,9 +503,7 @@ impl HashJoinStream {
probe_indices_buffer: Vec::with_capacity(batch_size),
build_indices_buffer: Vec::with_capacity(batch_size),
right_side_ordered,
build_accumulator,
build_waiter: None,
build_report_state: BuildReportState::NotReported,
build_report: BuildReportHandle::new(partition, mode, build_accumulator),
mode,
output_buffer,
null_aware,
Expand Down Expand Up @@ -449,9 +536,9 @@ impl HashJoinStream {
&mut self,
left_data: &Arc<JoinLeftData>,
) -> HashJoinStreamState {
let Some(build_accumulator) = self.build_accumulator.as_ref() else {
if !self.build_report.has_accumulator() {
return Self::state_after_build_ready(self.join_type, left_data.as_ref());
};
}

let pushdown = left_data.membership().clone();
let bounds = left_data
Expand All @@ -473,11 +560,7 @@ impl HashJoinStream {
),
};

let acc = Arc::clone(build_accumulator);
self.build_waiter = Some(OnceFut::new(async move {
acc.report_build_data(build_data).await
}));
self.build_report_state = BuildReportState::ReportScheduled;
self.build_report.schedule(build_data);
HashJoinStreamState::WaitPartitionBoundsReport
}

Expand Down Expand Up @@ -541,10 +624,7 @@ impl HashJoinStream {
&mut self,
cx: &mut std::task::Context<'_>,
) -> Poll<Result<StatefulStreamResult<Option<RecordBatch>>>> {
if let Some(ref mut fut) = self.build_waiter {
ready!(fut.get_shared(cx))?;
self.build_report_state = BuildReportState::ReportDelivered;
}
ready!(self.build_report.poll_delivery(cx))?;
let build_side = self.build_side.try_as_ready()?;
self.state =
Self::state_after_build_ready(self.join_type, build_side.left_data.as_ref());
Expand Down Expand Up @@ -966,14 +1046,74 @@ impl Stream for HashJoinStream {
}
}

impl Drop for HashJoinStream {
fn drop(&mut self) {
if self.mode == PartitionMode::Partitioned
&& !matches!(self.build_report_state, BuildReportState::ReportDelivered)
&& let Some(build_accumulator) = &self.build_accumulator
#[cfg(test)]
mod tests {
use super::*;
use crate::joins::hash_join::shared_bounds::{
PushdownStrategy, completed_partitions_for_test,
make_partitioned_accumulator_for_test,
};

fn empty_build_data(partition_id: usize) -> PartitionBuildData {
PartitionBuildData::Partitioned {
partition_id,
pushdown: PushdownStrategy::Empty,
bounds: PartitionBounds::new(vec![]),
}
}

fn partitioned_handle(acc: &Arc<SharedBuildAccumulator>) -> BuildReportHandle {
BuildReportHandle::new(0, PartitionMode::Partitioned, Some(Arc::clone(acc)))
}

#[test]
fn build_report_handle_cancels_scheduled_partition_on_drop() {
let acc = Arc::new(make_partitioned_accumulator_for_test(2));

{
build_accumulator.report_canceled_partition(self.partition);
self.build_report_state = BuildReportState::ReportDelivered;
let mut handle = partitioned_handle(&acc);
handle.schedule(empty_build_data(0));
assert_eq!(handle.state(), &BuildReportState::Scheduled);
}

assert_eq!(completed_partitions_for_test(&acc), 1);
}

#[test]
fn build_report_handle_does_not_cancel_delivered_partition_on_drop() {
let acc = Arc::new(make_partitioned_accumulator_for_test(1));

{
let mut handle = partitioned_handle(&acc);
handle.schedule(empty_build_data(0));
let mut cx = std::task::Context::from_waker(futures::task::noop_waker_ref());
assert!(matches!(handle.poll_delivery(&mut cx), Poll::Ready(Ok(()))));
assert_eq!(handle.state(), &BuildReportState::Delivered);
}

assert_eq!(completed_partitions_for_test(&acc), 1);
}

#[test]
fn build_report_handle_cancel_pending_is_idempotent() {
let acc = Arc::new(make_partitioned_accumulator_for_test(2));
let mut handle = partitioned_handle(&acc);
handle.schedule(empty_build_data(0));

handle.cancel_pending();
handle.cancel_pending();

assert_eq!(handle.state(), &BuildReportState::Canceled);
assert_eq!(completed_partitions_for_test(&acc), 1);
}

#[test]
fn build_report_handle_no_accumulator_finalizes() {
let mut handle = BuildReportHandle::new(0, PartitionMode::Partitioned, None);

handle.schedule(empty_build_data(0));
handle.cancel_pending();

assert_eq!(handle.state(), &BuildReportState::Finalized);
}
}
Loading