Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 47 additions & 38 deletions datafusion/functions-aggregate/src/average.rs
Original file line number Diff line number Diff line change
Expand Up @@ -757,11 +757,9 @@ where
/// The type of the returned sum
return_data_type: DataType,

/// Count per group (use u64 to make UInt64Array)
counts: Vec<u64>,

/// Sums per group, stored as the native type
sums: Vec<T::Native>,
/// Combined count and sum per group in a single Vec to halve reallocation cost.
/// Each entry stores (count, sum) for one group.
states: Vec<AvgState<T::Native>>,

/// Track nulls in the input / filters
null_state: NullState,
Expand All @@ -770,6 +768,14 @@ where
avg_fn: F,
}

/// Combined per-group state for AVG accumulator.
/// Stored in a single Vec to reduce reallocation overhead.
#[derive(Debug, Clone, Copy)]
struct AvgState<N> {
count: u64,
sum: N,
}

impl<T, F> AvgGroupsAccumulator<T, F>
where
T: ArrowNumericType + Send,
Expand All @@ -784,8 +790,7 @@ where
Self {
return_data_type: return_data_type.clone(),
sum_data_type: sum_data_type.clone(),
counts: vec![],
sums: vec![],
states: vec![],
null_state: NullState::new(),
avg_fn,
}
Expand All @@ -808,34 +813,36 @@ where
let values = values[0].as_primitive::<T>();

// increment counts, update sums
self.counts.resize(total_num_groups, 0);
self.sums.resize(total_num_groups, T::default_value());
self.states.resize(
total_num_groups,
AvgState {
count: 0,
sum: T::default_value(),
},
);
self.null_state.accumulate(
group_indices,
values,
opt_filter,
total_num_groups,
|group_index, new_value| {
// SAFETY: group_index is guaranteed to be in bounds
let sum = unsafe { self.sums.get_unchecked_mut(group_index) };
*sum = sum.add_wrapping(new_value);

self.counts[group_index] += 1;
let state = unsafe { self.states.get_unchecked_mut(group_index) };
state.sum = state.sum.add_wrapping(new_value);
state.count += 1;
},
);

Ok(())
}

fn evaluate(&mut self, emit_to: EmitTo) -> Result<ArrayRef> {
let counts = emit_to.take_needed(&mut self.counts);
let sums = emit_to.take_needed(&mut self.sums);
let states = emit_to.take_needed(&mut self.states);
let nulls = self.null_state.build(emit_to);

if let Some(nulls) = &nulls {
assert_eq!(nulls.len(), sums.len());
assert_eq!(nulls.len(), states.len());
}
assert_eq!(counts.len(), sums.len());

// don't evaluate averages with null inputs to avoid errors on null values

Expand All @@ -844,21 +851,20 @@ where
{
let mut builder = PrimitiveBuilder::<T>::with_capacity(nulls.len())
.with_data_type(self.return_data_type.clone());
let iter = sums.into_iter().zip(counts).zip(nulls.iter());
let iter = states.into_iter().zip(nulls.iter());

for ((sum, count), is_valid) in iter {
for (state, is_valid) in iter {
if is_valid {
builder.append_value((self.avg_fn)(sum, count)?)
builder.append_value((self.avg_fn)(state.sum, state.count)?)
} else {
builder.append_null();
}
}
builder.finish()
} else {
let averages: Vec<T::Native> = sums
let averages: Vec<T::Native> = states
.into_iter()
.zip(counts.into_iter())
.map(|(sum, count)| (self.avg_fn)(sum, count))
.map(|state| (self.avg_fn)(state.sum, state.count))
.collect::<Result<Vec<_>>>()?;
PrimitiveArray::new(averages.into(), nulls) // no copy
.with_data_type(self.return_data_type.clone())
Expand All @@ -871,11 +877,11 @@ where
fn state(&mut self, emit_to: EmitTo) -> Result<Vec<ArrayRef>> {
let nulls = self.null_state.build(emit_to);

let counts = emit_to.take_needed(&mut self.counts);
let counts = UInt64Array::new(counts.into(), nulls.clone()); // zero copy

let sums = emit_to.take_needed(&mut self.sums);
let sums = PrimitiveArray::<T>::new(sums.into(), nulls) // zero copy
let states = emit_to.take_needed(&mut self.states);
let (counts, sums): (Vec<u64>, Vec<T::Native>) =
states.into_iter().map(|s| (s.count, s.sum)).unzip();
let counts = UInt64Array::new(counts.into(), nulls.clone());
let sums = PrimitiveArray::<T>::new(sums.into(), nulls)
.with_data_type(self.sum_data_type.clone());

Ok(vec![
Expand All @@ -895,31 +901,34 @@ where
// first batch is counts, second is partial sums
let partial_counts = values[0].as_primitive::<UInt64Type>();
let partial_sums = values[1].as_primitive::<T>();
// update counts with partial counts
self.counts.resize(total_num_groups, 0);
// single resize for combined state
self.states.resize(
total_num_groups,
AvgState {
count: 0,
sum: T::default_value(),
},
);
self.null_state.accumulate(
group_indices,
partial_counts,
opt_filter,
total_num_groups,
|group_index, partial_count| {
// SAFETY: group_index is guaranteed to be in bounds
let count = unsafe { self.counts.get_unchecked_mut(group_index) };
*count += partial_count;
let state = unsafe { self.states.get_unchecked_mut(group_index) };
state.count += partial_count;
},
);

// update sums
self.sums.resize(total_num_groups, T::default_value());
self.null_state.accumulate(
group_indices,
partial_sums,
opt_filter,
total_num_groups,
|group_index, new_value: <T as ArrowPrimitiveType>::Native| {
// SAFETY: group_index is guaranteed to be in bounds
let sum = unsafe { self.sums.get_unchecked_mut(group_index) };
*sum = sum.add_wrapping(new_value);
let state = unsafe { self.states.get_unchecked_mut(group_index) };
state.sum = state.sum.add_wrapping(new_value);
},
);

Expand Down Expand Up @@ -951,6 +960,6 @@ where
}

fn size(&self) -> usize {
self.counts.capacity() * size_of::<u64>() + self.sums.capacity() * size_of::<T>()
self.states.capacity() * size_of::<AvgState<T::Native>>()
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,8 @@ impl<B: ByteViewType> ByteViewGroupValueBuilder<B> {
Nulls::Some
};

self.views.reserve(rows.len());

match all_null_or_non_null {
Nulls::Some => {
for &row in rows {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,8 @@ impl<T: ArrowPrimitiveType, const NULLABLE: bool> GroupColumn
Nulls::Some
};

self.group_values.reserve(rows.len());

match (NULLABLE, all_null_or_non_null) {
(true, Nulls::Some) => {
for &row in rows {
Expand Down
Loading