Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 5 additions & 24 deletions datafusion/core/src/physical_plan/aggregates/order/full.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,6 @@
// specific language governing permissions and limitations
// under the License.

use datafusion_execution::memory_pool::proxy::VecAllocExt;

use crate::physical_expr::EmitTo;

/// Tracks grouping state when the data is ordered entirely by its
Expand Down Expand Up @@ -58,8 +56,6 @@ use crate::physical_expr::EmitTo;
#[derive(Debug)]
pub(crate) struct GroupOrderingFull {
state: State,
/// Hash values for groups in 0..current
hashes: Vec<u64>,
}

#[derive(Debug)]
Expand All @@ -79,7 +75,6 @@ impl GroupOrderingFull {
pub fn new() -> Self {
Self {
state: State::Start,
hashes: vec![],
}
}

Expand All @@ -101,19 +96,17 @@ impl GroupOrderingFull {
}

/// remove the first n groups from the internal state, shifting
/// all existing indexes down by `n`. Returns stored hash values
pub fn remove_groups(&mut self, n: usize) -> &[u64] {
/// all existing indexes down by `n`
pub fn remove_groups(&mut self, n: usize) {
match &mut self.state {
State::Start => panic!("invalid state: start"),
State::InProgress { current } => {
// shift down by n
assert!(*current >= n);
*current -= n;
self.hashes.drain(0..n);
}
State::Complete { .. } => panic!("invalid state: complete"),
};
&self.hashes
}
}

/// Note that the input is complete so any outstanding groups are done as well
Expand All @@ -123,20 +116,8 @@ impl GroupOrderingFull {

/// Called when new groups are added in a batch. See documentation
/// on [`super::GroupOrdering::new_groups`]
pub fn new_groups(
&mut self,
group_indices: &[usize],
batch_hashes: &[u64],
total_num_groups: usize,
) {
pub fn new_groups(&mut self, total_num_groups: usize) {
assert_ne!(total_num_groups, 0);
assert_eq!(group_indices.len(), batch_hashes.len());

// copy any hash values
self.hashes.resize(total_num_groups, 0);
for (&group_index, &hash) in group_indices.iter().zip(batch_hashes.iter()) {
self.hashes[group_index] = hash;
}

// Update state
let max_group_index = total_num_groups - 1;
Expand All @@ -158,6 +139,6 @@ impl GroupOrderingFull {
}

pub(crate) fn size(&self) -> usize {
std::mem::size_of::<Self>() + self.hashes.allocated_size()
std::mem::size_of::<Self>()
}
}
9 changes: 3 additions & 6 deletions datafusion/core/src/physical_plan/aggregates/order/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -84,9 +84,9 @@ impl GroupOrdering {

/// remove the first n groups from the internal state, shifting
/// all existing indexes down by `n`. Returns stored hash values
pub fn remove_groups(&mut self, n: usize) -> &[u64] {
pub fn remove_groups(&mut self, n: usize) {
match self {
GroupOrdering::None => &[],
GroupOrdering::None => {}
GroupOrdering::Partial(partial) => partial.remove_groups(n),
GroupOrdering::Full(full) => full.remove_groups(n),
}
Expand All @@ -106,7 +106,6 @@ impl GroupOrdering {
&mut self,
batch_group_values: &[ArrayRef],
group_indices: &[usize],
batch_hashes: &[u64],
total_num_groups: usize,
) -> Result<()> {
match self {
Expand All @@ -115,13 +114,11 @@ impl GroupOrdering {
partial.new_groups(
batch_group_values,
group_indices,
batch_hashes,
total_num_groups,
)?;
}

GroupOrdering::Full(full) => {
full.new_groups(group_indices, batch_hashes, total_num_groups);
full.new_groups(total_num_groups);
}
};
Ok(())
Expand Down
30 changes: 6 additions & 24 deletions datafusion/core/src/physical_plan/aggregates/order/partial.rs
Original file line number Diff line number Diff line change
Expand Up @@ -71,9 +71,6 @@ pub(crate) struct GroupOrderingPartial {
/// Converter for the sort key (used on the group columns
/// specified in `order_indexes`)
row_converter: RowConverter,

/// Hash values for groups in 0..completed
hashes: Vec<u64>,
}

#[derive(Debug, Default)]
Expand Down Expand Up @@ -127,7 +124,6 @@ impl GroupOrderingPartial {
state: State::Start,
order_indices: order_indices.to_vec(),
row_converter: RowConverter::new(fields)?,
hashes: vec![],
})
}

Expand Down Expand Up @@ -167,8 +163,8 @@ impl GroupOrderingPartial {
}

/// remove the first n groups from the internal state, shifting
/// all existing indexes down by `n`. Returns stored hash values
pub fn remove_groups(&mut self, n: usize) -> &[u64] {
/// all existing indexes down by `n`
pub fn remove_groups(&mut self, n: usize) {
match &mut self.state {
State::Taken => unreachable!("State previously taken"),
State::Start => panic!("invalid state: start"),
Expand All @@ -182,12 +178,9 @@ impl GroupOrderingPartial {
*current -= n;
assert!(*current_sort >= n);
*current_sort -= n;
// Note sort_key stays the same, we are just translating group indexes
self.hashes.drain(0..n);
}
State::Complete { .. } => panic!("invalid state: complete"),
};
&self.hashes
}
}

/// Note that the input is complete so any outstanding groups are done as well
Expand All @@ -204,18 +197,15 @@ impl GroupOrderingPartial {
&mut self,
batch_group_values: &[ArrayRef],
group_indices: &[usize],
batch_hashes: &[u64],
total_num_groups: usize,
) -> Result<()> {
assert!(total_num_groups > 0);
assert!(!batch_group_values.is_empty());
assert_eq!(group_indices.len(), batch_hashes.len());

let max_group_index = total_num_groups - 1;

// compute the sort key values for each group
let sort_keys = self.compute_sort_keys(batch_group_values)?;
assert_eq!(sort_keys.num_rows(), batch_hashes.len());

let old_state = std::mem::take(&mut self.state);
let (mut current_sort, mut sort_key) = match &old_state {
Expand All @@ -231,16 +221,9 @@ impl GroupOrderingPartial {
}
};

// copy any hash values, and find latest sort key
self.hashes.resize(total_num_groups, 0);
let iter = group_indices
.iter()
.zip(batch_hashes.iter())
.zip(sort_keys.iter());

for ((&group_index, &hash), group_sort_key) in iter {
self.hashes[group_index] = hash;

// Find latest sort key
let iter = group_indices.iter().zip(sort_keys.iter());
for (&group_index, group_sort_key) in iter {
// Does this group have seen a new sort_key?
if sort_key != group_sort_key {
current_sort = group_index;
Expand All @@ -262,6 +245,5 @@ impl GroupOrderingPartial {
std::mem::size_of::<Self>()
+ self.order_indices.allocated_size()
+ self.row_converter.size()
+ self.hashes.allocated_size()
}
}
22 changes: 12 additions & 10 deletions datafusion/core/src/physical_plan/aggregates/row_hash.rs
Original file line number Diff line number Diff line change
Expand Up @@ -485,7 +485,6 @@ impl GroupedHashAggregateStream {
self.group_ordering.new_groups(
group_values,
group_indices,
batch_hashes,
total_num_groups,
)?;
}
Expand Down Expand Up @@ -624,15 +623,18 @@ impl GroupedHashAggregateStream {
}
std::mem::swap(&mut new_group_values, &mut self.group_values);

// rebuild hash table (maybe we should remove the
// entries for each group that was emitted rather than
// rebuilding the whole thing

let hashes = self.group_ordering.remove_groups(n);
assert_eq!(hashes.len(), self.group_values.num_rows());
self.map.clear();
for (idx, &hash) in hashes.iter().enumerate() {
self.map.insert(hash, (hash, idx), |(hash, _)| *hash);
self.group_ordering.remove_groups(n);
// SAFETY: self.map outlives iterator and is not modified concurrently
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

unsafe {
for bucket in self.map.iter() {
// Decrement group index by n
match bucket.as_ref().1.checked_sub(n) {
// Group index was >= n, shift value down
Some(sub) => bucket.as_mut().1 = sub,
// Group index was < n, so remove from table
None => self.map.erase(bucket),
}
}
}
}
};
Expand Down