-
Notifications
You must be signed in to change notification settings - Fork 2k
Fix sort merge interleave overflow #20922
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -18,9 +18,12 @@ | |
| use crate::spill::get_record_batch_memory_size; | ||
| use arrow::compute::interleave; | ||
| use arrow::datatypes::SchemaRef; | ||
| use arrow::error::ArrowError; | ||
| use arrow::record_batch::RecordBatch; | ||
| use datafusion_common::Result; | ||
| use datafusion_common::{DataFusionError, Result}; | ||
| use datafusion_execution::memory_pool::MemoryReservation; | ||
| use log::warn; | ||
| use std::panic::{AssertUnwindSafe, catch_unwind}; | ||
| use std::sync::Arc; | ||
|
|
||
| #[derive(Debug, Copy, Clone, Default)] | ||
|
|
@@ -126,49 +129,97 @@ impl BatchBuilder { | |
| &self.schema | ||
| } | ||
|
|
||
| /// Try to interleave all columns using the given index slice. | ||
| fn try_interleave_columns( | ||
| &self, | ||
| indices: &[(usize, usize)], | ||
| ) -> Result<Vec<Arc<dyn arrow::array::Array>>> { | ||
| (0..self.schema.fields.len()) | ||
| .map(|column_idx| { | ||
| let arrays: Vec<_> = self | ||
| .batches | ||
| .iter() | ||
| .map(|(_, batch)| batch.column(column_idx).as_ref()) | ||
| .collect(); | ||
| // Arrow's interleave panics on i32 offset overflow with | ||
| // `.expect("overflow")`. Catch that panic so the caller | ||
| // can retry with fewer rows. | ||
| match catch_unwind(AssertUnwindSafe(|| interleave(&arrays, indices))) { | ||
| Ok(result) => Ok(result?), | ||
| Err(panic_payload) => { | ||
| if is_overflow_panic(&panic_payload) { | ||
| Err(DataFusionError::ArrowError( | ||
| Box::new(ArrowError::OffsetOverflowError(0)), | ||
| None, | ||
| )) | ||
| } else { | ||
| std::panic::resume_unwind(panic_payload); | ||
| } | ||
| } | ||
| } | ||
| }) | ||
| .collect::<Result<Vec<_>>>() | ||
| } | ||
|
|
||
| /// Drains the in_progress row indexes, and builds a new RecordBatch from them | ||
| /// | ||
| /// Will then drop any batches for which all rows have been yielded to the output | ||
| /// Will then drop any batches for which all rows have been yielded to the output. | ||
| /// If an offset overflow occurs (e.g. string/list offsets exceed i32::MAX), | ||
| /// retries with progressively fewer rows until it succeeds. | ||
| /// | ||
| /// Returns `None` if no pending rows | ||
| pub fn build_record_batch(&mut self) -> Result<Option<RecordBatch>> { | ||
| if self.is_empty() { | ||
| return Ok(None); | ||
| } | ||
|
|
||
| let columns = (0..self.schema.fields.len()) | ||
| .map(|column_idx| { | ||
| let arrays: Vec<_> = self | ||
| .batches | ||
| .iter() | ||
| .map(|(_, batch)| batch.column(column_idx).as_ref()) | ||
| .collect(); | ||
| Ok(interleave(&arrays, &self.indices)?) | ||
| }) | ||
| .collect::<Result<Vec<_>>>()?; | ||
|
|
||
| self.indices.clear(); | ||
|
|
||
| // New cursors are only created once the previous cursor for the stream | ||
| // is finished. This means all remaining rows from all but the last batch | ||
| // for each stream have been yielded to the newly created record batch | ||
| // | ||
| // We can therefore drop all but the last batch for each stream | ||
| let mut batch_idx = 0; | ||
| let mut retained = 0; | ||
| self.batches.retain(|(stream_idx, batch)| { | ||
| let stream_cursor = &mut self.cursors[*stream_idx]; | ||
| let retain = stream_cursor.batch_idx == batch_idx; | ||
| batch_idx += 1; | ||
|
|
||
| if retain { | ||
| stream_cursor.batch_idx = retained; | ||
| retained += 1; | ||
| } else { | ||
| self.batches_mem_used -= get_record_batch_memory_size(batch); | ||
| // Try interleaving all indices. On offset overflow, halve and retry. | ||
| let mut end = self.indices.len(); | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The retry loop is clear, but I think |
||
| let columns = loop { | ||
| match self.try_interleave_columns(&self.indices[..end]) { | ||
| Ok(cols) => break cols, | ||
| Err(e) if is_offset_overflow(&e) => { | ||
| end /= 2; | ||
| if end == 0 { | ||
| return Err(e); | ||
| } | ||
| warn!( | ||
| "Interleave offset overflow with {} rows, retrying with {}", | ||
| self.indices.len(), | ||
| end | ||
| ); | ||
| } | ||
| Err(e) => return Err(e), | ||
| } | ||
| retain | ||
| }); | ||
| }; | ||
|
|
||
| // Remove consumed indices, keeping any remaining for the next call. | ||
| self.indices.drain(..end); | ||
|
|
||
| // Only clean up fully-consumed batches when all indices are drained, | ||
| // because remaining indices may still reference earlier batches. | ||
| if self.indices.is_empty() { | ||
| // New cursors are only created once the previous cursor for the stream | ||
| // is finished. This means all remaining rows from all but the last batch | ||
| // for each stream have been yielded to the newly created record batch | ||
| // | ||
| // We can therefore drop all but the last batch for each stream | ||
| let mut batch_idx = 0; | ||
| let mut retained = 0; | ||
| self.batches.retain(|(stream_idx, batch)| { | ||
| let stream_cursor = &mut self.cursors[*stream_idx]; | ||
| let retain = stream_cursor.batch_idx == batch_idx; | ||
| batch_idx += 1; | ||
|
|
||
| if retain { | ||
| stream_cursor.batch_idx = retained; | ||
| retained += 1; | ||
| } else { | ||
| self.batches_mem_used -= get_record_batch_memory_size(batch); | ||
| } | ||
| retain | ||
| }); | ||
| } | ||
|
|
||
| // Release excess memory back to the pool, but never shrink below | ||
| // initial_reservation to maintain the anti-starvation guarantee | ||
|
|
@@ -200,3 +251,75 @@ pub(crate) fn try_grow_reservation_to_at_least( | |
| } | ||
| Ok(()) | ||
| } | ||
|
|
||
| /// Returns true if the error is an Arrow offset overflow. | ||
| fn is_offset_overflow(e: &DataFusionError) -> bool { | ||
| matches!( | ||
| e, | ||
| DataFusionError::ArrowError(boxed, _) | ||
| if matches!(boxed.as_ref(), ArrowError::OffsetOverflowError(_)) | ||
| ) | ||
| } | ||
|
|
||
| /// Returns true if a caught panic payload looks like an Arrow offset overflow. | ||
| fn is_overflow_panic(payload: &Box<dyn std::any::Any + Send>) -> bool { | ||
| if let Some(msg) = payload.downcast_ref::<&str>() { | ||
| return msg.contains("overflow"); | ||
| } | ||
| if let Some(msg) = payload.downcast_ref::<String>() { | ||
| return msg.contains("overflow"); | ||
| } | ||
| false | ||
| } | ||
|
|
||
| #[cfg(test)] | ||
| mod tests { | ||
| use super::*; | ||
| use arrow::array::StringArray; | ||
| use arrow::datatypes::{DataType, Field, Schema}; | ||
| use datafusion_execution::memory_pool::{ | ||
| MemoryConsumer, MemoryPool, UnboundedMemoryPool, | ||
| }; | ||
|
|
||
| /// Test that interleaving string columns whose combined byte length | ||
| /// exceeds i32::MAX does not panic. Arrow's `interleave` panics with | ||
| /// `.expect("overflow")` in this case; `BatchBuilder` catches the | ||
| /// panic and retries with fewer rows until the output fits in i32 | ||
| /// offsets. | ||
| #[test] | ||
| fn test_interleave_overflow_is_caught() { | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this and In practice that means several gigabytes of heap allocation per test, which is likely to make CI flaky or OOM outright. The coverage is important, but I do not think these tests are better replaced with a lower-memory reproduction, for example by constructing the overflow condition with a purpose-built array fixture/helper instead of copying multi-GB payloads into |
||
| // Each string is ~768 MB. Three rows total → ~2.3 GB > i32::MAX. | ||
| let big_str: String = "x".repeat(768 * 1024 * 1024); | ||
|
|
||
| let schema = Arc::new(Schema::new(vec![Field::new("s", DataType::Utf8, false)])); | ||
|
|
||
| let pool: Arc<dyn MemoryPool> = Arc::new(UnboundedMemoryPool::default()); | ||
| let reservation = MemoryConsumer::new("test").register(&pool); | ||
| let mut builder = BatchBuilder::new( | ||
| Arc::clone(&schema), | ||
| /* stream_count */ 3, | ||
| /* batch_size */ 16, | ||
| reservation, | ||
| ); | ||
|
|
||
| // Push one batch per stream, each containing one large string. | ||
| for stream_idx in 0..3 { | ||
| let array = StringArray::from(vec![big_str.as_str()]); | ||
| let batch = | ||
| RecordBatch::try_new(Arc::clone(&schema), vec![Arc::new(array)]).unwrap(); | ||
| builder.push_batch(stream_idx, batch).unwrap(); | ||
| builder.push_row(stream_idx); | ||
| } | ||
|
|
||
| // 3 rows total; interleaving all 3 would overflow i32 offsets. | ||
| // The retry loop should halve until it succeeds. | ||
| let batch = builder.build_record_batch().unwrap().unwrap(); | ||
| assert!(batch.num_rows() > 0); | ||
| assert!(batch.num_rows() < 3); | ||
|
|
||
| // Drain remaining rows. | ||
| let batch2 = builder.build_record_batch().unwrap().unwrap(); | ||
| assert!(batch2.num_rows() > 0); | ||
| assert_eq!(batch.num_rows() + batch2.num_rows(), 3); | ||
| } | ||
| } | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -53,6 +53,14 @@ pub(crate) struct SortPreservingMergeStream<C: CursorValues> { | |
| /// `fetch` limit. | ||
| done: bool, | ||
|
|
||
| /// Whether buffered rows should be drained after `done` is set. | ||
| /// | ||
| /// This is enabled when we stop because the `fetch` limit has been | ||
| /// reached, allowing partial batches left over after overflow handling to | ||
| /// be emitted on subsequent polls. It remains disabled for terminal | ||
| /// errors so the stream does not yield data after returning `Err`. | ||
| drain_in_progress_on_done: bool, | ||
|
|
||
| /// A loser tree that always produces the minimum cursor | ||
| /// | ||
| /// Node 0 stores the top winner, Nodes 1..num_streams store | ||
|
|
@@ -164,6 +172,7 @@ impl<C: CursorValues> SortPreservingMergeStream<C> { | |
| streams, | ||
| metrics, | ||
| done: false, | ||
| drain_in_progress_on_done: false, | ||
| cursors: (0..stream_count).map(|_| None).collect(), | ||
| prev_cursors: (0..stream_count).map(|_| None).collect(), | ||
| round_robin_tie_breaker_mode: false, | ||
|
|
@@ -208,6 +217,19 @@ impl<C: CursorValues> SortPreservingMergeStream<C> { | |
| cx: &mut Context<'_>, | ||
| ) -> Poll<Option<Result<RecordBatch>>> { | ||
| if self.done { | ||
| // When `build_record_batch()` hits an i32 offset overflow (e.g. | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The This feels like it wants a small helper on |
||
| // combined string offsets exceed 2 GB), it emits a partial batch | ||
| // and keeps the remaining rows in `self.in_progress.indices`. | ||
| // Drain those leftover rows before terminating the stream, | ||
| // otherwise they would be silently dropped. | ||
| // Repeated overflows are fine — each poll emits another partial | ||
| // batch until `in_progress` is fully drained. | ||
| if self.drain_in_progress_on_done && !self.in_progress.is_empty() { | ||
| let before = self.in_progress.len(); | ||
| let result = self.in_progress.build_record_batch(); | ||
| self.produced += before - self.in_progress.len(); | ||
| return Poll::Ready(result.transpose()); | ||
| } | ||
| return Poll::Ready(None); | ||
| } | ||
| // Once all partitions have set their corresponding cursors for the loser tree, | ||
|
|
@@ -283,14 +305,17 @@ impl<C: CursorValues> SortPreservingMergeStream<C> { | |
| // stop sorting if fetch has been reached | ||
| if self.fetch_reached() { | ||
| self.done = true; | ||
| self.drain_in_progress_on_done = true; | ||
| } else if self.in_progress.len() < self.batch_size { | ||
| continue; | ||
| } | ||
| } | ||
|
|
||
| self.produced += self.in_progress.len(); | ||
| let before = self.in_progress.len(); | ||
| let result = self.in_progress.build_record_batch(); | ||
| self.produced += before - self.in_progress.len(); | ||
|
|
||
| return Poll::Ready(self.in_progress.build_record_batch().transpose()); | ||
| return Poll::Ready(result.transpose()); | ||
| } | ||
| } | ||
|
|
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Catching any panic whose message merely contains
"overflow"is too broad for a recovery path in the merge operator.This now converts unrelated bugs such as Rust arithmetic overflows (
"attempt to multiply with overflow") or allocation failures like"capacity overflow"into a syntheticOffsetOverflowError, causing the stream to silently split batches instead of surfacing the real defect.Since this code is on the hot path and intentionally swallows panics, I think we need a tighter discriminator before merging. Ideally the overflow detection should match the specific Arrow panic we expect, or be isolated behind a smaller helper/API so we are not turning arbitrary panics into data-dependent control flow.