From 6269fb7cded6d014b8037017e3e97a78d3a50947 Mon Sep 17 00:00:00 2001 From: kamille Date: Sun, 11 Aug 2024 05:55:17 +0800 Subject: [PATCH 1/9] draft. --- .../src/aggregates/group_values/mod.rs | 31 +++- .../src/aggregates/group_values/row.rs | 157 ++++++++++++------ .../physical-plan/src/aggregates/row_hash.rs | 2 +- 3 files changed, 135 insertions(+), 55 deletions(-) diff --git a/datafusion/physical-plan/src/aggregates/group_values/mod.rs b/datafusion/physical-plan/src/aggregates/group_values/mod.rs index be7ac934d7bc..2d0468bbf533 100644 --- a/datafusion/physical-plan/src/aggregates/group_values/mod.rs +++ b/datafusion/physical-plan/src/aggregates/group_values/mod.rs @@ -33,6 +33,31 @@ mod bytes_view; use bytes::GroupValuesByes; use datafusion_physical_expr::binary_map::OutputType; +const GROUP_IDX_HIGH_16_BITS_MASK: u64 = 0xffff000000000000; +const GROUP_IDX_LOW_48_BITS_MASK: u64 = 0x0000ffffffffffff; + +#[derive(Debug, Clone, Copy)] +pub struct GroupIdx(u64); + +impl GroupIdx { + pub fn new(block_id: u16, block_offset: u64) -> Self { + let group_idx_high_part = ((block_id as u64) << 48) & GROUP_IDX_HIGH_16_BITS_MASK; + let group_idx_low_part = block_offset & GROUP_IDX_LOW_48_BITS_MASK; + + Self(group_idx_high_part | group_idx_low_part) + } + + #[inline] + pub fn block_id(&self) -> usize { + ((self.0 & GROUP_IDX_HIGH_16_BITS_MASK) >> 48) as usize + } + + #[inline] + pub fn block_offset(&self) -> usize { + (self.0 & GROUP_IDX_LOW_48_BITS_MASK) as usize + } +} + /// An interning store for group keys pub trait GroupValues: Send { /// Calculates the `groups` for each input row of `cols` @@ -48,13 +73,13 @@ pub trait GroupValues: Send { fn len(&self) -> usize; /// Emits the group values - fn emit(&mut self, emit_to: EmitTo) -> Result>; + fn emit(&mut self, emit_to: EmitTo) -> Result>>; /// Clear the contents and shrink the capacity to the size of the batch (free up memory usage) fn clear_shrink(&mut self, batch: &RecordBatch); } -pub fn new_group_values(schema: SchemaRef) -> Result> { +pub fn new_group_values(schema: SchemaRef, batch_size: usize) -> Result> { if schema.fields.len() == 1 { let d = schema.fields[0].data_type(); @@ -92,5 +117,5 @@ pub fn new_group_values(schema: SchemaRef) -> Result> { } } - Ok(Box::new(GroupValuesRows::try_new(schema)?)) + Ok(Box::new(GroupValuesRows::try_new(schema, batch_size)?)) } diff --git a/datafusion/physical-plan/src/aggregates/group_values/row.rs b/datafusion/physical-plan/src/aggregates/group_values/row.rs index dc948e28bb2d..e3decd4d66be 100644 --- a/datafusion/physical-plan/src/aggregates/group_values/row.rs +++ b/datafusion/physical-plan/src/aggregates/group_values/row.rs @@ -15,7 +15,10 @@ // specific language governing permissions and limitations // under the License. -use crate::aggregates::group_values::GroupValues; +use std::collections::VecDeque; +use std::mem; + +use crate::aggregates::group_values::{GroupBlock, GroupIdx, GroupValues}; use ahash::RandomState; use arrow::compute::cast; use arrow::record_batch::RecordBatch; @@ -44,7 +47,7 @@ pub struct GroupValuesRows { /// /// keys: u64 hashes of the GroupValue /// values: (hash, group_index) - map: RawTable<(u64, usize)>, + map: RawTable<(u64, GroupIdx)>, /// The size of `map` in bytes map_size: usize, @@ -57,7 +60,7 @@ pub struct GroupValuesRows { /// important for multi-column group keys. /// /// [`Row`]: arrow::row::Row - group_values: Option, + group_values_blocks: VecDeque, /// reused buffer to store hashes hashes_buffer: Vec, @@ -67,10 +70,14 @@ pub struct GroupValuesRows { /// Random state for creating hashes random_state: RandomState, + + max_block_size: usize, + + cur_block_id: u16, } impl GroupValuesRows { - pub fn try_new(schema: SchemaRef) -> Result { + pub fn try_new(schema: SchemaRef, page_size: usize) -> Result { let row_converter = RowConverter::new( schema .fields() @@ -90,27 +97,32 @@ impl GroupValuesRows { row_converter, map, map_size: 0, - group_values: None, + group_values_blocks: VecDeque::new(), hashes_buffer: Default::default(), rows_buffer, random_state: Default::default(), + max_block_size: page_size, + cur_block_id: 0, }) } } impl GroupValues for GroupValuesRows { - fn intern(&mut self, cols: &[ArrayRef], groups: &mut Vec) -> Result<()> { + fn intern(&mut self, cols: &[ArrayRef], groups: &mut Vec) -> Result<()> { // Convert the group keys into the row format let group_rows = &mut self.rows_buffer; group_rows.clear(); self.row_converter.append(group_rows, cols)?; let n_rows = group_rows.num_rows(); - let mut group_values = match self.group_values.take() { - Some(group_values) => group_values, - None => self.row_converter.empty_rows(0, 0), + if self.group_values_blocks.is_empty() { + // TODO: calc and use the capacity to init. + let block = self.row_converter.empty_rows(0, 0); + self.group_values_blocks.push_back(block); }; + let mut group_values_blocks = mem::take(&mut self.group_values_blocks); + // tracks to which group each of the input rows belongs groups.clear(); @@ -126,11 +138,17 @@ impl GroupValues for GroupValuesRows { // hash doesn't match, so check the hash first with an integer // comparison first avoid the more expensive comparison with // group value. https://github.com/apache/datafusion/pull/11718 - target_hash == *exist_hash - // verify that the group that we are inserting with hash is - // actually the same key value as the group in - // existing_idx (aka group_values @ row) - && group_rows.row(row) == group_values.row(*group_idx) + if target_hash != *exist_hash { + return false; + } + + // verify that the group that we are inserting with hash is + // actually the same key value as the group in + // existing_idx (aka group_values @ row) + let block_id = group_idx.block_id(); + let block_offset = group_idx.block_offset(); + let group_value = group_values_blocks[block_id].row(block_offset); + group_rows.row(row) == group_value }); let group_idx = match entry { @@ -138,9 +156,20 @@ impl GroupValues for GroupValuesRows { Some((_hash, group_idx)) => *group_idx, // 1.2 Need to create new entry for the group None => { + // Check if the block size has reached the limit, if so we switch to next block. + let block_size = group_values_blocks.back().unwrap().num_rows(); + if block_size == self.max_block_size { + self.cur_block_id += 1; + // TODO: calc and use the capacity to init. + let block = self.row_converter.empty_rows(0, 0); + self.group_values_blocks.push_back(block); + } + // Add new entry to aggr_state and save newly created index - let group_idx = group_values.num_rows(); - group_values.push(group_rows.row(row)); + let cur_group_values = self.group_values_blocks.back_mut().unwrap(); + let block_offset = cur_group_values.num_rows(); + let group_idx = GroupIdx::new(self.cur_block_id, block_offset as u64); + cur_group_values.push(group_rows.row(row)); // for hasher function, use precomputed hash value self.map.insert_accounted( @@ -154,13 +183,13 @@ impl GroupValues for GroupValuesRows { groups.push(group_idx); } - self.group_values = Some(group_values); + self.group_values_blocks = group_values_blocks; Ok(()) } fn size(&self) -> usize { - let group_values_size = self.group_values.as_ref().map(|v| v.size()).unwrap_or(0); + let group_values_size = self.group_values_blocks.as_ref().map(|v| v.size()).unwrap_or(0); self.row_converter.size() + group_values_size + self.map_size @@ -173,72 +202,98 @@ impl GroupValues for GroupValuesRows { } fn len(&self) -> usize { - self.group_values + self.group_values_blocks .as_ref() .map(|group_values| group_values.num_rows()) .unwrap_or(0) } - fn emit(&mut self, emit_to: EmitTo) -> Result> { - let mut group_values = self - .group_values - .take() - .expect("Can not emit from empty rows"); + fn emit(&mut self, emit_to: EmitTo) -> Result>> { + let mut group_values_blocks = mem::take(&mut self + .group_values_blocks); + + if group_values_blocks.is_empty() { + return Ok(Vec::new()); + } let mut output = match emit_to { EmitTo::All => { - let output = self.row_converter.convert_rows(&group_values)?; - group_values.clear(); - output + group_values_blocks.iter_mut().map(|rows_block| { + let output = self.row_converter.convert_rows(rows_block.iter())?; + rows_block.clear(); + Ok(output) + }).collect::>>()? } EmitTo::First(n) => { - let groups_rows = group_values.iter().take(n); - let output = self.row_converter.convert_rows(groups_rows)?; - // Clear out first n group keys by copying them to a new Rows. - // TODO file some ticket in arrow-rs to make this more efficient? - let mut new_group_values = self.row_converter.empty_rows(0, 0); - for row in group_values.iter().skip(n) { - new_group_values.push(row); + // convert it to block + let num_emitted_blocks = if n > self.max_block_size { + n / self.max_block_size + } else { + 1 + }; + + let mut emitted_blocks = Vec::with_capacity(num_emitted_blocks); + for _ in 0..num_emitted_blocks { + let block = group_values_blocks.pop_front().unwrap(); + let converted_block = self.row_converter.convert_rows(block.into_iter())?; + emitted_blocks.push(converted_block); } - std::mem::swap(&mut new_group_values, &mut group_values); + + // let groups_rows = group_values.iter().take(n); + // let output = self.row_converter.convert_rows(groups_rows)?; + // // Clear out first n group keys by copying them to a new Rows. + // // TODO file some ticket in arrow-rs to make this more efficient? + // let mut new_group_values = self.row_converter.empty_rows(0, 0); + // for row in group_values.iter().skip(n) { + // new_group_values.push(row); + // } + // std::mem::swap(&mut new_group_values, &mut group_values); // SAFETY: self.map outlives iterator and is not modified concurrently unsafe { for bucket in self.map.iter() { - // Decrement group index by n - match bucket.as_ref().1.checked_sub(n) { + // Decrement block id by `num_emitted_blocks` + let (_, group_idx, ) = bucket.as_ref(); + let new_block_id = group_idx.block_id().checked_sub(num_emitted_blocks); + match new_block_id { // Group index was >= n, shift value down - Some(sub) => bucket.as_mut().1 = sub, + Some(bid) => { + let block_offset = group_idx.block_offset(); + bucket.as_mut().1 = GroupIdx::new(bid as u16, block_offset as u64); + }, // Group index was < n, so remove from table - None => self.map.erase(bucket), + None => self.map.erase(bucket), } } } - output + emitted_blocks } }; // TODO: Materialize dictionaries in group keys (#7647) - for (field, array) in self.schema.fields.iter().zip(&mut output) { - let expected = field.data_type(); - if let DataType::Dictionary(_, v) = expected { - let actual = array.data_type(); - if v.as_ref() != actual { - return Err(DataFusionError::Internal(format!( - "Converted group rows expected dictionary of {v} got {actual}" - ))); + for one_output in output.iter_mut() { + for (field, array) in self.schema.fields.iter().zip(one_output) { + let expected = field.data_type(); + if let DataType::Dictionary(_, v) = expected { + let actual = array.data_type(); + if v.as_ref() != actual { + return Err(DataFusionError::Internal(format!( + "Converted group rows expected dictionary of {v} got {actual}" + ))); + } + *array = cast(array.as_ref(), expected)?; } - *array = cast(array.as_ref(), expected)?; } } - self.group_values = Some(group_values); + self.group_values_blocks = group_values_blocks; + Ok(output) } fn clear_shrink(&mut self, batch: &RecordBatch) { let count = batch.num_rows(); - self.group_values = self.group_values.take().map(|mut rows| { + self.group_values_blocks = self.group_values_blocks.take().map(|mut rows| { rows.clear(); rows }); diff --git a/datafusion/physical-plan/src/aggregates/row_hash.rs b/datafusion/physical-plan/src/aggregates/row_hash.rs index ed3d6d49f9f3..00263a9e5509 100644 --- a/datafusion/physical-plan/src/aggregates/row_hash.rs +++ b/datafusion/physical-plan/src/aggregates/row_hash.rs @@ -463,7 +463,7 @@ impl GroupedHashAggregateStream { ordering.as_slice(), )?; - let group_values = new_group_values(group_schema)?; + let group_values = new_group_values(group_schema, batch_size)?; timer.done(); let exec_state = ExecutionState::ReadingInput; From 7f5ccd265f9c4110df0eb7600a94882f03c1176f Mon Sep 17 00:00:00 2001 From: kamille Date: Sun, 11 Aug 2024 13:55:27 +0800 Subject: [PATCH 2/9] impl the new interface functions for all `GroupValues` impls. --- .../src/aggregates/group_values/bytes.rs | 12 ++++++------ .../src/aggregates/group_values/bytes_view.rs | 12 ++++++------ .../src/aggregates/group_values/mod.rs | 2 +- .../src/aggregates/group_values/primitive.rs | 10 +++++----- .../src/aggregates/group_values/row.rs | 14 ++++++-------- 5 files changed, 24 insertions(+), 26 deletions(-) diff --git a/datafusion/physical-plan/src/aggregates/group_values/bytes.rs b/datafusion/physical-plan/src/aggregates/group_values/bytes.rs index f789af8b8a02..427e002a4455 100644 --- a/datafusion/physical-plan/src/aggregates/group_values/bytes.rs +++ b/datafusion/physical-plan/src/aggregates/group_values/bytes.rs @@ -15,7 +15,7 @@ // specific language governing permissions and limitations // under the License. -use crate::aggregates::group_values::GroupValues; +use crate::aggregates::group_values::{GroupIdx, GroupValues}; use arrow_array::{Array, ArrayRef, OffsetSizeTrait, RecordBatch}; use datafusion_expr::EmitTo; use datafusion_physical_expr_common::binary_map::{ArrowBytesMap, OutputType}; @@ -44,7 +44,7 @@ impl GroupValues for GroupValuesByes { fn intern( &mut self, cols: &[ArrayRef], - groups: &mut Vec, + groups: &mut Vec, ) -> datafusion_common::Result<()> { assert_eq!(cols.len(), 1); @@ -63,7 +63,7 @@ impl GroupValues for GroupValuesByes { }, // called for each group |group_idx| { - groups.push(group_idx); + groups.push(GroupIdx::new(0, group_idx as u64)); }, ); @@ -84,7 +84,7 @@ impl GroupValues for GroupValuesByes { self.num_groups } - fn emit(&mut self, emit_to: EmitTo) -> datafusion_common::Result> { + fn emit(&mut self, emit_to: EmitTo) -> datafusion_common::Result>> { // Reset the map to default, and convert it into a single array let map_contents = self.map.take().into_state(); @@ -111,13 +111,13 @@ impl GroupValues for GroupValuesByes { self.intern(&[remaining_group_values], &mut group_indexes)?; // Verify that the group indexes were assigned in the correct order - assert_eq!(0, group_indexes[0]); + assert_eq!(0, group_indexes[0].block_offset()); emit_group_values } }; - Ok(vec![group_values]) + Ok(vec![vec![group_values]]) } fn clear_shrink(&mut self, _batch: &RecordBatch) { diff --git a/datafusion/physical-plan/src/aggregates/group_values/bytes_view.rs b/datafusion/physical-plan/src/aggregates/group_values/bytes_view.rs index 1a0cb90a16d4..beca77eca9cf 100644 --- a/datafusion/physical-plan/src/aggregates/group_values/bytes_view.rs +++ b/datafusion/physical-plan/src/aggregates/group_values/bytes_view.rs @@ -15,7 +15,7 @@ // specific language governing permissions and limitations // under the License. -use crate::aggregates::group_values::GroupValues; +use crate::aggregates::group_values::{GroupIdx, GroupValues}; use arrow_array::{Array, ArrayRef, RecordBatch}; use datafusion_expr::EmitTo; use datafusion_physical_expr::binary_map::OutputType; @@ -45,7 +45,7 @@ impl GroupValues for GroupValuesBytesView { fn intern( &mut self, cols: &[ArrayRef], - groups: &mut Vec, + groups: &mut Vec, ) -> datafusion_common::Result<()> { assert_eq!(cols.len(), 1); @@ -64,7 +64,7 @@ impl GroupValues for GroupValuesBytesView { }, // called for each group |group_idx| { - groups.push(group_idx); + groups.push(GroupIdx::new(0, group_idx as u64)); }, ); @@ -85,7 +85,7 @@ impl GroupValues for GroupValuesBytesView { self.num_groups } - fn emit(&mut self, emit_to: EmitTo) -> datafusion_common::Result> { + fn emit(&mut self, emit_to: EmitTo) -> datafusion_common::Result>> { // Reset the map to default, and convert it into a single array let map_contents = self.map.take().into_state(); @@ -112,13 +112,13 @@ impl GroupValues for GroupValuesBytesView { self.intern(&[remaining_group_values], &mut group_indexes)?; // Verify that the group indexes were assigned in the correct order - assert_eq!(0, group_indexes[0]); + assert_eq!(0, group_indexes[0].block_offset()); emit_group_values } }; - Ok(vec![group_values]) + Ok(vec![vec![group_values]]) } fn clear_shrink(&mut self, _batch: &RecordBatch) { diff --git a/datafusion/physical-plan/src/aggregates/group_values/mod.rs b/datafusion/physical-plan/src/aggregates/group_values/mod.rs index 2d0468bbf533..727647e9cb65 100644 --- a/datafusion/physical-plan/src/aggregates/group_values/mod.rs +++ b/datafusion/physical-plan/src/aggregates/group_values/mod.rs @@ -61,7 +61,7 @@ impl GroupIdx { /// An interning store for group keys pub trait GroupValues: Send { /// Calculates the `groups` for each input row of `cols` - fn intern(&mut self, cols: &[ArrayRef], groups: &mut Vec) -> Result<()>; + fn intern(&mut self, cols: &[ArrayRef], groups: &mut Vec) -> Result<()>; /// Returns the number of bytes used by this [`GroupValues`] fn size(&self) -> usize; diff --git a/datafusion/physical-plan/src/aggregates/group_values/primitive.rs b/datafusion/physical-plan/src/aggregates/group_values/primitive.rs index d5b7f1b11ac5..1ccaa9f5b7c1 100644 --- a/datafusion/physical-plan/src/aggregates/group_values/primitive.rs +++ b/datafusion/physical-plan/src/aggregates/group_values/primitive.rs @@ -15,7 +15,7 @@ // specific language governing permissions and limitations // under the License. -use crate::aggregates::group_values::GroupValues; +use crate::aggregates::group_values::{GroupIdx, GroupValues}; use ahash::RandomState; use arrow::array::BooleanBufferBuilder; use arrow::buffer::NullBuffer; @@ -111,7 +111,7 @@ impl GroupValues for GroupValuesPrimitive where T::Native: HashValue, { - fn intern(&mut self, cols: &[ArrayRef], groups: &mut Vec) -> Result<()> { + fn intern(&mut self, cols: &[ArrayRef], groups: &mut Vec) -> Result<()> { assert_eq!(cols.len(), 1); groups.clear(); @@ -145,7 +145,7 @@ where } } }; - groups.push(group_id) + groups.push(GroupIdx::new(0, group_id as u64)) } Ok(()) } @@ -162,7 +162,7 @@ where self.values.len() } - fn emit(&mut self, emit_to: EmitTo) -> Result> { + fn emit(&mut self, emit_to: EmitTo) -> Result>> { fn build_primitive( values: Vec, null_idx: Option, @@ -207,7 +207,7 @@ where build_primitive(split, null_group) } }; - Ok(vec![Arc::new(array.with_data_type(self.data_type.clone()))]) + Ok(vec![vec![Arc::new(array.with_data_type(self.data_type.clone()))]]) } fn clear_shrink(&mut self, batch: &RecordBatch) { diff --git a/datafusion/physical-plan/src/aggregates/group_values/row.rs b/datafusion/physical-plan/src/aggregates/group_values/row.rs index e3decd4d66be..b1a38d653d1c 100644 --- a/datafusion/physical-plan/src/aggregates/group_values/row.rs +++ b/datafusion/physical-plan/src/aggregates/group_values/row.rs @@ -18,7 +18,7 @@ use std::collections::VecDeque; use std::mem; -use crate::aggregates::group_values::{GroupBlock, GroupIdx, GroupValues}; +use crate::aggregates::group_values::{GroupIdx, GroupValues}; use ahash::RandomState; use arrow::compute::cast; use arrow::record_batch::RecordBatch; @@ -189,7 +189,7 @@ impl GroupValues for GroupValuesRows { } fn size(&self) -> usize { - let group_values_size = self.group_values_blocks.as_ref().map(|v| v.size()).unwrap_or(0); + let group_values_size = self.group_values_blocks.iter().map(|v| v.size()).sum::(); self.row_converter.size() + group_values_size + self.map_size @@ -203,9 +203,9 @@ impl GroupValues for GroupValuesRows { fn len(&self) -> usize { self.group_values_blocks - .as_ref() + .iter() .map(|group_values| group_values.num_rows()) - .unwrap_or(0) + .sum::() } fn emit(&mut self, emit_to: EmitTo) -> Result>> { @@ -293,10 +293,8 @@ impl GroupValues for GroupValuesRows { fn clear_shrink(&mut self, batch: &RecordBatch) { let count = batch.num_rows(); - self.group_values_blocks = self.group_values_blocks.take().map(|mut rows| { - rows.clear(); - rows - }); + let mut old_blocks = mem::replace(&mut self.group_values_blocks, VecDeque::new()); + old_blocks.iter_mut().for_each(|block| block.clear()); self.map.clear(); self.map.shrink_to(count, |_| 0); // hasher does not matter since the map is cleared self.map_size = self.map.capacity() * std::mem::size_of::<(u64, usize)>(); From c35536882062d8d2e6628d4df517edf91c796c51 Mon Sep 17 00:00:00 2001 From: kamille Date: Sun, 11 Aug 2024 15:46:49 +0800 Subject: [PATCH 3/9] just make `GroupIdx` an internal concept first. --- .../physical-plan/src/aggregates/group_values/bytes.rs | 8 ++++---- .../src/aggregates/group_values/bytes_view.rs | 10 +++++----- .../physical-plan/src/aggregates/group_values/mod.rs | 2 +- .../src/aggregates/group_values/primitive.rs | 6 +++--- .../physical-plan/src/aggregates/group_values/row.rs | 4 ++-- datafusion/physical-plan/src/aggregates/row_hash.rs | 4 ++-- 6 files changed, 17 insertions(+), 17 deletions(-) diff --git a/datafusion/physical-plan/src/aggregates/group_values/bytes.rs b/datafusion/physical-plan/src/aggregates/group_values/bytes.rs index 427e002a4455..024bce858c62 100644 --- a/datafusion/physical-plan/src/aggregates/group_values/bytes.rs +++ b/datafusion/physical-plan/src/aggregates/group_values/bytes.rs @@ -15,7 +15,7 @@ // specific language governing permissions and limitations // under the License. -use crate::aggregates::group_values::{GroupIdx, GroupValues}; +use crate::aggregates::group_values::GroupValues; use arrow_array::{Array, ArrayRef, OffsetSizeTrait, RecordBatch}; use datafusion_expr::EmitTo; use datafusion_physical_expr_common::binary_map::{ArrowBytesMap, OutputType}; @@ -44,7 +44,7 @@ impl GroupValues for GroupValuesByes { fn intern( &mut self, cols: &[ArrayRef], - groups: &mut Vec, + groups: &mut Vec, ) -> datafusion_common::Result<()> { assert_eq!(cols.len(), 1); @@ -63,7 +63,7 @@ impl GroupValues for GroupValuesByes { }, // called for each group |group_idx| { - groups.push(GroupIdx::new(0, group_idx as u64)); + groups.push(group_idx); }, ); @@ -111,7 +111,7 @@ impl GroupValues for GroupValuesByes { self.intern(&[remaining_group_values], &mut group_indexes)?; // Verify that the group indexes were assigned in the correct order - assert_eq!(0, group_indexes[0].block_offset()); + assert_eq!(0, group_indexes[0]); emit_group_values } diff --git a/datafusion/physical-plan/src/aggregates/group_values/bytes_view.rs b/datafusion/physical-plan/src/aggregates/group_values/bytes_view.rs index beca77eca9cf..34b6a74e70d6 100644 --- a/datafusion/physical-plan/src/aggregates/group_values/bytes_view.rs +++ b/datafusion/physical-plan/src/aggregates/group_values/bytes_view.rs @@ -15,7 +15,7 @@ // specific language governing permissions and limitations // under the License. -use crate::aggregates::group_values::{GroupIdx, GroupValues}; +use crate::aggregates::group_values::GroupValues; use arrow_array::{Array, ArrayRef, RecordBatch}; use datafusion_expr::EmitTo; use datafusion_physical_expr::binary_map::OutputType; @@ -45,7 +45,7 @@ impl GroupValues for GroupValuesBytesView { fn intern( &mut self, cols: &[ArrayRef], - groups: &mut Vec, + groups: &mut Vec, ) -> datafusion_common::Result<()> { assert_eq!(cols.len(), 1); @@ -64,7 +64,7 @@ impl GroupValues for GroupValuesBytesView { }, // called for each group |group_idx| { - groups.push(GroupIdx::new(0, group_idx as u64)); + groups.push(group_idx); }, ); @@ -85,7 +85,7 @@ impl GroupValues for GroupValuesBytesView { self.num_groups } - fn emit(&mut self, emit_to: EmitTo) -> datafusion_common::Result>> { + fn emit(&mut self, emit_to: EmitTo) -> datafusion_common::Result> { // Reset the map to default, and convert it into a single array let map_contents = self.map.take().into_state(); @@ -112,7 +112,7 @@ impl GroupValues for GroupValuesBytesView { self.intern(&[remaining_group_values], &mut group_indexes)?; // Verify that the group indexes were assigned in the correct order - assert_eq!(0, group_indexes[0].block_offset()); + assert_eq!(0, group_indexes[0]); emit_group_values } diff --git a/datafusion/physical-plan/src/aggregates/group_values/mod.rs b/datafusion/physical-plan/src/aggregates/group_values/mod.rs index 727647e9cb65..2d0468bbf533 100644 --- a/datafusion/physical-plan/src/aggregates/group_values/mod.rs +++ b/datafusion/physical-plan/src/aggregates/group_values/mod.rs @@ -61,7 +61,7 @@ impl GroupIdx { /// An interning store for group keys pub trait GroupValues: Send { /// Calculates the `groups` for each input row of `cols` - fn intern(&mut self, cols: &[ArrayRef], groups: &mut Vec) -> Result<()>; + fn intern(&mut self, cols: &[ArrayRef], groups: &mut Vec) -> Result<()>; /// Returns the number of bytes used by this [`GroupValues`] fn size(&self) -> usize; diff --git a/datafusion/physical-plan/src/aggregates/group_values/primitive.rs b/datafusion/physical-plan/src/aggregates/group_values/primitive.rs index 1ccaa9f5b7c1..b4cc5421e5dd 100644 --- a/datafusion/physical-plan/src/aggregates/group_values/primitive.rs +++ b/datafusion/physical-plan/src/aggregates/group_values/primitive.rs @@ -15,7 +15,7 @@ // specific language governing permissions and limitations // under the License. -use crate::aggregates::group_values::{GroupIdx, GroupValues}; +use crate::aggregates::group_values::GroupValues; use ahash::RandomState; use arrow::array::BooleanBufferBuilder; use arrow::buffer::NullBuffer; @@ -111,7 +111,7 @@ impl GroupValues for GroupValuesPrimitive where T::Native: HashValue, { - fn intern(&mut self, cols: &[ArrayRef], groups: &mut Vec) -> Result<()> { + fn intern(&mut self, cols: &[ArrayRef], groups: &mut Vec) -> Result<()> { assert_eq!(cols.len(), 1); groups.clear(); @@ -145,7 +145,7 @@ where } } }; - groups.push(GroupIdx::new(0, group_id as u64)) + groups.push(group_id) } Ok(()) } diff --git a/datafusion/physical-plan/src/aggregates/group_values/row.rs b/datafusion/physical-plan/src/aggregates/group_values/row.rs index b1a38d653d1c..1fccf447b348 100644 --- a/datafusion/physical-plan/src/aggregates/group_values/row.rs +++ b/datafusion/physical-plan/src/aggregates/group_values/row.rs @@ -108,7 +108,7 @@ impl GroupValuesRows { } impl GroupValues for GroupValuesRows { - fn intern(&mut self, cols: &[ArrayRef], groups: &mut Vec) -> Result<()> { + fn intern(&mut self, cols: &[ArrayRef], groups: &mut Vec) -> Result<()> { // Convert the group keys into the row format let group_rows = &mut self.rows_buffer; group_rows.clear(); @@ -121,7 +121,7 @@ impl GroupValues for GroupValuesRows { self.group_values_blocks.push_back(block); }; - let mut group_values_blocks = mem::take(&mut self.group_values_blocks); + let group_values_blocks = mem::take(&mut self.group_values_blocks); // tracks to which group each of the input rows belongs groups.clear(); diff --git a/datafusion/physical-plan/src/aggregates/row_hash.rs b/datafusion/physical-plan/src/aggregates/row_hash.rs index 00263a9e5509..e4a8f99f960e 100644 --- a/datafusion/physical-plan/src/aggregates/row_hash.rs +++ b/datafusion/physical-plan/src/aggregates/row_hash.rs @@ -21,7 +21,7 @@ use std::sync::Arc; use std::task::{Context, Poll}; use std::vec; -use crate::aggregates::group_values::{new_group_values, GroupValues}; +use crate::aggregates::group_values::{new_group_values, GroupIdx, GroupValues}; use crate::aggregates::order::GroupOrderingFull; use crate::aggregates::{ evaluate_group_by, evaluate_many, evaluate_optional, group_schema, AggregateMode, @@ -353,7 +353,7 @@ pub(crate) struct GroupedHashAggregateStream { /// scratch space for the current input [`RecordBatch`] being /// processed. Reused across batches here to avoid reallocations - current_group_indices: Vec, + current_group_indices: Vec, /// Tracks if this stream is generating input or output exec_state: ExecutionState, From f4a71bf8a75843802076c1ca37d0fe7f8f465c2e Mon Sep 17 00:00:00 2001 From: kamille Date: Sun, 11 Aug 2024 16:25:56 +0800 Subject: [PATCH 4/9] apply it to the main procedure. --- .../src/aggregates/group_values/bytes_view.rs | 2 +- .../src/aggregates/group_values/mod.rs | 7 +++ .../src/aggregates/group_values/row.rs | 2 +- .../physical-plan/src/aggregates/row_hash.rs | 60 ++++++++++++++++--- 4 files changed, 60 insertions(+), 11 deletions(-) diff --git a/datafusion/physical-plan/src/aggregates/group_values/bytes_view.rs b/datafusion/physical-plan/src/aggregates/group_values/bytes_view.rs index 34b6a74e70d6..27c65756c7a8 100644 --- a/datafusion/physical-plan/src/aggregates/group_values/bytes_view.rs +++ b/datafusion/physical-plan/src/aggregates/group_values/bytes_view.rs @@ -85,7 +85,7 @@ impl GroupValues for GroupValuesBytesView { self.num_groups } - fn emit(&mut self, emit_to: EmitTo) -> datafusion_common::Result> { + fn emit(&mut self, emit_to: EmitTo) -> datafusion_common::Result>> { // Reset the map to default, and convert it into a single array let map_contents = self.map.take().into_state(); diff --git a/datafusion/physical-plan/src/aggregates/group_values/mod.rs b/datafusion/physical-plan/src/aggregates/group_values/mod.rs index 2d0468bbf533..4838c8849927 100644 --- a/datafusion/physical-plan/src/aggregates/group_values/mod.rs +++ b/datafusion/physical-plan/src/aggregates/group_values/mod.rs @@ -56,6 +56,13 @@ impl GroupIdx { pub fn block_offset(&self) -> usize { (self.0 & GROUP_IDX_LOW_48_BITS_MASK) as usize } + + pub fn as_flat_group_idx(&self, max_block_size: usize) -> usize { + let block_id = self.block_id(); + let block_offset = self.block_offset(); + + block_id * max_block_size + block_offset + } } /// An interning store for group keys diff --git a/datafusion/physical-plan/src/aggregates/group_values/row.rs b/datafusion/physical-plan/src/aggregates/group_values/row.rs index 1fccf447b348..4e814f94f688 100644 --- a/datafusion/physical-plan/src/aggregates/group_values/row.rs +++ b/datafusion/physical-plan/src/aggregates/group_values/row.rs @@ -180,7 +180,7 @@ impl GroupValues for GroupValuesRows { group_idx } }; - groups.push(group_idx); + groups.push(group_idx.as_flat_group_idx(self.max_block_size)); } self.group_values_blocks = group_values_blocks; diff --git a/datafusion/physical-plan/src/aggregates/row_hash.rs b/datafusion/physical-plan/src/aggregates/row_hash.rs index e4a8f99f960e..33ef05823519 100644 --- a/datafusion/physical-plan/src/aggregates/row_hash.rs +++ b/datafusion/physical-plan/src/aggregates/row_hash.rs @@ -61,7 +61,7 @@ pub(crate) enum ExecutionState { ReadingInput, /// When producing output, the remaining rows to output are stored /// here and are sliced off as needed in batch_size chunks - ProducingOutput(RecordBatch), + ProducingOutput(Vec), /// Produce intermediate aggregate state for each input row without /// aggregation. /// @@ -353,7 +353,7 @@ pub(crate) struct GroupedHashAggregateStream { /// scratch space for the current input [`RecordBatch`] being /// processed. Reused across batches here to avoid reallocations - current_group_indices: Vec, + current_group_indices: Vec, /// Tracks if this stream is generating input or output exec_state: ExecutionState, @@ -798,14 +798,14 @@ impl GroupedHashAggregateStream { /// Create an output RecordBatch with the group keys and /// accumulator states/values specified in emit_to - fn emit(&mut self, emit_to: EmitTo, spilling: bool) -> Result { + fn emit(&mut self, emit_to: EmitTo, spilling: bool) -> Result> { let schema = if spilling { Arc::clone(&self.spill_state.spill_schema) } else { self.schema() }; if self.group_values.is_empty() { - return Ok(RecordBatch::new_empty(schema)); + return Ok(vec![RecordBatch::new_empty(schema)]); } let mut output = self.group_values.emit(emit_to)?; @@ -816,24 +816,66 @@ impl GroupedHashAggregateStream { // Next output each aggregate value for acc in self.accumulators.iter_mut() { match self.mode { - AggregateMode::Partial => output.extend(acc.state(emit_to)?), + AggregateMode::Partial => { + let states = acc.state(emit_to)?; + let mut rows_count_before_cur_block = 0; + for output_block in output.iter_mut() { + let block_start = rows_count_before_cur_block; + let block_end = + rows_count_before_cur_block + output_block[0].len(); + output_block.reserve(states.len()); + for state in states.iter() { + output_block.push(state.slice(block_start, block_end)) + } + + rows_count_before_cur_block = output_block[0].len(); + } + } _ if spilling => { // If spilling, output partial state because the spilled data will be // merged and re-evaluated later. - output.extend(acc.state(emit_to)?) + let states = acc.state(emit_to)?; + let mut rows_count_before_cur_block = 0; + for output_block in output.iter_mut() { + let block_start = rows_count_before_cur_block; + let block_end = + rows_count_before_cur_block + output_block[0].len(); + output_block.reserve(states.len()); + for state in states.iter() { + output_block.push(state.slice(block_start, block_end)) + } + + rows_count_before_cur_block = output_block[0].len(); + } } AggregateMode::Final | AggregateMode::FinalPartitioned | AggregateMode::Single - | AggregateMode::SinglePartitioned => output.push(acc.evaluate(emit_to)?), + | AggregateMode::SinglePartitioned => { + let state = acc.evaluate(emit_to)?; + let mut rows_count_before_cur_block = 0; + for output_block in output.iter_mut() { + let block_start = rows_count_before_cur_block; + let block_end = + rows_count_before_cur_block + output_block[0].len(); + output_block.push(state.slice(block_start, block_end)); + rows_count_before_cur_block = output_block[0].len(); + } + } } } // emit reduces the memory usage. Ignore Err from update_memory_reservation. Even if it is // over the target memory size after emission, we can emit again rather than returning Err. let _ = self.update_memory_reservation(); - let batch = RecordBatch::try_new(schema, output)?; - Ok(batch) + let batches = output + .into_iter() + .map(|o| { + RecordBatch::try_new(schema.clone(), o) + .map_err(|e| DataFusionError::ArrowError(e, None)) + }) + .collect::>>()?; + Ok(batches) } /// Optimistically, [`Self::group_aggregate_batch`] allows to exceed the memory target slightly From 8deecab4280453a2056544896cfd8f0b1edd49ff Mon Sep 17 00:00:00 2001 From: kamille Date: Sun, 11 Aug 2024 17:36:35 +0800 Subject: [PATCH 5/9] fix spilling(but maybe lead to performance problem currently). --- .../src/aggregates/group_values/row.rs | 8 +++-- .../physical-plan/src/aggregates/row_hash.rs | 31 ++++++++++++------- 2 files changed, 25 insertions(+), 14 deletions(-) diff --git a/datafusion/physical-plan/src/aggregates/group_values/row.rs b/datafusion/physical-plan/src/aggregates/group_values/row.rs index 4e814f94f688..a3a0c5c6284c 100644 --- a/datafusion/physical-plan/src/aggregates/group_values/row.rs +++ b/datafusion/physical-plan/src/aggregates/group_values/row.rs @@ -216,14 +216,15 @@ impl GroupValues for GroupValuesRows { return Ok(Vec::new()); } - let mut output = match emit_to { + let vec = match emit_to { EmitTo::All => { group_values_blocks.iter_mut().map(|rows_block| { let output = self.row_converter.convert_rows(rows_block.iter())?; rows_block.clear(); Ok(output) }).collect::>>()? - } + }, + EmitTo::First(n) => { // convert it to block let num_emitted_blocks = if n > self.max_block_size { @@ -267,8 +268,9 @@ impl GroupValues for GroupValuesRows { } } emitted_blocks - } + }, }; + let mut output = vec; // TODO: Materialize dictionaries in group keys (#7647) for one_output in output.iter_mut() { diff --git a/datafusion/physical-plan/src/aggregates/row_hash.rs b/datafusion/physical-plan/src/aggregates/row_hash.rs index 33ef05823519..18ec433d3bdc 100644 --- a/datafusion/physical-plan/src/aggregates/row_hash.rs +++ b/datafusion/physical-plan/src/aggregates/row_hash.rs @@ -37,6 +37,7 @@ use crate::{aggregates, metrics, ExecutionPlan, PhysicalExpr}; use crate::{RecordBatchStream, SendableRecordBatchStream}; use arrow::array::*; +use arrow::compute::concat_batches; use arrow::datatypes::SchemaRef; use arrow_schema::SortOptions; use datafusion_common::{internal_datafusion_err, DataFusionError, Result}; @@ -901,10 +902,12 @@ impl GroupedHashAggregateStream { /// Emit all rows, sort them, and store them on disk. fn spill(&mut self) -> Result<()> { - let emit = self.emit(EmitTo::All, true)?; - let sorted = sort_batch(&emit, &self.spill_state.spill_expr, None)?; + let emitteds = self.emit(EmitTo::All, true)?; + // TODO: maybe we should concat it gradually in `emit` for saving memory? + let single_batch = concat_batches(&emitteds[0].schema(), &emitteds)?; + let sorted = sort_batch(&single_batch, &self.spill_state.spill_expr, None)?; let spillfile = self.runtime.disk_manager.create_tmp_file("HashAggSpill")?; - let mut writer = IPCWriter::new(spillfile.path(), &emit.schema())?; + let mut writer = IPCWriter::new(spillfile.path(), &single_batch.schema())?; // TODO: slice large `sorted` and write to multiple files in parallel let mut offset = 0; let total_rows = sorted.num_rows(); @@ -955,23 +958,29 @@ impl GroupedHashAggregateStream { /// Conduct a streaming merge sort between the batch and spilled data. Since the stream is fully /// sorted, set `self.group_ordering` to Full, then later we can read with [`EmitTo::First`]. fn update_merged_stream(&mut self) -> Result<()> { - let batch = self.emit(EmitTo::All, true)?; + let emitteds = self.emit(EmitTo::All, true)?; // clear up memory for streaming_merge self.clear_all(); self.update_memory_reservation()?; let mut streams: Vec = vec![]; let expr = self.spill_state.spill_expr.clone(); - let schema = batch.schema(); - streams.push(Box::pin(RecordBatchStreamAdapter::new( - Arc::clone(&schema), - futures::stream::once(futures::future::lazy(move |_| { - sort_batch(&batch, &expr, None) - })), - ))); + let schema = emitteds[0].schema(); + + for batch in emitteds { + let expr_clone = expr.clone(); + streams.push(Box::pin(RecordBatchStreamAdapter::new( + Arc::clone(&schema), + futures::stream::once(futures::future::lazy(move |_| { + sort_batch(&batch, &expr_clone, None) + })), + ))); + } + for spill in self.spill_state.spills.drain(..) { let stream = read_spill_as_stream(spill, Arc::clone(&schema), 2)?; streams.push(stream); } + self.spill_state.is_stream_merging = true; self.input = streaming_merge( streams, From 92be0ac0dd39ebbc73aa6318539f3f257ae472ad Mon Sep 17 00:00:00 2001 From: kamille Date: Sun, 11 Aug 2024 18:03:11 +0800 Subject: [PATCH 6/9] fix compile. --- .../physical-plan/src/aggregates/row_hash.rs | 63 +++++++++++-------- 1 file changed, 37 insertions(+), 26 deletions(-) diff --git a/datafusion/physical-plan/src/aggregates/row_hash.rs b/datafusion/physical-plan/src/aggregates/row_hash.rs index 18ec433d3bdc..9bd9f9cb2810 100644 --- a/datafusion/physical-plan/src/aggregates/row_hash.rs +++ b/datafusion/physical-plan/src/aggregates/row_hash.rs @@ -17,11 +17,12 @@ //! Hash aggregation +use std::collections::VecDeque; use std::sync::Arc; use std::task::{Context, Poll}; -use std::vec; +use std::{mem, vec}; -use crate::aggregates::group_values::{new_group_values, GroupIdx, GroupValues}; +use crate::aggregates::group_values::{new_group_values, GroupValues}; use crate::aggregates::order::GroupOrderingFull; use crate::aggregates::{ evaluate_group_by, evaluate_many, evaluate_optional, group_schema, AggregateMode, @@ -62,7 +63,7 @@ pub(crate) enum ExecutionState { ReadingInput, /// When producing output, the remaining rows to output are stored /// here and are sliced off as needed in batch_size chunks - ProducingOutput(Vec), + ProducingOutput(VecDeque), /// Produce intermediate aggregate state for each input row without /// aggregation. /// @@ -570,9 +571,10 @@ impl Stream for GroupedHashAggregateStream { cx: &mut Context<'_>, ) -> Poll> { let elapsed_compute = self.baseline_metrics.elapsed_compute().clone(); + let batch_size = self.batch_size; loop { - match &self.exec_state { + match &mut self.exec_state { ExecutionState::ReadingInput => 'reading_input: { match ready!(self.input.poll_next_unpin(cx)) { // new batch to aggregate @@ -649,31 +651,39 @@ impl Stream for GroupedHashAggregateStream { } } - ExecutionState::ProducingOutput(batch) => { - // slice off a part of the batch, if needed - let output_batch; - let size = self.batch_size; - (self.exec_state, output_batch) = if batch.num_rows() <= size { - ( - if self.input_done { + ExecutionState::ProducingOutput(batches) => { + // If the buffered batches have been empty, we just switch and state and continue the loop. + if batches.is_empty() { + self.exec_state = if self.input_done { ExecutionState::Done } else if self.should_skip_aggregation() { ExecutionState::SkippingAggregation } else { ExecutionState::ReadingInput - }, - batch.clone(), - ) - } else { - // output first batch_size rows - let size = self.batch_size; - let num_remaining = batch.num_rows() - size; - let remaining = batch.slice(size, num_remaining); - let output = batch.slice(0, size); - (ExecutionState::ProducingOutput(remaining), output) - }; + }; + continue; + } + + // If `cur_record`'s size has been smaller than `batch_size`, + // just pop and return it. + let cur_record = batches.front().unwrap(); + let cur_record_size = cur_record.num_rows(); + if cur_record_size <= batch_size { + let output = batches.pop_front().unwrap(); + + return Poll::Ready(Some(Ok( + output.record_output(&self.baseline_metrics) + ))); + } + + // If `cur_record`'s size is bigger than `batch_size`, we can just return part of it. + let num_remaining = cur_record_size - batch_size; + let mut cur_remaining = cur_record.slice(batch_size, num_remaining); + let output = cur_record.slice(0, batch_size); + mem::swap(&mut cur_remaining, batches.front_mut().unwrap()); + return Poll::Ready(Some(Ok( - output_batch.record_output(&self.baseline_metrics) + output.record_output(&self.baseline_metrics) ))); } @@ -799,14 +809,14 @@ impl GroupedHashAggregateStream { /// Create an output RecordBatch with the group keys and /// accumulator states/values specified in emit_to - fn emit(&mut self, emit_to: EmitTo, spilling: bool) -> Result> { + fn emit(&mut self, emit_to: EmitTo, spilling: bool) -> Result> { let schema = if spilling { Arc::clone(&self.spill_state.spill_schema) } else { self.schema() }; if self.group_values.is_empty() { - return Ok(vec![RecordBatch::new_empty(schema)]); + return Ok(VecDeque::from([RecordBatch::new_empty(schema)])); } let mut output = self.group_values.emit(emit_to)?; @@ -875,7 +885,8 @@ impl GroupedHashAggregateStream { RecordBatch::try_new(schema.clone(), o) .map_err(|e| DataFusionError::ArrowError(e, None)) }) - .collect::>>()?; + .collect::>>()?; + Ok(batches) } From 73a7503eb27e7cd1eb722c54509c01a1efce128c Mon Sep 17 00:00:00 2001 From: kamille Date: Sun, 11 Aug 2024 18:11:05 +0800 Subject: [PATCH 7/9] fix test. --- .../physical-plan/src/aggregates/group_values/row.rs | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/datafusion/physical-plan/src/aggregates/group_values/row.rs b/datafusion/physical-plan/src/aggregates/group_values/row.rs index a3a0c5c6284c..3740c7cefe60 100644 --- a/datafusion/physical-plan/src/aggregates/group_values/row.rs +++ b/datafusion/physical-plan/src/aggregates/group_values/row.rs @@ -121,7 +121,7 @@ impl GroupValues for GroupValuesRows { self.group_values_blocks.push_back(block); }; - let group_values_blocks = mem::take(&mut self.group_values_blocks); + let mut group_values_blocks = mem::take(&mut self.group_values_blocks); // tracks to which group each of the input rows belongs groups.clear(); @@ -157,16 +157,16 @@ impl GroupValues for GroupValuesRows { // 1.2 Need to create new entry for the group None => { // Check if the block size has reached the limit, if so we switch to next block. - let block_size = group_values_blocks.back().unwrap().num_rows(); + let block_size = group_values_blocks.back().unwrap().num_rows(); if block_size == self.max_block_size { self.cur_block_id += 1; // TODO: calc and use the capacity to init. let block = self.row_converter.empty_rows(0, 0); - self.group_values_blocks.push_back(block); + group_values_blocks.push_back(block); } // Add new entry to aggr_state and save newly created index - let cur_group_values = self.group_values_blocks.back_mut().unwrap(); + let cur_group_values = group_values_blocks.back_mut().unwrap(); let block_offset = cur_group_values.num_rows(); let group_idx = GroupIdx::new(self.cur_block_id, block_offset as u64); cur_group_values.push(group_rows.row(row)); From 06e73423073eaea33965b7cc66e67bd1f0eeb227 Mon Sep 17 00:00:00 2001 From: kamille Date: Sun, 11 Aug 2024 18:54:39 +0800 Subject: [PATCH 8/9] make emit first n exact. --- .../src/aggregates/group_values/row.rs | 99 ++++++++++++------- 1 file changed, 66 insertions(+), 33 deletions(-) diff --git a/datafusion/physical-plan/src/aggregates/group_values/row.rs b/datafusion/physical-plan/src/aggregates/group_values/row.rs index 3740c7cefe60..981c92dca6b2 100644 --- a/datafusion/physical-plan/src/aggregates/group_values/row.rs +++ b/datafusion/physical-plan/src/aggregates/group_values/row.rs @@ -189,7 +189,11 @@ impl GroupValues for GroupValuesRows { } fn size(&self) -> usize { - let group_values_size = self.group_values_blocks.iter().map(|v| v.size()).sum::(); + let group_values_size = self + .group_values_blocks + .iter() + .map(|v| v.size()) + .sum::(); self.row_converter.size() + group_values_size + self.map_size @@ -209,66 +213,95 @@ impl GroupValues for GroupValuesRows { } fn emit(&mut self, emit_to: EmitTo) -> Result>> { - let mut group_values_blocks = mem::take(&mut self - .group_values_blocks); + let mut group_values_blocks = mem::take(&mut self.group_values_blocks); if group_values_blocks.is_empty() { return Ok(Vec::new()); } let vec = match emit_to { - EmitTo::All => { - group_values_blocks.iter_mut().map(|rows_block| { + EmitTo::All => group_values_blocks + .iter_mut() + .map(|rows_block| { let output = self.row_converter.convert_rows(rows_block.iter())?; rows_block.clear(); Ok(output) - }).collect::>>()? - }, - + }) + .collect::>>()?, + EmitTo::First(n) => { // convert it to block - let num_emitted_blocks = if n > self.max_block_size { - n / self.max_block_size - } else { - 1 - }; - - let mut emitted_blocks = Vec::with_capacity(num_emitted_blocks); - for _ in 0..num_emitted_blocks { + let num_blocks = n / self.max_block_size; + let num_first_rows_in_last_block = n % self.max_block_size; + + let mut emitted_blocks = Vec::with_capacity(num_blocks + 1); + + // Collect the complete emitted blocks + for _ in 0..num_blocks { let block = group_values_blocks.pop_front().unwrap(); - let converted_block = self.row_converter.convert_rows(block.into_iter())?; + let converted_block = + self.row_converter.convert_rows(block.into_iter())?; emitted_blocks.push(converted_block); } - // let groups_rows = group_values.iter().take(n); - // let output = self.row_converter.convert_rows(groups_rows)?; - // // Clear out first n group keys by copying them to a new Rows. - // // TODO file some ticket in arrow-rs to make this more efficient? - // let mut new_group_values = self.row_converter.empty_rows(0, 0); - // for row in group_values.iter().skip(n) { - // new_group_values.push(row); - // } - // std::mem::swap(&mut new_group_values, &mut group_values); + // Cut off the last block and collect if needed + if num_first_rows_in_last_block > 0 { + let last_output_rows = group_values_blocks + .front() + .unwrap() + .iter() + .take(num_first_rows_in_last_block); + let last_output_block = + self.row_converter.convert_rows(last_output_rows)?; + + let mut remaining_rows = self.row_converter.empty_rows(0, 0); + // TODO file some ticket in arrow-rs to make this more efficient? + for row in group_values_blocks + .front() + .unwrap() + .iter() + .skip(num_first_rows_in_last_block) + { + remaining_rows.push(row); + } + + std::mem::swap( + group_values_blocks.front_mut().unwrap(), + &mut remaining_rows, + ); + + emitted_blocks.push(last_output_block); + } // SAFETY: self.map outlives iterator and is not modified concurrently + let num_emitted_blocks = emitted_blocks.len(); unsafe { for bucket in self.map.iter() { // Decrement block id by `num_emitted_blocks` - let (_, group_idx, ) = bucket.as_ref(); - let new_block_id = group_idx.block_id().checked_sub(num_emitted_blocks); + let (_, group_idx) = bucket.as_ref(); + let new_block_id = + group_idx.block_id().checked_sub(num_emitted_blocks); + match new_block_id { // Group index was >= n, shift value down Some(bid) => { - let block_offset = group_idx.block_offset(); - bucket.as_mut().1 = GroupIdx::new(bid as u16, block_offset as u64); - }, + if bid == 0 && num_first_rows_in_last_block > 0 { + let new_block_offset = group_idx.block_offset() - num_first_rows_in_last_block; + bucket.as_mut().1 = + GroupIdx::new(bid as u16, new_block_offset as u64); + } else { + let block_offset = group_idx.block_offset(); + bucket.as_mut().1 = + GroupIdx::new(bid as u16, block_offset as u64); + } + } // Group index was < n, so remove from table - None => self.map.erase(bucket), + None => self.map.erase(bucket), } } } emitted_blocks - }, + } }; let mut output = vec; From 56399f7bd87c70b94be8b39d38fb5484289d3915 Mon Sep 17 00:00:00 2001 From: kamille Date: Sun, 11 Aug 2024 19:07:25 +0800 Subject: [PATCH 9/9] fix some cases in `streaming_aggregate_test` but still some faileds... --- .../physical-plan/src/aggregates/row_hash.rs | 26 ++++++++++--------- 1 file changed, 14 insertions(+), 12 deletions(-) diff --git a/datafusion/physical-plan/src/aggregates/row_hash.rs b/datafusion/physical-plan/src/aggregates/row_hash.rs index 9bd9f9cb2810..9e36439e033d 100644 --- a/datafusion/physical-plan/src/aggregates/row_hash.rs +++ b/datafusion/physical-plan/src/aggregates/row_hash.rs @@ -815,6 +815,7 @@ impl GroupedHashAggregateStream { } else { self.schema() }; + if self.group_values.is_empty() { return Ok(VecDeque::from([RecordBatch::new_empty(schema)])); } @@ -832,14 +833,14 @@ impl GroupedHashAggregateStream { let mut rows_count_before_cur_block = 0; for output_block in output.iter_mut() { let block_start = rows_count_before_cur_block; - let block_end = - rows_count_before_cur_block + output_block[0].len(); + let block_len= output_block[0].len(); + output_block.reserve(states.len()); for state in states.iter() { - output_block.push(state.slice(block_start, block_end)) + output_block.push(state.slice(block_start, block_len)) } - rows_count_before_cur_block = output_block[0].len(); + rows_count_before_cur_block += block_len; } } _ if spilling => { @@ -849,14 +850,14 @@ impl GroupedHashAggregateStream { let mut rows_count_before_cur_block = 0; for output_block in output.iter_mut() { let block_start = rows_count_before_cur_block; - let block_end = - rows_count_before_cur_block + output_block[0].len(); + let block_len = output_block[0].len(); + output_block.reserve(states.len()); for state in states.iter() { - output_block.push(state.slice(block_start, block_end)) + output_block.push(state.slice(block_start, block_len)) } - rows_count_before_cur_block = output_block[0].len(); + rows_count_before_cur_block += block_len; } } AggregateMode::Final @@ -867,10 +868,11 @@ impl GroupedHashAggregateStream { let mut rows_count_before_cur_block = 0; for output_block in output.iter_mut() { let block_start = rows_count_before_cur_block; - let block_end = - rows_count_before_cur_block + output_block[0].len(); - output_block.push(state.slice(block_start, block_end)); - rows_count_before_cur_block = output_block[0].len(); + let block_len = output_block[0].len(); + + output_block.push(state.slice(block_start, block_len)); + + rows_count_before_cur_block += block_len; } } }