Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 9 additions & 3 deletions datafusion/core/src/physical_plan/sorts/cursor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -109,8 +109,14 @@ impl PartialOrd for SortKeyCursor {

impl Ord for SortKeyCursor {
fn cmp(&self, other: &Self) -> Ordering {
self.current()
.cmp(&other.current())
.then_with(|| self.stream_idx.cmp(&other.stream_idx))
match (self.is_finished(), other.is_finished()) {
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
match (self.is_finished(), other.is_finished()) {
// Order finished cursors last
match (self.is_finished(), other.is_finished()) {

(true, true) => Ordering::Equal,
(_, true) => Ordering::Less,
(true, _) => Ordering::Greater,
_ => self
.current()
.cmp(&other.current())
.then_with(|| self.stream_idx.cmp(&other.stream_idx)),
}
}
}
162 changes: 105 additions & 57 deletions datafusion/core/src/physical_plan/sorts/sort_preserving_merge.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,7 @@
//! Defines the sort preserving merge plan

use std::any::Any;
use std::cmp::Reverse;
use std::collections::{BinaryHeap, VecDeque};
use std::collections::VecDeque;
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};
Expand Down Expand Up @@ -304,10 +303,6 @@ pub(crate) struct SortPreservingMergeStream {
/// their rows have been yielded to the output
batches: Vec<VecDeque<RecordBatch>>,

/// Maintain a flag for each stream denoting if the current cursor
/// has finished and needs to poll from the stream
cursor_finished: Vec<bool>,

/// The accumulated row indexes for the next record batch
in_progress: Vec<RowIndex>,

Expand All @@ -323,8 +318,17 @@ pub(crate) struct SortPreservingMergeStream {
/// An id to uniquely identify the input stream batch
next_batch_id: usize,

/// Heap that yields [`SortKeyCursor`] in increasing order
heap: BinaryHeap<Reverse<SortKeyCursor>>,
/// Vector that holds all [`SortKeyCursor`]s
cursors: Vec<Option<SortKeyCursor>>,

/// The loser tree that always produces the minimum cursor
///
/// Node 0 stores the top winner, Nodes 1..num_streams store
/// the loser nodes
loser_tree: Vec<usize>,

/// Identify whether the loser tree is adjusted
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
/// Identify whether the loser tree is adjusted
/// Identify whether the most recently yielded overall winner has been replaced
/// within the loser tree, a value of `false` indicates that they overall winner
/// has been yielded but the loser tree has not been updated

Or something to make it clearer what adjusted actually means.

FWIW a boolean of should_replace_winner or something might be clearer

loser_tree_adjusted: bool,

/// target batch size
batch_size: usize,
Expand Down Expand Up @@ -361,14 +365,15 @@ impl SortPreservingMergeStream {
Ok(Self {
schema,
batches,
cursor_finished: vec![true; stream_count],
streams: MergingStreams::new(wrappers),
column_expressions: expressions.iter().map(|x| x.expr.clone()).collect(),
tracking_metrics,
aborted: false,
in_progress: vec![],
next_batch_id: 0,
heap: BinaryHeap::with_capacity(stream_count),
cursors: (0..stream_count).into_iter().map(|_| None).collect(),
loser_tree: Vec::with_capacity(stream_count),
loser_tree_adjusted: false,
batch_size,
row_converter,
})
Expand All @@ -382,7 +387,11 @@ impl SortPreservingMergeStream {
cx: &mut Context<'_>,
idx: usize,
) -> Poll<ArrowResult<()>> {
if !self.cursor_finished[idx] {
if self.cursors[idx]
.as_ref()
.map(|cursor| !cursor.is_finished())
.unwrap_or(false)
{
// Cursor is not finished - don't need a new RecordBatch yet
return Poll::Ready(Ok(()));
}
Expand Down Expand Up @@ -418,14 +427,12 @@ impl SortPreservingMergeStream {
}
};

let cursor = SortKeyCursor::new(
self.cursors[idx] = Some(SortKeyCursor::new(
idx,
self.next_batch_id, // assign this batch an ID
rows,
);
));
self.next_batch_id += 1;
self.heap.push(Reverse(cursor));
self.cursor_finished[idx] = false;
self.batches[idx].push_back(batch)
} else {
empty_batch = true;
Expand Down Expand Up @@ -551,17 +558,46 @@ impl SortPreservingMergeStream {
if self.aborted {
return Poll::Ready(None);
}
let num_streams = self.streams.num_streams();

// Init all cursors and the loser tree in the first poll
if self.loser_tree.is_empty() {
// Ensure all non-exhausted streams have a cursor from which
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It might be easier to follow if this method were split into a method called init_loser_tree with a doc comment explaining what it does

// rows can be pulled
for i in 0..num_streams {
match futures::ready!(self.maybe_poll_stream(cx, i)) {
Ok(_) => {}
Err(e) => {
self.aborted = true;
return Poll::Ready(Some(Err(e)));
}
}
}

// Ensure all non-exhausted streams have a cursor from which
// rows can be pulled
for i in 0..self.streams.num_streams() {
match futures::ready!(self.maybe_poll_stream(cx, i)) {
Ok(_) => {}
Err(e) => {
self.aborted = true;
return Poll::Ready(Some(Err(e)));
// Init loser tree
self.loser_tree.resize(num_streams, usize::MAX);
for i in 0..num_streams {
let mut winner = i;
let mut cmp_node = (num_streams + i) / 2;
while cmp_node != 0 && self.loser_tree[cmp_node] != usize::MAX {
let challenger = self.loser_tree[cmp_node];
let challenger_win =
match (&self.cursors[winner], &self.cursors[challenger]) {
(None, _) => true,
(_, None) => false,
(Some(winner), Some(challenger)) => challenger < winner,
};
if challenger_win {
self.loser_tree[cmp_node] = winner;
winner = challenger;
} else {
self.loser_tree[cmp_node] = challenger;
}
Comment on lines +590 to +595
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if challenger_win {
self.loser_tree[cmp_node] = winner;
winner = challenger;
} else {
self.loser_tree[cmp_node] = challenger;
}
if challenger_win {
self.loser_tree[cmp_node] = winner;
winner = challenger;
}

cmp_node /= 2;
}
self.loser_tree[cmp_node] = winner;
}
self.loser_tree_adjusted = true;
}

// NB timer records time taken on drop, so there are no
Expand All @@ -570,45 +606,57 @@ impl SortPreservingMergeStream {
let _timer = elapsed_compute.timer();

loop {
match self.heap.pop() {
Some(Reverse(mut cursor)) => {
let stream_idx = cursor.stream_idx();
let batch_idx = self.batches[stream_idx].len() - 1;
let row_idx = cursor.advance();

let mut cursor_finished = false;
// insert the cursor back to heap if the record batch is not exhausted
if !cursor.is_finished() {
self.heap.push(Reverse(cursor));
} else {
cursor_finished = true;
self.cursor_finished[stream_idx] = true;
// Adjust the loser tree if necessary
if !self.loser_tree_adjusted {
let mut winner = self.loser_tree[0];
Copy link
Copy Markdown
Contributor

@tustvold tustvold Nov 21, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It might be easier to follow if this was moved into a method called replace_loser_tree_winner, perhaps with a link to this GIF - https://en.wikipedia.org/wiki/K-way_merge_algorithm#/media/File:Loser_tree_replacement_selection.gif

match futures::ready!(self.maybe_poll_stream(cx, winner)) {
Ok(_) => {}
Err(e) => {
self.aborted = true;
return Poll::Ready(Some(Err(e)));
}
}

self.in_progress.push(RowIndex {
stream_idx,
batch_idx,
row_idx,
});

if self.in_progress.len() == self.batch_size {
return Poll::Ready(Some(self.build_record_batch()));
let mut cmp_node = (num_streams + winner) / 2;
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
let mut cmp_node = (num_streams + winner) / 2;
// Replace overall winner by walking tree of losers
let mut cmp_node = (num_streams + winner) / 2;

while cmp_node != 0 {
let challenger = self.loser_tree[cmp_node];
let challenger_win =
match (&self.cursors[winner], &self.cursors[challenger]) {
(None, _) => true,
(_, None) => false,
(Some(winner), Some(challenger)) => challenger < winner,
};
if challenger_win {
self.loser_tree[cmp_node] = winner;
winner = challenger;
}
cmp_node /= 2;
}
self.loser_tree[0] = winner;
self.loser_tree_adjusted = true;
}

// If removed the last row from the cursor, need to fetch a new record
// batch if possible, before looping round again
if cursor_finished {
match futures::ready!(self.maybe_poll_stream(cx, stream_idx)) {
Ok(_) => {}
Err(e) => {
self.aborted = true;
return Poll::Ready(Some(Err(e)));
}
}
}
let min_cursor_idx = self.loser_tree[0];
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this could be made easier to follow if it were written along the lines of

let min_cursor = self.cursors[min_cursor_idx];
if min_cursor.is_finished() {
    // All streams are exhausted
    return Poll::Ready((!self.in_progress.is_empty()).then(|| self.build_record_batch()))
}

self.loser_tree_adjusted = false;
self.in_progress.push(...)
if self.in_progress.len() == self.batch_size {
    return Poll::Ready(Some(self.build_record_batch()));
}

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I couldn't make this work in #4407

let next = self.cursors[min_cursor_idx]
.as_mut()
.filter(|cursor| !cursor.is_finished())
.map(|cursor| (cursor.stream_idx(), cursor.advance()));

if let Some((stream_idx, row_idx)) = next {
self.loser_tree_adjusted = false;
let batch_idx = self.batches[stream_idx].len() - 1;
self.in_progress.push(RowIndex {
stream_idx,
batch_idx,
row_idx,
});
if self.in_progress.len() == self.batch_size {
return Poll::Ready(Some(self.build_record_batch()));
}
None if self.in_progress.is_empty() => return Poll::Ready(None),
None => return Poll::Ready(Some(self.build_record_batch())),
} else if !self.in_progress.is_empty() {
return Poll::Ready(Some(self.build_record_batch()));
} else {
return Poll::Ready(None);
}
}
}
Expand Down