diff --git a/datafusion-examples/examples/advanced_udaf.rs b/datafusion-examples/examples/advanced_udaf.rs index 2c672a18a738..b4c12a75d4df 100644 --- a/datafusion-examples/examples/advanced_udaf.rs +++ b/datafusion-examples/examples/advanced_udaf.rs @@ -289,10 +289,12 @@ impl GroupsAccumulator for GeometricMeanGroupsAccumulator { opt_filter, total_num_groups, |group_index, new_value| { - let prod = &mut self.prods[group_index]; - *prod = prod.mul_wrapping(new_value); + if let Some(new_value) = new_value { + let prod = &mut self.prods[group_index]; + *prod = prod.mul_wrapping(new_value); - self.counts[group_index] += 1; + self.counts[group_index] += 1; + } }, ); @@ -319,7 +321,9 @@ impl GroupsAccumulator for GeometricMeanGroupsAccumulator { opt_filter, total_num_groups, |group_index, partial_count| { - self.counts[group_index] += partial_count; + if let Some(partial_count) = partial_count { + self.counts[group_index] += partial_count; + } }, ); @@ -330,9 +334,12 @@ impl GroupsAccumulator for GeometricMeanGroupsAccumulator { partial_prods, opt_filter, total_num_groups, - |group_index, new_value: ::Native| { - let prod = &mut self.prods[group_index]; - *prod = prod.mul_wrapping(new_value); + |group_index, + new_value: Option<::Native>| { + if let Some(new_value) = new_value { + let prod = &mut self.prods[group_index]; + *prod = prod.mul_wrapping(new_value); + } }, ); diff --git a/datafusion/physical-expr-common/src/aggregate/groups_accumulator/accumulate.rs b/datafusion/physical-expr-common/src/aggregate/groups_accumulator/accumulate.rs index f109079f6a26..a2711de94add 100644 --- a/datafusion/physical-expr-common/src/aggregate/groups_accumulator/accumulate.rs +++ b/datafusion/physical-expr-common/src/aggregate/groups_accumulator/accumulate.rs @@ -19,7 +19,10 @@ //! //! [`GroupsAccumulator`]: datafusion_expr::GroupsAccumulator -use arrow::array::{Array, BooleanArray, BooleanBufferBuilder, PrimitiveArray}; +use arrow::array::{ + Array, ArrayRef, BooleanArray, BooleanBufferBuilder, Int64BufferBuilder, ListArray, + PrimitiveArray, StringArray, +}; use arrow::buffer::{BooleanBuffer, NullBuffer}; use arrow::datatypes::ArrowPrimitiveType; @@ -59,6 +62,8 @@ pub struct NullState { /// If `seen_values[i]` is false, have not seen any values that /// pass the filter yet for group `i` seen_values: BooleanBufferBuilder, + + seen_nulls: Int64BufferBuilder, } impl Default for NullState { @@ -71,13 +76,14 @@ impl NullState { pub fn new() -> Self { Self { seen_values: BooleanBufferBuilder::new(0), + seen_nulls: Int64BufferBuilder::new(0), } } /// return the size of all buffers allocated by this null state, not including self pub fn size(&self) -> usize { // capacity is in bits, so convert to bytes - self.seen_values.capacity() / 8 + self.seen_values.capacity() / 8 + self.seen_nulls.capacity() / 8 } /// Invokes `value_fn(group_index, value)` for each non null, non @@ -132,7 +138,7 @@ impl NullState { mut value_fn: F, ) where T: ArrowPrimitiveType + Send, - F: FnMut(usize, T::Native) + Send, + F: FnMut(usize, Option) + Send, { let data: &[T::Native] = values.values(); assert_eq!(data.len(), group_indices.len()); @@ -141,14 +147,13 @@ impl NullState { // "not seen" valid) let seen_values = initialize_builder(&mut self.seen_values, total_num_groups, false); - match (values.null_count() > 0, opt_filter) { // no nulls, no filter, (false, None) => { let iter = group_indices.iter().zip(data.iter()); for (&group_index, &new_value) in iter { seen_values.set_bit(group_index, true); - value_fn(group_index, new_value); + value_fn(group_index, Some(new_value)); } } // nulls, no filter @@ -175,7 +180,9 @@ impl NullState { let is_valid = (mask & index_mask) != 0; if is_valid { seen_values.set_bit(group_index, true); - value_fn(group_index, new_value); + value_fn(group_index, Some(new_value)); + } else { + value_fn(group_index, None); } index_mask <<= 1; }, @@ -192,7 +199,9 @@ impl NullState { let is_valid = remainder_bits & (1 << i) != 0; if is_valid { seen_values.set_bit(group_index, true); - value_fn(group_index, new_value); + value_fn(group_index, Some(new_value)); + } else { + value_fn(group_index, None); } }); } @@ -209,7 +218,7 @@ impl NullState { .for_each(|((&group_index, &new_value), filter_value)| { if let Some(true) = filter_value { seen_values.set_bit(group_index, true); - value_fn(group_index, new_value); + value_fn(group_index, Some(new_value)); } }) } @@ -227,7 +236,9 @@ impl NullState { if let Some(true) = filter_value { if let Some(new_value) = new_value { seen_values.set_bit(group_index, true); - value_fn(group_index, new_value) + value_fn(group_index, Some(new_value)) + } else { + value_fn(group_index, None); } } }) @@ -324,6 +335,176 @@ impl NullState { } } + /// Invokes `value_fn(group_index, value)` for each non null, non + /// filtered value in `values`, while tracking which groups have + /// seen null inputs and which groups have seen any inputs, for + /// [`ListArray`]s. + /// + /// See [`Self::accumulate`], which handles `PrimitiveArray`s, for + /// more details on other arguments. + pub fn accumulate_array( + &mut self, + group_indices: &[usize], + values: &ListArray, + opt_filter: Option<&BooleanArray>, + total_num_groups: usize, + mut value_fn: F, + mut null_fn: N, + ) where + F: FnMut(usize, ArrayRef) + Send, + N: FnMut(usize) + Send, + { + assert_eq!(values.len(), group_indices.len()); + + // ensure the seen_values is big enough (start everything at + // "not seen" valid) + let seen_values = + initialize_builder(&mut self.seen_values, total_num_groups, false); + + match (values.null_count() > 0, opt_filter) { + // no nulls, no filter, + (false, None) => { + let iter = group_indices.iter().zip(values.iter()); + for (&group_index, new_value) in iter { + seen_values.set_bit(group_index, true); + value_fn(group_index, new_value.unwrap()); + } + } + // nulls, no filter + (true, None) => { + let nulls = values.nulls().unwrap(); + group_indices + .iter() + .zip(values.iter()) + .zip(nulls.iter()) + .for_each(|((&group_index, new_value), is_valid)| { + if is_valid { + seen_values.set_bit(group_index, true); + value_fn(group_index, new_value.unwrap()); + } else { + null_fn(group_index); + } + }) + } + // no nulls, but a filter + (false, Some(filter)) => { + assert_eq!(filter.len(), group_indices.len()); + group_indices + .iter() + .zip(values.iter()) + .zip(filter.iter()) + .for_each(|((&group_index, new_value), filter_value)| { + if let Some(true) = filter_value { + seen_values.set_bit(group_index, true); + value_fn(group_index, new_value.unwrap()); + } + }); + } + // both null values and filters + (true, Some(filter)) => { + assert_eq!(filter.len(), group_indices.len()); + filter + .iter() + .zip(group_indices.iter()) + .zip(values.iter()) + .for_each(|((filter_value, &group_index), new_value)| { + if let Some(true) = filter_value { + if let Some(new_value) = new_value { + seen_values.set_bit(group_index, true); + value_fn(group_index, new_value); + } else { + null_fn(group_index); + } + } + }); + } + } + } + + /// Invokes `value_fn(group_index, value)` for each non-null, + /// non-filtered value in `values`, while tracking which groups have + /// seen null inputs and which groups have seen any inputs, for + /// [`ListArray`]s. + /// + /// See [`Self::accumulate`], which handles `PrimitiveArray`s, for + /// more details on other arguments. + pub fn accumulate_string( + &mut self, + group_indices: &[usize], + values: &StringArray, + opt_filter: Option<&BooleanArray>, + total_num_groups: usize, + mut value_fn: F, + ) where + F: FnMut(usize, Option<&str>) + Send, + { + assert_eq!(values.len(), group_indices.len()); + + // ensure the seen_values is big enough (start everything at + // "not seen" valid) + let seen_values = + initialize_builder(&mut self.seen_values, total_num_groups, false); + + match (values.null_count() > 0, opt_filter) { + // no nulls, no filter, + (false, None) => { + let iter = group_indices.iter().zip(values.iter()); + for (&group_index, new_value) in iter { + seen_values.set_bit(group_index, true); + value_fn(group_index, new_value); + } + } + // nulls, no filter + (true, None) => { + let nulls = values.nulls().unwrap(); + group_indices + .iter() + .zip(values.iter()) + .zip(nulls.iter()) + .for_each(|((&group_index, new_value), is_valid)| { + if is_valid { + seen_values.set_bit(group_index, true); + value_fn(group_index, new_value); + } else { + value_fn(group_index, None); + } + }) + } + // no nulls, but a filter + (false, Some(filter)) => { + assert_eq!(filter.len(), group_indices.len()); + group_indices + .iter() + .zip(values.iter()) + .zip(filter.iter()) + .for_each(|((&group_index, new_value), filter_value)| { + if let Some(true) = filter_value { + seen_values.set_bit(group_index, true); + value_fn(group_index, new_value); + } + }); + } + // both null values and filters + (true, Some(filter)) => { + assert_eq!(filter.len(), group_indices.len()); + filter + .iter() + .zip(group_indices.iter()) + .zip(values.iter()) + .for_each(|((filter_value, &group_index), new_value)| { + if let Some(true) = filter_value { + if let Some(new_value) = new_value { + seen_values.set_bit(group_index, true); + value_fn(group_index, Some(new_value)); + } else { + value_fn(group_index, None); + } + } + }); + } + } + } + /// Creates the a [`NullBuffer`] representing which group_indices /// should have null values (because they never saw any values) /// for the `emit_to` rows. @@ -670,7 +851,9 @@ mod test { opt_filter, total_num_groups, |group_index, value| { - accumulated_values.push((group_index, value)); + if let Some(value) = value { + accumulated_values.push((group_index, value)); + } }, ); diff --git a/datafusion/physical-expr-common/src/aggregate/groups_accumulator/prim_op.rs b/datafusion/physical-expr-common/src/aggregate/groups_accumulator/prim_op.rs index debb36852b22..623547c6a7aa 100644 --- a/datafusion/physical-expr-common/src/aggregate/groups_accumulator/prim_op.rs +++ b/datafusion/physical-expr-common/src/aggregate/groups_accumulator/prim_op.rs @@ -103,8 +103,10 @@ where opt_filter, total_num_groups, |group_index, new_value| { - let value = &mut self.values[group_index]; - (self.prim_fn)(value, new_value); + if let Some(new_value) = new_value { + let value = &mut self.values[group_index]; + (self.prim_fn)(value, new_value); + } }, ); diff --git a/datafusion/physical-expr/src/aggregate/array_agg.rs b/datafusion/physical-expr/src/aggregate/array_agg.rs index a23ba07de44a..d4936aabf112 100644 --- a/datafusion/physical-expr/src/aggregate/array_agg.rs +++ b/datafusion/physical-expr/src/aggregate/array_agg.rs @@ -20,14 +20,24 @@ use crate::aggregate::utils::down_cast_any_ref; use crate::expressions::format_state_name; use crate::{AggregateExpr, PhysicalExpr}; -use arrow::array::ArrayRef; -use arrow::datatypes::{DataType, Field}; -use arrow_array::Array; +use arrow::array::{ArrayRef, AsArray, ListBuilder, PrimitiveBuilder, StringBuilder}; +use arrow::datatypes::{ + ArrowPrimitiveType, DataType, Date32Type, Date64Type, Decimal128Type, Decimal256Type, + DurationMicrosecondType, DurationMillisecondType, DurationNanosecondType, + DurationSecondType, Field, Float32Type, Float64Type, Int16Type, Int32Type, Int64Type, + Int8Type, IntervalDayTimeType, IntervalMonthDayNanoType, IntervalYearMonthType, + Time32MillisecondType, Time32SecondType, Time64MicrosecondType, Time64NanosecondType, + TimestampMicrosecondType, TimestampMillisecondType, TimestampNanosecondType, + TimestampSecondType, UInt16Type, UInt32Type, UInt64Type, UInt8Type, +}; +use arrow_array::{Array, BooleanArray}; +use arrow_schema::{IntervalUnit, TimeUnit}; use datafusion_common::cast::as_list_array; use datafusion_common::utils::array_into_list_array; -use datafusion_common::Result; use datafusion_common::ScalarValue; -use datafusion_expr::Accumulator; +use datafusion_common::{DataFusionError, Result}; +use datafusion_expr::{Accumulator, EmitTo, GroupsAccumulator}; +use datafusion_physical_expr_common::aggregate::groups_accumulator::accumulate::NullState; use std::any::Any; use std::sync::Arc; @@ -42,6 +52,8 @@ pub struct ArrayAgg { expr: Arc, /// If the input expression can have NULLs nullable: bool, + // If the aggregate should ignore NULLs + ignore_nulls: bool, } impl ArrayAgg { @@ -51,12 +63,14 @@ impl ArrayAgg { name: impl Into, data_type: DataType, nullable: bool, + ignore_nulls: bool, ) -> Self { Self { name: name.into(), input_data_type: data_type, expr, nullable, + ignore_nulls, } } } @@ -78,6 +92,7 @@ impl AggregateExpr for ArrayAgg { fn create_accumulator(&self) -> Result> { Ok(Box::new(ArrayAggAccumulator::try_new( &self.input_data_type, + self.ignore_nulls, )?)) } @@ -96,6 +111,189 @@ impl AggregateExpr for ArrayAgg { fn name(&self) -> &str { &self.name } + + fn groups_accumulator_supported(&self) -> bool { + self.input_data_type.is_primitive() || self.input_data_type == DataType::Utf8 + } + + fn create_groups_accumulator(&self) -> Result> { + match self.input_data_type { + DataType::Int8 => Ok(Box::new(ArrayAggGroupsAccumulator::::new( + &self.input_data_type, + self.ignore_nulls, + ))), + DataType::Int16 => Ok(Box::new(ArrayAggGroupsAccumulator::::new( + &self.input_data_type, + self.ignore_nulls, + ))), + DataType::Int32 => Ok(Box::new(ArrayAggGroupsAccumulator::::new( + &self.input_data_type, + self.ignore_nulls, + ))), + DataType::Int64 => Ok(Box::new(ArrayAggGroupsAccumulator::::new( + &self.input_data_type, + self.ignore_nulls, + ))), + DataType::UInt8 => Ok(Box::new(ArrayAggGroupsAccumulator::::new( + &self.input_data_type, + self.ignore_nulls, + ))), + DataType::UInt16 => { + Ok(Box::new(ArrayAggGroupsAccumulator::::new( + &self.input_data_type, + self.ignore_nulls, + ))) + } + DataType::UInt32 => { + Ok(Box::new(ArrayAggGroupsAccumulator::::new( + &self.input_data_type, + self.ignore_nulls, + ))) + } + DataType::UInt64 => { + Ok(Box::new(ArrayAggGroupsAccumulator::::new( + &self.input_data_type, + self.ignore_nulls, + ))) + } + DataType::Float32 => { + Ok(Box::new(ArrayAggGroupsAccumulator::::new( + &self.input_data_type, + self.ignore_nulls, + ))) + } + DataType::Float64 => { + Ok(Box::new(ArrayAggGroupsAccumulator::::new( + &self.input_data_type, + self.ignore_nulls, + ))) + } + DataType::Decimal128(_, _) => { + Ok(Box::new(ArrayAggGroupsAccumulator::::new( + &self.input_data_type, + self.ignore_nulls, + ))) + } + DataType::Decimal256(_, _) => { + Ok(Box::new(ArrayAggGroupsAccumulator::::new( + &self.input_data_type, + self.ignore_nulls, + ))) + } + DataType::Date32 => { + Ok(Box::new(ArrayAggGroupsAccumulator::::new( + &self.input_data_type, + self.ignore_nulls, + ))) + } + DataType::Date64 => { + Ok(Box::new(ArrayAggGroupsAccumulator::::new( + &self.input_data_type, + self.ignore_nulls, + ))) + } + DataType::Timestamp(TimeUnit::Second, _) => Ok(Box::new( + ArrayAggGroupsAccumulator::::new( + &self.input_data_type, + self.ignore_nulls, + ), + )), + DataType::Timestamp(TimeUnit::Millisecond, _) => { + Ok(Box::new(ArrayAggGroupsAccumulator::< + TimestampMillisecondType, + >::new( + &self.input_data_type, self.ignore_nulls + ))) + } + DataType::Timestamp(TimeUnit::Microsecond, _) => { + Ok(Box::new(ArrayAggGroupsAccumulator::< + TimestampMicrosecondType, + >::new( + &self.input_data_type, self.ignore_nulls + ))) + } + DataType::Timestamp(TimeUnit::Nanosecond, _) => Ok(Box::new( + ArrayAggGroupsAccumulator::::new( + &self.input_data_type, + self.ignore_nulls, + ), + )), + DataType::Time32(TimeUnit::Second) => Ok(Box::new( + ArrayAggGroupsAccumulator::::new( + &self.input_data_type, + self.ignore_nulls, + ), + )), + DataType::Time32(TimeUnit::Millisecond) => Ok(Box::new( + ArrayAggGroupsAccumulator::::new( + &self.input_data_type, + self.ignore_nulls, + ), + )), + DataType::Time64(TimeUnit::Microsecond) => Ok(Box::new( + ArrayAggGroupsAccumulator::::new( + &self.input_data_type, + self.ignore_nulls, + ), + )), + DataType::Time64(TimeUnit::Nanosecond) => Ok(Box::new( + ArrayAggGroupsAccumulator::::new( + &self.input_data_type, + self.ignore_nulls, + ), + )), + DataType::Duration(TimeUnit::Second) => Ok(Box::new( + ArrayAggGroupsAccumulator::::new( + &self.input_data_type, + self.ignore_nulls, + ), + )), + DataType::Duration(TimeUnit::Millisecond) => Ok(Box::new( + ArrayAggGroupsAccumulator::::new( + &self.input_data_type, + self.ignore_nulls, + ), + )), + DataType::Duration(TimeUnit::Microsecond) => Ok(Box::new( + ArrayAggGroupsAccumulator::::new( + &self.input_data_type, + self.ignore_nulls, + ), + )), + DataType::Duration(TimeUnit::Nanosecond) => Ok(Box::new( + ArrayAggGroupsAccumulator::::new( + &self.input_data_type, + self.ignore_nulls, + ), + )), + DataType::Interval(IntervalUnit::YearMonth) => Ok(Box::new( + ArrayAggGroupsAccumulator::::new( + &self.input_data_type, + self.ignore_nulls, + ), + )), + DataType::Interval(IntervalUnit::DayTime) => Ok(Box::new( + ArrayAggGroupsAccumulator::::new( + &self.input_data_type, + self.ignore_nulls, + ), + )), + DataType::Interval(IntervalUnit::MonthDayNano) => { + Ok(Box::new(ArrayAggGroupsAccumulator::< + IntervalMonthDayNanoType, + >::new( + &self.input_data_type, self.ignore_nulls + ))) + } + DataType::Utf8 => Ok(Box::new(StringArrayAggGroupsAccumulator::new( + self.ignore_nulls, + ))), + _ => Err(DataFusionError::Internal(format!( + "ArrayAggGroupsAccumulator not supported for data type {:?}", + self.input_data_type + ))), + } + } } impl PartialEq for ArrayAgg { @@ -115,14 +313,16 @@ impl PartialEq for ArrayAgg { pub(crate) struct ArrayAggAccumulator { values: Vec, datatype: DataType, + ignore_nulls: bool, } impl ArrayAggAccumulator { /// new array_agg accumulator based on given item data type - pub fn try_new(datatype: &DataType) -> Result { + pub fn try_new(datatype: &DataType, ignore_nulls: bool) -> Result { Ok(Self { values: vec![], datatype: datatype.clone(), + ignore_nulls, }) } } @@ -134,7 +334,18 @@ impl Accumulator for ArrayAggAccumulator { return Ok(()); } assert!(values.len() == 1, "array_agg can only take 1 param!"); + let val = values[0].clone(); + + if self.ignore_nulls { + if let Some(nulls) = val.logical_nulls() { + let predicate = BooleanArray::from(nulls.inner().clone()); + let filtered = arrow::compute::filter(val.as_ref(), &predicate)?; + self.values.push(filtered); + return Ok(()); + } + } + self.values.push(val); Ok(()) } @@ -186,3 +397,264 @@ impl Accumulator for ArrayAggAccumulator { - std::mem::size_of_val(&self.datatype) } } + +struct ArrayAggGroupsAccumulator +where + T: ArrowPrimitiveType + Send, +{ + values: Vec>, + data_type: DataType, + null_state: NullState, + ignore_nulls: bool, +} + +impl ArrayAggGroupsAccumulator +where + T: ArrowPrimitiveType + Send, +{ + pub fn new(data_type: &DataType, ignore_nulls: bool) -> Self { + Self { + values: vec![], + data_type: data_type.clone(), + null_state: NullState::new(), + ignore_nulls, + } + } +} + +impl ArrayAggGroupsAccumulator { + fn build_list(&mut self, emit_to: EmitTo) -> Result { + let arrays = emit_to.take_needed(&mut self.values); + let nulls = self.null_state.build(emit_to); + + let len = nulls.len(); + assert_eq!(arrays.len(), len); + + let mut builder = ListBuilder::with_capacity( + PrimitiveBuilder::::new().with_data_type(self.data_type.clone()), + len, + ); + + for (is_valid, mut arr) in nulls.iter().zip(arrays.into_iter()) { + if is_valid { + builder.append_value(arr.finish().into_iter()); + } else { + builder.append_null(); + } + } + + Ok(Arc::new(builder.finish())) + } +} + +impl GroupsAccumulator for ArrayAggGroupsAccumulator +where + T: ArrowPrimitiveType + Send + Sync, +{ + fn update_batch( + &mut self, + new_values: &[ArrayRef], + group_indices: &[usize], + opt_filter: Option<&BooleanArray>, + total_num_groups: usize, + ) -> Result<()> { + assert_eq!(new_values.len(), 1, "single argument to update_batch"); + let new_values = new_values[0].as_primitive::(); + + for _ in self.values.len()..total_num_groups { + self.values.push( + PrimitiveBuilder::::new().with_data_type(self.data_type.clone()), + ); + } + + self.null_state.accumulate( + group_indices, + new_values, + opt_filter, + total_num_groups, + |group_index, new_value| match new_value { + Some(new_value) => { + self.values[group_index].append_value(new_value); + } + None => { + if !self.ignore_nulls { + self.values[group_index].append_null(); + } + } + }, + ); + + Ok(()) + } + + fn merge_batch( + &mut self, + values: &[ArrayRef], + group_indices: &[usize], + opt_filter: Option<&BooleanArray>, + total_num_groups: usize, + ) -> Result<()> { + assert_eq!(values.len(), 1, "single argument to merge_batch"); + let values = values[0].as_list(); + + for _ in self.values.len()..total_num_groups { + self.values.push( + PrimitiveBuilder::::new().with_data_type(self.data_type.clone()), + ); + } + + self.null_state.accumulate_array( + group_indices, + values, + opt_filter, + total_num_groups, + |group_index, new_value: ArrayRef| { + let new_value = new_value.as_primitive::(); + self.values[group_index].extend(new_value); + }, + |_group_index| { + // TODO: Should this not do nothing? Null here just means that the group saw no values, right? + }, + ); + + Ok(()) + } + + fn evaluate(&mut self, emit_to: EmitTo) -> Result { + self.build_list(emit_to) + } + + fn state(&mut self, emit_to: EmitTo) -> Result> { + Ok(vec![self.build_list(emit_to)?]) + } + + fn size(&self) -> usize { + std::mem::size_of_val(self) + + std::mem::size_of::>() * self.values.capacity() + + self.values.iter().map(|arr| arr.capacity()).sum::() + * std::mem::size_of::<::Native>() + + self.null_state.size() + } +} + +struct StringArrayAggGroupsAccumulator { + values: Vec, + null_state: NullState, + ignore_nulls: bool, +} + +impl StringArrayAggGroupsAccumulator { + pub fn new(ignore_nulls: bool) -> Self { + Self { + values: vec![], + null_state: NullState::new(), + ignore_nulls, + } + } +} + +impl StringArrayAggGroupsAccumulator { + fn build_list(&mut self, emit_to: EmitTo) -> Result { + let array = emit_to.take_needed(&mut self.values); + let nulls = self.null_state.build(emit_to); + + assert_eq!(array.len(), nulls.len()); + + let mut builder = ListBuilder::with_capacity(StringBuilder::new(), nulls.len()); + for (mut arr, is_valid) in array.into_iter().zip(nulls.into_iter()) { + if is_valid { + builder.append_value(arr.finish().into_iter()); + } else { + builder.append_null(); + } + } + + Ok(Arc::new(builder.finish())) + } +} + +impl GroupsAccumulator for StringArrayAggGroupsAccumulator { + fn update_batch( + &mut self, + new_values: &[ArrayRef], + group_indices: &[usize], + opt_filter: Option<&BooleanArray>, + total_num_groups: usize, + ) -> Result<()> { + assert_eq!(new_values.len(), 1, "single argument to update_batch"); + let new_values = new_values[0].as_string(); + + for _ in self.values.len()..total_num_groups { + self.values.push(StringBuilder::new()); + } + + self.null_state.accumulate_string( + group_indices, + new_values, + opt_filter, + total_num_groups, + |group_index, new_value| { + if let Some(new_value) = new_value { + self.values[group_index].append_value(new_value); + } else if !self.ignore_nulls { + self.values[group_index].append_null(); + } + }, + ); + + Ok(()) + } + + fn merge_batch( + &mut self, + values: &[ArrayRef], + group_indices: &[usize], + opt_filter: Option<&BooleanArray>, + total_num_groups: usize, + ) -> Result<()> { + assert_eq!(values.len(), 1, "single argument to merge_batch"); + let values = values[0].as_list(); + + for _ in self.values.len()..total_num_groups { + self.values.push(StringBuilder::new()); + } + + self.null_state.accumulate_array( + group_indices, + values, + opt_filter, + total_num_groups, + |group_index, new_value: ArrayRef| { + let new_value = new_value.as_string::(); + self.values[group_index] + .extend(new_value.into_iter().map(|s| s.map(|s| s.to_string()))); + }, + |_group_index| {}, + ); + + Ok(()) + } + + fn evaluate(&mut self, emit_to: EmitTo) -> Result { + self.build_list(emit_to) + } + + fn state(&mut self, emit_to: EmitTo) -> Result> { + Ok(vec![self.build_list(emit_to)?]) + } + + fn size(&self) -> usize { + std::mem::size_of_val(self) + + std::mem::size_of::() * self.values.capacity() + + self + .values + .iter() + .map(|arr| { + std::mem::size_of_val(arr.values_slice()) + + std::mem::size_of_val(arr.offsets_slice()) + + arr.validity_slice().map(std::mem::size_of_val).unwrap_or(0) + }) + .sum::() + + self.null_state.size() + } +} diff --git a/datafusion/physical-expr/src/aggregate/average.rs b/datafusion/physical-expr/src/aggregate/average.rs index 80fcc9b70c5f..4ff1bc0d4bd9 100644 --- a/datafusion/physical-expr/src/aggregate/average.rs +++ b/datafusion/physical-expr/src/aggregate/average.rs @@ -459,10 +459,12 @@ where opt_filter, total_num_groups, |group_index, new_value| { - let sum = &mut self.sums[group_index]; - *sum = sum.add_wrapping(new_value); + if let Some(new_value) = new_value { + let sum = &mut self.sums[group_index]; + *sum = sum.add_wrapping(new_value); - self.counts[group_index] += 1; + self.counts[group_index] += 1; + } }, ); @@ -488,7 +490,9 @@ where opt_filter, total_num_groups, |group_index, partial_count| { - self.counts[group_index] += partial_count; + if let Some(partial_count) = partial_count { + self.counts[group_index] += partial_count; + } }, ); @@ -499,9 +503,11 @@ where partial_sums, opt_filter, total_num_groups, - |group_index, new_value: ::Native| { - let sum = &mut self.sums[group_index]; - *sum = sum.add_wrapping(new_value); + |group_index, new_value: Option<::Native>| { + if let Some(new_value) = new_value { + let sum = &mut self.sums[group_index]; + *sum = sum.add_wrapping(new_value); + } }, ); diff --git a/datafusion/physical-expr/src/aggregate/build_in.rs b/datafusion/physical-expr/src/aggregate/build_in.rs index 53cfcfb033a1..3655c0564242 100644 --- a/datafusion/physical-expr/src/aggregate/build_in.rs +++ b/datafusion/physical-expr/src/aggregate/build_in.rs @@ -46,7 +46,7 @@ pub fn create_aggregate_expr( ordering_req: &[PhysicalSortExpr], input_schema: &Schema, name: impl Into, - _ignore_nulls: bool, + ignore_nulls: bool, ) -> Result> { let name = name.into(); // get the result data type for this aggregate function @@ -71,7 +71,13 @@ pub fn create_aggregate_expr( let nullable = expr.nullable(input_schema)?; if ordering_req.is_empty() { - Arc::new(expressions::ArrayAgg::new(expr, name, data_type, nullable)) + Arc::new(expressions::ArrayAgg::new( + expr, + name, + data_type, + nullable, + ignore_nulls, + )) } else { Arc::new(expressions::OrderSensitiveArrayAgg::new( expr, diff --git a/datafusion/sqllogictest/test_files/aggregate.slt b/datafusion/sqllogictest/test_files/aggregate.slt index 378cab206240..64b39571f358 100644 --- a/datafusion/sqllogictest/test_files/aggregate.slt +++ b/datafusion/sqllogictest/test_files/aggregate.slt @@ -2770,6 +2770,19 @@ select array_agg(c1) from t; statement ok drop table t; +# array_agg_str + +statement ok +create table t (c1 string) as values ('a'), ('b'), ('c'), ('d'), ('e'); + +query ? +select array_agg(c1) from t; +---- +[a, b, c, d, e] + +statement ok +drop table t; + # array_agg_nested statement ok create table t as values (make_array([1, 2, 3], [4, 5])), (make_array([6], [7, 8])), (make_array([9])); @@ -3694,6 +3707,13 @@ SELECT tag, array_agg(millis - arrow_cast(secs, 'Timestamp(Millisecond, None)')) X [0 days 0 hours 0 mins 0.011 secs, 0 days 0 hours 0 mins 0.123 secs] Y [, 0 days 0 hours 0 mins 0.432 secs] +# aggregate_duration_array_agg_ignore_nulls +query T? +SELECT tag, array_agg(millis - arrow_cast(secs, 'Timestamp(Millisecond, None)')) IGNORE NULLS FROM t GROUP BY tag ORDER BY tag; +---- +X [0 days 0 hours 0 mins 0.011 secs, 0 days 0 hours 0 mins 0.123 secs] +Y [0 days 0 hours 0 mins 0.432 secs] + statement ok drop table t_source;