From 432f5c7fdea0ef0b04265bec6b121cd5aa90e876 Mon Sep 17 00:00:00 2001 From: SubhamSinghal Date: Mon, 11 May 2026 12:26:03 +0530 Subject: [PATCH 1/2] Track spill read-back memory in SMJ --- .../sort_merge_join/materializing_stream.rs | 67 ++++-- .../src/joins/sort_merge_join/tests.rs | 223 ++++++++++++++++++ 2 files changed, 275 insertions(+), 15 deletions(-) diff --git a/datafusion/physical-plan/src/joins/sort_merge_join/materializing_stream.rs b/datafusion/physical-plan/src/joins/sort_merge_join/materializing_stream.rs index 069e94d0a9fd6..c7df8ba8586ac 100644 --- a/datafusion/physical-plan/src/joins/sort_merge_join/materializing_stream.rs +++ b/datafusion/physical-plan/src/joins/sort_merge_join/materializing_stream.rs @@ -1526,7 +1526,7 @@ impl MaterializingSortMergeJoinStream { /// gathers columns across sources. A null-row sentinel at source index 0 /// handles null right indices (unmatched streamed rows). fn materialize_right_columns( - &self, + &mut self, matched_chunks: &[(usize, UInt64Array, UInt64Array)], total_matched_rows: usize, ) -> Result> { @@ -1541,11 +1541,33 @@ impl MaterializingSortMergeJoinStream { matched_chunks.iter().map(|c| &c.2 as &dyn Array).collect(); as_uint64_array(&compute::concat(&refs)?)?.clone() }; - return fetch_right_columns_by_idxs( + + let spill_read_mem = match &self.buffered_data.batches[first_batch_idx].batch + { + BufferedBatchState::Spilled(_) => { + self.buffered_data.batches[first_batch_idx].size_estimation + } + _ => 0, + }; + + if spill_read_mem > 0 { + self.reservation.grow(spill_read_mem); + self.join_metrics + .peak_mem_used() + .set_max(self.reservation.size()); + } + + let result = fetch_right_columns_by_idxs( &self.buffered_data, first_batch_idx, &combined_right_indices, ); + + if spill_read_mem > 0 { + self.reservation.shrink(spill_read_mem); + } + + return result; } // Multiple source batches: map each buffered_batch_idx to a @@ -1577,20 +1599,31 @@ impl MaterializingSortMergeJoinStream { let mut right_columns = Vec::with_capacity(num_right_cols); // Read each source batch once (spilled batches require disk I/O). - let source_data: Vec> = source_batches - .iter() - .map(|&idx| { - let bb = &self.buffered_data.batches[idx]; - match &bb.batch { - BufferedBatchState::InMemory(batch) => Some(batch.clone()), - BufferedBatchState::Spilled(spill_file) => { - let file = BufReader::new(File::open(spill_file.path()).ok()?); - let reader = StreamReader::try_new(file, None).ok()?; - reader.into_iter().next()?.ok() - } + // Track memory for each spilled batch at the point of deserialization + // so the pool reflects actual usage as it grows. + let mut spill_read_mem: usize = 0; + let mut source_data: Vec> = + Vec::with_capacity(source_batches.len()); + for &idx in &source_batches { + let bb = &self.buffered_data.batches[idx]; + match &bb.batch { + BufferedBatchState::InMemory(batch) => { + source_data.push(Some(batch.clone())); } - }) - .collect(); + BufferedBatchState::Spilled(spill_file) => { + let batch_mem = bb.size_estimation; + self.reservation.grow(batch_mem); + self.join_metrics + .peak_mem_used() + .set_max(self.reservation.size()); + spill_read_mem += batch_mem; + + let file = BufReader::new(File::open(spill_file.path())?); + let reader = StreamReader::try_new(file, None)?; + source_data.push(reader.into_iter().next().transpose()?); + } + } + } for col_idx in 0..num_right_cols { let dtype = self.buffered_schema.field(col_idx).data_type(); @@ -1614,6 +1647,10 @@ impl MaterializingSortMergeJoinStream { right_columns.push(interleave(&source_arrays, &interleave_indices)?); } + if spill_read_mem > 0 { + self.reservation.shrink(spill_read_mem); + } + Ok(right_columns) } diff --git a/datafusion/physical-plan/src/joins/sort_merge_join/tests.rs b/datafusion/physical-plan/src/joins/sort_merge_join/tests.rs index bc34c351c5e21..3e0237350050d 100644 --- a/datafusion/physical-plan/src/joins/sort_merge_join/tests.rs +++ b/datafusion/physical-plan/src/joins/sort_merge_join/tests.rs @@ -4724,3 +4724,226 @@ async fn spill_filtered_boundary_loses_outer_rows() -> Result<()> { Ok(()) } + +/// Verifies that `peak_mem_used` reflects spill read-back memory during +/// output materialization. +/// +/// When spilled buffered batches are read back from disk to produce join +/// output, the deserialized data temporarily exists in memory. This test +/// verifies that the read-back is tracked via grow/shrink so the pool +/// accurately reflects the transient spike. +#[tokio::test] +async fn spill_read_back_memory_accounting() -> Result<()> { + use arrow::array::Array; + + let left_batch = build_table_i32( + ("a1", &vec![0, 1]), + ("b1", &vec![1, 1]), + ("c1", &vec![4, 5]), + ); + let size_estimation = left_batch.get_array_memory_size() + + Int32Array::from(vec![1, 1]).get_array_memory_size() + + 2usize.next_power_of_two() * size_of::() + + size_of::>() + + size_of::(); + + // Memory limit too small for a full batch — forces spilling. + let memory_limit = size_estimation / 2; + + // All rows share the same join key (b=1) to force multiple buffered + // batches in the same key group — triggering spill read-back during + // output materialization. + let left_batches: Vec = (0..4) + .map(|i| { + build_table_i32( + ("a1", &vec![i * 2, i * 2 + 1]), + ("b1", &vec![1, 1]), + ("c1", &vec![100 + i, 101 + i]), + ) + }) + .collect(); + let left = build_table_from_batches(left_batches); + + let right_batches: Vec = (0..4) + .map(|i| { + build_table_i32( + ("a2", &vec![i * 2, i * 2 + 1]), + ("b2", &vec![1, 1]), + ("c2", &vec![200 + i, 201 + i]), + ) + }) + .collect(); + let right = build_table_from_batches(right_batches); + + let on = vec![( + Arc::new(Column::new_with_schema("b1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("b2", &right.schema())?) as _, + )]; + let sort_options = vec![SortOptions::default(); on.len()]; + + let runtime = RuntimeEnvBuilder::new() + .with_memory_limit(memory_limit, 1.0) + .with_disk_manager_builder( + DiskManagerBuilder::default().with_mode(DiskManagerMode::OsTmpDirectory), + ) + .build_arc()?; + + let session_config = SessionConfig::default().with_batch_size(50); + let task_ctx = Arc::new( + TaskContext::default() + .with_session_config(session_config) + .with_runtime(Arc::clone(&runtime)), + ); + + let join = join_with_options( + Arc::clone(&left), + Arc::clone(&right), + on.clone(), + Inner, + sort_options, + NullEquality::NullEqualsNothing, + )?; + + let stream = join.execute(0, task_ctx)?; + let result = common::collect(stream).await.unwrap(); + + assert!(!result.is_empty(), "Expected non-empty join result"); + + let metrics = join.metrics().unwrap(); + assert!( + metrics.spill_count().unwrap() > 0, + "Expected spilling to occur" + ); + + // peak_mem_used should reflect the spill read-back: when buffered + // batches are read from disk during output materialization, grow() + // temporarily reserves size_estimation. This pushes peak above what + // join_arrays_mem alone would show. + let peak_mem = metrics + .sum_by_name("peak_mem_used") + .map(|m| m.as_usize()) + .unwrap_or(0); + assert!( + peak_mem >= size_estimation, + "peak_mem_used ({peak_mem}) should be >= size_estimation ({size_estimation}) \ + because spill read-back temporarily loads full batch into memory" + ); + + // All memory must be released (grow/shrink balanced) + assert_eq!( + runtime.memory_pool.reserved(), + 0, + "All memory should be released after join completes" + ); + + Ok(()) +} + +/// Verifies spill read-back memory tracking for the single-source path. +/// +/// When only ONE buffered batch exists for a key group and it's spilled, +/// `fetch_right_columns_by_idxs` reads it back. This test verifies the +/// grow/shrink around that single-batch read. +#[tokio::test] +async fn spill_read_back_single_source() -> Result<()> { + use arrow::array::Array; + + let left_batch = build_table_i32( + ("a1", &vec![0, 1]), + ("b1", &vec![1, 1]), + ("c1", &vec![4, 5]), + ); + let size_estimation = left_batch.get_array_memory_size() + + Int32Array::from(vec![1, 1]).get_array_memory_size() + + 2usize.next_power_of_two() * size_of::() + + size_of::>() + + size_of::(); + + // Memory limit too small for a full batch — forces spilling. + let memory_limit = size_estimation / 2; + + // Multiple distinct keys so each key group has exactly ONE buffered batch. + // This ensures the single-source path is exercised. + let left_batches: Vec = (0..4) + .map(|i| { + build_table_i32( + ("a1", &vec![i * 2, i * 2 + 1]), + ("b1", &vec![i, i]), + ("c1", &vec![100 + i, 101 + i]), + ) + }) + .collect(); + let left = build_table_from_batches(left_batches); + + // One batch per key — each key group has single source + let right_batches: Vec = (0..4) + .map(|i| { + build_table_i32( + ("a2", &vec![i * 2, i * 2 + 1]), + ("b2", &vec![i, i]), + ("c2", &vec![200 + i, 201 + i]), + ) + }) + .collect(); + let right = build_table_from_batches(right_batches); + + let on = vec![( + Arc::new(Column::new_with_schema("b1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("b2", &right.schema())?) as _, + )]; + let sort_options = vec![SortOptions::default(); on.len()]; + + let runtime = RuntimeEnvBuilder::new() + .with_memory_limit(memory_limit, 1.0) + .with_disk_manager_builder( + DiskManagerBuilder::default().with_mode(DiskManagerMode::OsTmpDirectory), + ) + .build_arc()?; + + let session_config = SessionConfig::default().with_batch_size(50); + let task_ctx = Arc::new( + TaskContext::default() + .with_session_config(session_config) + .with_runtime(Arc::clone(&runtime)), + ); + + let join = join_with_options( + Arc::clone(&left), + Arc::clone(&right), + on.clone(), + Inner, + sort_options, + NullEquality::NullEqualsNothing, + )?; + + let stream = join.execute(0, task_ctx)?; + let result = common::collect(stream).await.unwrap(); + + assert!(!result.is_empty(), "Expected non-empty join result"); + + let metrics = join.metrics().unwrap(); + assert!( + metrics.spill_count().unwrap() > 0, + "Expected spilling to occur" + ); + + // peak_mem_used should reflect the single-batch read-back + let peak_mem = metrics + .sum_by_name("peak_mem_used") + .map(|m| m.as_usize()) + .unwrap_or(0); + assert!( + peak_mem >= size_estimation, + "peak_mem_used ({peak_mem}) should be >= size_estimation ({size_estimation}) \ + because single-source spill read-back loads full batch" + ); + + // All memory must be released + assert_eq!( + runtime.memory_pool.reserved(), + 0, + "All memory should be released after join completes" + ); + + Ok(()) +} From 81be3f976af567e24fa7432ab714d5d877a5559c Mon Sep 17 00:00:00 2001 From: SubhamSinghal Date: Thu, 14 May 2026 17:30:29 +0530 Subject: [PATCH 2/2] Use MemoryReservation for memorary tracking --- .../sort_merge_join/materializing_stream.rs | 42 ++++++------------- .../src/joins/sort_merge_join/tests.rs | 12 +++--- 2 files changed, 19 insertions(+), 35 deletions(-) diff --git a/datafusion/physical-plan/src/joins/sort_merge_join/materializing_stream.rs b/datafusion/physical-plan/src/joins/sort_merge_join/materializing_stream.rs index bf96e75e5a8ee..9bcc749c23dce 100644 --- a/datafusion/physical-plan/src/joins/sort_merge_join/materializing_stream.rs +++ b/datafusion/physical-plan/src/joins/sort_merge_join/materializing_stream.rs @@ -1556,32 +1556,23 @@ impl MaterializingSortMergeJoinStream { as_uint64_array(&compute::concat(&refs)?)?.clone() }; - let spill_read_mem = match &self.buffered_data.batches[first_batch_idx].batch - { - BufferedBatchState::Spilled(_) => { - self.buffered_data.batches[first_batch_idx].size_estimation - } - _ => 0, - }; - - if spill_read_mem > 0 { - self.reservation.grow(spill_read_mem); + let spill_reservation = self.reservation.new_empty(); + if matches!( + &self.buffered_data.batches[first_batch_idx].batch, + BufferedBatchState::Spilled(_) + ) { + spill_reservation + .grow(self.buffered_data.batches[first_batch_idx].size_estimation); self.join_metrics .peak_mem_used() - .set_max(self.reservation.size()); + .set_max(self.reservation.size() + spill_reservation.size()); } - let result = fetch_right_columns_by_idxs( + return fetch_right_columns_by_idxs( &self.buffered_data, first_batch_idx, &combined_right_indices, ); - - if spill_read_mem > 0 { - self.reservation.shrink(spill_read_mem); - } - - return result; } // Multiple source batches: map each buffered_batch_idx to a @@ -1610,12 +1601,11 @@ impl MaterializingSortMergeJoinStream { } let num_right_cols = self.buffered_schema.fields().len(); - let mut right_columns = Vec::with_capacity(num_right_cols); // Read each source batch once (spilled batches require disk I/O). // Track memory for each spilled batch at the point of deserialization // so the pool reflects actual usage as it grows. - let mut spill_read_mem: usize = 0; + let spill_reservation = self.reservation.new_empty(); let mut source_data: Vec> = Vec::with_capacity(source_batches.len()); for &idx in &source_batches { @@ -1625,12 +1615,10 @@ impl MaterializingSortMergeJoinStream { source_data.push(Some(batch.clone())); } BufferedBatchState::Spilled(spill_file) => { - let batch_mem = bb.size_estimation; - self.reservation.grow(batch_mem); + spill_reservation.grow(bb.size_estimation); self.join_metrics .peak_mem_used() - .set_max(self.reservation.size()); - spill_read_mem += batch_mem; + .set_max(self.reservation.size() + spill_reservation.size()); let file = BufReader::new(File::open(spill_file.path())?); let reader = StreamReader::try_new(file, None)?; @@ -1639,6 +1627,7 @@ impl MaterializingSortMergeJoinStream { } } + let mut right_columns = Vec::with_capacity(num_right_cols); for col_idx in 0..num_right_cols { let dtype = self.buffered_schema.field(col_idx).data_type(); let null_array = new_null_array(dtype, 1); @@ -1657,14 +1646,9 @@ impl MaterializingSortMergeJoinStream { } } } - right_columns.push(interleave(&source_arrays, &interleave_indices)?); } - if spill_read_mem > 0 { - self.reservation.shrink(spill_read_mem); - } - Ok(right_columns) } diff --git a/datafusion/physical-plan/src/joins/sort_merge_join/tests.rs b/datafusion/physical-plan/src/joins/sort_merge_join/tests.rs index 3e0237350050d..c4377b3189ff7 100644 --- a/datafusion/physical-plan/src/joins/sort_merge_join/tests.rs +++ b/datafusion/physical-plan/src/joins/sort_merge_join/tests.rs @@ -4726,12 +4726,12 @@ async fn spill_filtered_boundary_loses_outer_rows() -> Result<()> { } /// Verifies that `peak_mem_used` reflects spill read-back memory during -/// output materialization. +/// output materialization (multi-source path). /// /// When spilled buffered batches are read back from disk to produce join -/// output, the deserialized data temporarily exists in memory. This test -/// verifies that the read-back is tracked via grow/shrink so the pool -/// accurately reflects the transient spike. +/// output, a scoped `MemoryReservation` (via `new_empty()`) tracks the +/// transient memory. Its `Drop` guarantees the pool is balanced on every +/// exit path — normal return or early `?` error. #[tokio::test] async fn spill_read_back_memory_accounting() -> Result<()> { use arrow::array::Array; @@ -4842,8 +4842,8 @@ async fn spill_read_back_memory_accounting() -> Result<()> { /// Verifies spill read-back memory tracking for the single-source path. /// /// When only ONE buffered batch exists for a key group and it's spilled, -/// `fetch_right_columns_by_idxs` reads it back. This test verifies the -/// grow/shrink around that single-batch read. +/// `fetch_right_columns_by_idxs` reads it back. A scoped `MemoryReservation` +/// (via `new_empty()`) tracks the transient memory and releases it on drop. #[tokio::test] async fn spill_read_back_single_source() -> Result<()> { use arrow::array::Array;