diff --git a/datafusion/functions-aggregate/src/string_agg.rs b/datafusion/functions-aggregate/src/string_agg.rs index 3986984b2630..a091ed34da70 100644 --- a/datafusion/functions-aggregate/src/string_agg.rs +++ b/datafusion/functions-aggregate/src/string_agg.rs @@ -25,9 +25,14 @@ use crate::array_agg::ArrayAgg; use arrow::array::ArrayRef; use arrow::datatypes::{DataType, Field, FieldRef}; -use datafusion_common::cast::{as_generic_string_array, as_string_view_array}; -use datafusion_common::{internal_err, not_impl_err, Result, ScalarValue}; +use datafusion_common::cast::{ + as_generic_string_array, as_string_array, as_string_view_array, +}; +use datafusion_common::{ + internal_datafusion_err, internal_err, not_impl_err, Result, ScalarValue, +}; use datafusion_expr::function::AccumulatorArgs; +use datafusion_expr::utils::format_state_name; use datafusion_expr::{ Accumulator, AggregateUDFImpl, Documentation, Signature, TypeSignature, Volatility, }; @@ -120,6 +125,8 @@ impl Default for StringAgg { } } +/// If there is no `distinct` and `order by` required by the `string_agg` call, a +/// more efficient accumulator `SimpleStringAggAccumulator` will be used. impl AggregateUDFImpl for StringAgg { fn as_any(&self) -> &dyn Any { self @@ -138,7 +145,21 @@ impl AggregateUDFImpl for StringAgg { } fn state_fields(&self, args: StateFieldsArgs) -> Result> { - self.array_agg.state_fields(args) + // See comments in `impl AggregateUDFImpl ...` for more detail + let no_order_no_distinct = + (args.ordering_fields.is_empty()) && (!args.is_distinct); + if no_order_no_distinct { + // Case `SimpleStringAggAccumulator` + Ok(vec![Field::new( + format_state_name(args.name, "string_agg"), + DataType::LargeUtf8, + true, + ) + .into()]) + } else { + // Case `StringAggAccumulator` + self.array_agg.state_fields(args) + } } fn accumulator(&self, acc_args: AccumulatorArgs) -> Result> { @@ -161,21 +182,31 @@ impl AggregateUDFImpl for StringAgg { ); }; - let array_agg_acc = self.array_agg.accumulator(AccumulatorArgs { - return_field: Field::new( - "f", - DataType::new_list(acc_args.return_field.data_type().clone(), true), - true, - ) - .into(), - exprs: &filter_index(acc_args.exprs, 1), - ..acc_args - })?; + // See comments in `impl AggregateUDFImpl ...` for more detail + let no_order_no_distinct = + acc_args.order_bys.is_empty() && (!acc_args.is_distinct); - Ok(Box::new(StringAggAccumulator::new( - array_agg_acc, - delimiter, - ))) + if no_order_no_distinct { + // simple case (more efficient) + Ok(Box::new(SimpleStringAggAccumulator::new(delimiter))) + } else { + // general case + let array_agg_acc = self.array_agg.accumulator(AccumulatorArgs { + return_field: Field::new( + "f", + DataType::new_list(acc_args.return_field.data_type().clone(), true), + true, + ) + .into(), + exprs: &filter_index(acc_args.exprs, 1), + ..acc_args + })?; + + Ok(Box::new(StringAggAccumulator::new( + array_agg_acc, + delimiter, + ))) + } } fn reverse_expr(&self) -> datafusion_expr::ReversedUDAF { @@ -187,6 +218,7 @@ impl AggregateUDFImpl for StringAgg { } } +/// StringAgg accumulator for the general case (with order or distinct specified) #[derive(Debug)] pub(crate) struct StringAggAccumulator { array_agg_acc: Box, @@ -269,6 +301,105 @@ fn filter_index(values: &[T], index: usize) -> Vec { .collect::>() } +/// StringAgg accumulator for the simple case (no order or distinct specified) +/// This accumulator is more efficient than `StringAggAccumulator` +/// because it accumulates the string directly, +/// whereas `StringAggAccumulator` uses `ArrayAggAccumulator`. +#[derive(Debug)] +pub(crate) struct SimpleStringAggAccumulator { + delimiter: String, + /// Updated during `update_batch()`. e.g. "foo,bar" + accumulated_string: String, + has_value: bool, +} + +impl SimpleStringAggAccumulator { + pub fn new(delimiter: &str) -> Self { + Self { + delimiter: delimiter.to_string(), + accumulated_string: "".to_string(), + has_value: false, + } + } + + #[inline] + fn append_strings<'a, I>(&mut self, iter: I) + where + I: Iterator>, + { + for value in iter.flatten() { + if self.has_value { + self.accumulated_string.push_str(&self.delimiter); + } + + self.accumulated_string.push_str(value); + self.has_value = true; + } + } +} + +impl Accumulator for SimpleStringAggAccumulator { + fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + let string_arr = values.first().ok_or_else(|| { + internal_datafusion_err!( + "Planner should ensure its first arg is Utf8/Utf8View" + ) + })?; + + match string_arr.data_type() { + DataType::Utf8 => { + let array = as_string_array(string_arr)?; + self.append_strings(array.iter()); + } + DataType::LargeUtf8 => { + let array = as_generic_string_array::(string_arr)?; + self.append_strings(array.iter()); + } + DataType::Utf8View => { + let array = as_string_view_array(string_arr)?; + self.append_strings(array.iter()); + } + other => { + return internal_err!( + "Planner should ensure string_agg first argument is Utf8-like, found {other}" + ); + } + } + + Ok(()) + } + + fn evaluate(&mut self) -> Result { + let result = if self.has_value { + ScalarValue::LargeUtf8(Some(std::mem::take(&mut self.accumulated_string))) + } else { + ScalarValue::LargeUtf8(None) + }; + + self.has_value = false; + Ok(result) + } + + fn size(&self) -> usize { + size_of_val(self) + self.delimiter.capacity() + self.accumulated_string.capacity() + } + + fn state(&mut self) -> Result> { + let result = if self.has_value { + ScalarValue::LargeUtf8(Some(std::mem::take(&mut self.accumulated_string))) + } else { + ScalarValue::LargeUtf8(None) + }; + self.has_value = false; + + Ok(vec![result]) + } + + fn merge_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + self.update_batch(values) + } +} + #[cfg(test)] mod tests { use super::*;