diff --git a/datafusion/core/src/physical_plan/sorts/cursor.rs b/datafusion/core/src/physical_plan/sorts/cursor.rs index c9313fbfe2596..51110403f5c44 100644 --- a/datafusion/core/src/physical_plan/sorts/cursor.rs +++ b/datafusion/core/src/physical_plan/sorts/cursor.rs @@ -109,8 +109,14 @@ impl PartialOrd for SortKeyCursor { impl Ord for SortKeyCursor { fn cmp(&self, other: &Self) -> Ordering { - self.current() - .cmp(&other.current()) - .then_with(|| self.stream_idx.cmp(&other.stream_idx)) + match (self.is_finished(), other.is_finished()) { + (true, true) => Ordering::Equal, + (_, true) => Ordering::Less, + (true, _) => Ordering::Greater, + _ => self + .current() + .cmp(&other.current()) + .then_with(|| self.stream_idx.cmp(&other.stream_idx)), + } } } diff --git a/datafusion/core/src/physical_plan/sorts/sort_preserving_merge.rs b/datafusion/core/src/physical_plan/sorts/sort_preserving_merge.rs index 2559f6c58c289..212c4c955b327 100644 --- a/datafusion/core/src/physical_plan/sorts/sort_preserving_merge.rs +++ b/datafusion/core/src/physical_plan/sorts/sort_preserving_merge.rs @@ -18,8 +18,7 @@ //! Defines the sort preserving merge plan use std::any::Any; -use std::cmp::Reverse; -use std::collections::{BinaryHeap, VecDeque}; +use std::collections::VecDeque; use std::pin::Pin; use std::sync::Arc; use std::task::{Context, Poll}; @@ -304,10 +303,6 @@ pub(crate) struct SortPreservingMergeStream { /// their rows have been yielded to the output batches: Vec>, - /// Maintain a flag for each stream denoting if the current cursor - /// has finished and needs to poll from the stream - cursor_finished: Vec, - /// The accumulated row indexes for the next record batch in_progress: Vec, @@ -323,8 +318,17 @@ pub(crate) struct SortPreservingMergeStream { /// An id to uniquely identify the input stream batch next_batch_id: usize, - /// Heap that yields [`SortKeyCursor`] in increasing order - heap: BinaryHeap>, + /// Vector that holds all [`SortKeyCursor`]s + cursors: Vec>, + + /// The loser tree that always produces the minimum cursor + /// + /// Node 0 stores the top winner, Nodes 1..num_streams store + /// the loser nodes + loser_tree: Vec, + + /// Identify whether the loser tree is adjusted + loser_tree_adjusted: bool, /// target batch size batch_size: usize, @@ -361,14 +365,15 @@ impl SortPreservingMergeStream { Ok(Self { schema, batches, - cursor_finished: vec![true; stream_count], streams: MergingStreams::new(wrappers), column_expressions: expressions.iter().map(|x| x.expr.clone()).collect(), tracking_metrics, aborted: false, in_progress: vec![], next_batch_id: 0, - heap: BinaryHeap::with_capacity(stream_count), + cursors: (0..stream_count).into_iter().map(|_| None).collect(), + loser_tree: Vec::with_capacity(stream_count), + loser_tree_adjusted: false, batch_size, row_converter, }) @@ -382,7 +387,11 @@ impl SortPreservingMergeStream { cx: &mut Context<'_>, idx: usize, ) -> Poll> { - if !self.cursor_finished[idx] { + if self.cursors[idx] + .as_ref() + .map(|cursor| !cursor.is_finished()) + .unwrap_or(false) + { // Cursor is not finished - don't need a new RecordBatch yet return Poll::Ready(Ok(())); } @@ -418,14 +427,12 @@ impl SortPreservingMergeStream { } }; - let cursor = SortKeyCursor::new( + self.cursors[idx] = Some(SortKeyCursor::new( idx, self.next_batch_id, // assign this batch an ID rows, - ); + )); self.next_batch_id += 1; - self.heap.push(Reverse(cursor)); - self.cursor_finished[idx] = false; self.batches[idx].push_back(batch) } else { empty_batch = true; @@ -551,17 +558,46 @@ impl SortPreservingMergeStream { if self.aborted { return Poll::Ready(None); } + let num_streams = self.streams.num_streams(); + + // Init all cursors and the loser tree in the first poll + if self.loser_tree.is_empty() { + // Ensure all non-exhausted streams have a cursor from which + // rows can be pulled + for i in 0..num_streams { + match futures::ready!(self.maybe_poll_stream(cx, i)) { + Ok(_) => {} + Err(e) => { + self.aborted = true; + return Poll::Ready(Some(Err(e))); + } + } + } - // Ensure all non-exhausted streams have a cursor from which - // rows can be pulled - for i in 0..self.streams.num_streams() { - match futures::ready!(self.maybe_poll_stream(cx, i)) { - Ok(_) => {} - Err(e) => { - self.aborted = true; - return Poll::Ready(Some(Err(e))); + // Init loser tree + self.loser_tree.resize(num_streams, usize::MAX); + for i in 0..num_streams { + let mut winner = i; + let mut cmp_node = (num_streams + i) / 2; + while cmp_node != 0 && self.loser_tree[cmp_node] != usize::MAX { + let challenger = self.loser_tree[cmp_node]; + let challenger_win = + match (&self.cursors[winner], &self.cursors[challenger]) { + (None, _) => true, + (_, None) => false, + (Some(winner), Some(challenger)) => challenger < winner, + }; + if challenger_win { + self.loser_tree[cmp_node] = winner; + winner = challenger; + } else { + self.loser_tree[cmp_node] = challenger; + } + cmp_node /= 2; } + self.loser_tree[cmp_node] = winner; } + self.loser_tree_adjusted = true; } // NB timer records time taken on drop, so there are no @@ -570,45 +606,57 @@ impl SortPreservingMergeStream { let _timer = elapsed_compute.timer(); loop { - match self.heap.pop() { - Some(Reverse(mut cursor)) => { - let stream_idx = cursor.stream_idx(); - let batch_idx = self.batches[stream_idx].len() - 1; - let row_idx = cursor.advance(); - - let mut cursor_finished = false; - // insert the cursor back to heap if the record batch is not exhausted - if !cursor.is_finished() { - self.heap.push(Reverse(cursor)); - } else { - cursor_finished = true; - self.cursor_finished[stream_idx] = true; + // Adjust the loser tree if necessary + if !self.loser_tree_adjusted { + let mut winner = self.loser_tree[0]; + match futures::ready!(self.maybe_poll_stream(cx, winner)) { + Ok(_) => {} + Err(e) => { + self.aborted = true; + return Poll::Ready(Some(Err(e))); } + } - self.in_progress.push(RowIndex { - stream_idx, - batch_idx, - row_idx, - }); - - if self.in_progress.len() == self.batch_size { - return Poll::Ready(Some(self.build_record_batch())); + let mut cmp_node = (num_streams + winner) / 2; + while cmp_node != 0 { + let challenger = self.loser_tree[cmp_node]; + let challenger_win = + match (&self.cursors[winner], &self.cursors[challenger]) { + (None, _) => true, + (_, None) => false, + (Some(winner), Some(challenger)) => challenger < winner, + }; + if challenger_win { + self.loser_tree[cmp_node] = winner; + winner = challenger; } + cmp_node /= 2; + } + self.loser_tree[0] = winner; + self.loser_tree_adjusted = true; + } - // If removed the last row from the cursor, need to fetch a new record - // batch if possible, before looping round again - if cursor_finished { - match futures::ready!(self.maybe_poll_stream(cx, stream_idx)) { - Ok(_) => {} - Err(e) => { - self.aborted = true; - return Poll::Ready(Some(Err(e))); - } - } - } + let min_cursor_idx = self.loser_tree[0]; + let next = self.cursors[min_cursor_idx] + .as_mut() + .filter(|cursor| !cursor.is_finished()) + .map(|cursor| (cursor.stream_idx(), cursor.advance())); + + if let Some((stream_idx, row_idx)) = next { + self.loser_tree_adjusted = false; + let batch_idx = self.batches[stream_idx].len() - 1; + self.in_progress.push(RowIndex { + stream_idx, + batch_idx, + row_idx, + }); + if self.in_progress.len() == self.batch_size { + return Poll::Ready(Some(self.build_record_batch())); } - None if self.in_progress.is_empty() => return Poll::Ready(None), - None => return Poll::Ready(Some(self.build_record_batch())), + } else if !self.in_progress.is_empty() { + return Poll::Ready(Some(self.build_record_batch())); + } else { + return Poll::Ready(None); } } }