From b999feafb8877c06bbeea3153570119bc9d7117c Mon Sep 17 00:00:00 2001 From: Kumar Ujjawal Date: Wed, 26 Nov 2025 11:20:37 +0530 Subject: [PATCH 1/4] feat: Support decimal for variance --- .../functions-aggregate/src/variance.rs | 76 +++++++++++++++++-- .../sqllogictest/test_files/aggregate.slt | 17 +++++ 2 files changed, 88 insertions(+), 5 deletions(-) diff --git a/datafusion/functions-aggregate/src/variance.rs b/datafusion/functions-aggregate/src/variance.rs index 846c145cb11e..33dcae872fa7 100644 --- a/datafusion/functions-aggregate/src/variance.rs +++ b/datafusion/functions-aggregate/src/variance.rs @@ -29,8 +29,8 @@ use datafusion_common::{downcast_value, not_impl_err, plan_err, Result, ScalarVa use datafusion_expr::{ function::{AccumulatorArgs, StateFieldsArgs}, utils::format_state_name, - Accumulator, AggregateUDFImpl, Documentation, GroupsAccumulator, Signature, - Volatility, + Accumulator, AggregateUDFImpl, Coercion, Documentation, GroupsAccumulator, Signature, + TypeSignature, TypeSignatureClass, Volatility, }; use datafusion_functions_aggregate_common::{ aggregate::groups_accumulator::accumulate::accumulate, stats::StatsType, @@ -55,6 +55,29 @@ make_udaf_expr_and_func!( var_pop_udaf ); +fn variance_signature() -> Signature { + Signature::one_of( + vec![ + TypeSignature::Numeric(1), + TypeSignature::Coercible(vec![Coercion::new_exact( + TypeSignatureClass::Decimal, + )]), + ], + Volatility::Immutable, + ) +} + +fn is_numeric_or_decimal(data_type: &DataType) -> bool { + data_type.is_numeric() + || matches!( + data_type, + DataType::Decimal32(_, _) + | DataType::Decimal64(_, _) + | DataType::Decimal128(_, _) + | DataType::Decimal256(_, _) + ) +} + #[user_doc( doc_section(label = "General Functions"), description = "Returns the statistical sample variance of a set of numbers.", @@ -86,7 +109,7 @@ impl VarianceSample { pub fn new() -> Self { Self { aliases: vec![String::from("var_sample"), String::from("var_samp")], - signature: Signature::numeric(1, Volatility::Immutable), + signature: variance_signature(), } } } @@ -179,7 +202,7 @@ impl VariancePopulation { pub fn new() -> Self { Self { aliases: vec![String::from("var_population")], - signature: Signature::numeric(1, Volatility::Immutable), + signature: variance_signature(), } } } @@ -198,7 +221,7 @@ impl AggregateUDFImpl for VariancePopulation { } fn return_type(&self, arg_types: &[DataType]) -> Result { - if !arg_types[0].is_numeric() { + if !is_numeric_or_decimal(&arg_types[0]) { return plan_err!("Variance requires numeric input types"); } @@ -583,10 +606,53 @@ impl GroupsAccumulator for VarianceGroupsAccumulator { #[cfg(test)] mod tests { + use arrow::array::Decimal128Builder; use datafusion_expr::EmitTo; + use std::sync::Arc; use super::*; + #[test] + fn variance_population_accepts_decimal() -> Result<()> { + let variance = VariancePopulation::new(); + variance.return_type(&[DataType::Decimal128(10, 3)])?; + Ok(()) + } + + #[test] + fn variance_decimal_input() -> Result<()> { + let mut builder = Decimal128Builder::with_capacity(20); + for i in 0..10 { + builder.append_value(110000 + i); + } + for i in 0..10 { + builder.append_value(-((100000 + i) as i128)); + } + let decimal_array = builder.finish().with_precision_and_scale(10, 3).unwrap(); + let array: ArrayRef = Arc::new(decimal_array); + + let mut pop_acc = VarianceAccumulator::try_new(StatsType::Population)?; + let pop_input = [Arc::clone(&array)]; + pop_acc.update_batch(&pop_input)?; + assert_variance(pop_acc.evaluate()?, 11025.9450285); + + let mut sample_acc = VarianceAccumulator::try_new(StatsType::Sample)?; + let sample_input = [array]; + sample_acc.update_batch(&sample_input)?; + assert_variance(sample_acc.evaluate()?, 11606.257924736841); + + Ok(()) + } + + fn assert_variance(value: ScalarValue, expected: f64) { + match value { + ScalarValue::Float64(Some(actual)) => { + assert!((actual - expected).abs() < 1e-9) + } + other => panic!("expected Float64 result, got {other:?}"), + } + } + #[test] fn test_groups_accumulator_merge_empty_states() -> Result<()> { let state_1 = vec![ diff --git a/datafusion/sqllogictest/test_files/aggregate.slt b/datafusion/sqllogictest/test_files/aggregate.slt index a1b868b0b028..51e2c4e75c38 100644 --- a/datafusion/sqllogictest/test_files/aggregate.slt +++ b/datafusion/sqllogictest/test_files/aggregate.slt @@ -5629,6 +5629,23 @@ select c2, avg(c1), arrow_typeof(avg(c1)) from d_table GROUP BY c2 ORDER BY c2 A 110.0045 Decimal128(14, 7) B -100.0045 Decimal128(14, 7) +# aggregate_decimal_variance +query RT +select var_pop(c1), arrow_typeof(var_pop(c1)) from d_table +---- +11025.945028500004 Float64 + +query RT +select var(c1), arrow_typeof(var(c1)) from d_table +---- +11606.257924736847 Float64 + +query TRT +select c2, var_pop(c1), arrow_typeof(var_pop(c1)) from d_table GROUP BY c2 ORDER BY c2 +---- +A 0.00000825 Float64 +B 0.00000825 Float64 + # aggregate_decimal_count_distinct query I select count(DISTINCT cast(c1 AS DECIMAL(10, 2))) from d_table From abc67804756ac1b9cb432691f363326446949f5a Mon Sep 17 00:00:00 2001 From: Kumar Ujjawal Date: Wed, 26 Nov 2025 14:28:07 +0530 Subject: [PATCH 2/4] native decimal support using accumulator --- .../functions-aggregate/src/variance.rs | 575 +++++++++++++++++- .../sqllogictest/test_files/aggregate.slt | 8 +- 2 files changed, 573 insertions(+), 10 deletions(-) diff --git a/datafusion/functions-aggregate/src/variance.rs b/datafusion/functions-aggregate/src/variance.rs index 33dcae872fa7..6673e09f2799 100644 --- a/datafusion/functions-aggregate/src/variance.rs +++ b/datafusion/functions-aggregate/src/variance.rs @@ -20,12 +20,23 @@ use arrow::datatypes::FieldRef; use arrow::{ - array::{Array, ArrayRef, BooleanArray, Float64Array, UInt64Array}, + array::{ + Array, ArrayRef, AsArray, BooleanArray, FixedSizeBinaryArray, + FixedSizeBinaryBuilder, Float64Array, Float64Builder, PrimitiveArray, + UInt64Array, UInt64Builder, + }, buffer::NullBuffer, compute::kernels::cast, - datatypes::{DataType, Field}, + datatypes::i256, + datatypes::{ + ArrowNumericType, DataType, Decimal128Type, Decimal256Type, Decimal32Type, + Decimal64Type, DecimalType, Field, DECIMAL256_MAX_SCALE, + }, +}; +use datafusion_common::{ + downcast_value, exec_err, not_impl_err, plan_err, DataFusionError, Result, + ScalarValue, }; -use datafusion_common::{downcast_value, not_impl_err, plan_err, Result, ScalarValue}; use datafusion_expr::{ function::{AccumulatorArgs, StateFieldsArgs}, utils::format_state_name, @@ -36,8 +47,9 @@ use datafusion_functions_aggregate_common::{ aggregate::groups_accumulator::accumulate::accumulate, stats::StatsType, }; use datafusion_macros::user_doc; +use std::convert::TryInto; use std::mem::{size_of, size_of_val}; -use std::{fmt::Debug, sync::Arc}; +use std::{fmt::Debug, marker::PhantomData, ops::Neg, sync::Arc}; make_udaf_expr_and_func!( VarianceSample, @@ -67,6 +79,61 @@ fn variance_signature() -> Signature { ) } +const DECIMAL_VARIANCE_BINARY_SIZE: i32 = 32; + +fn decimal_overflow_err() -> DataFusionError { + DataFusionError::Execution("Decimal variance overflow".to_string()) +} + +fn i256_to_f64_lossy(value: i256) -> f64 { + const SCALE: f64 = 18446744073709551616.0; // 2^64 + let mut abs = value; + let negative = abs < i256::ZERO; + if negative { + abs = abs.neg(); + } + let bytes = abs.to_le_bytes(); + let mut result = 0f64; + for chunk in bytes.chunks_exact(8).rev() { + let chunk_val = u64::from_le_bytes(chunk.try_into().unwrap()); + result = result * SCALE + chunk_val as f64; + } + if negative { + -result + } else { + result + } +} + +fn decimal_scale(dt: &DataType) -> Option { + match dt { + DataType::Decimal32(_, scale) + | DataType::Decimal64(_, scale) + | DataType::Decimal128(_, scale) + | DataType::Decimal256(_, scale) => Some(*scale), + _ => None, + } +} + +fn decimal_variance_state_fields(name: &str) -> Vec { + vec![ + Field::new(format_state_name(name, "count"), DataType::UInt64, true), + Field::new( + format_state_name(name, "sum"), + DataType::FixedSizeBinary(DECIMAL_VARIANCE_BINARY_SIZE), + true, + ), + Field::new( + format_state_name(name, "sum_squares"), + DataType::FixedSizeBinary(DECIMAL_VARIANCE_BINARY_SIZE), + true, + ), + ] + .into_iter() + .map(Arc::new) + .collect() +} + fn is_numeric_or_decimal(data_type: &DataType) -> bool { data_type.is_numeric() || matches!( @@ -78,6 +145,460 @@ fn is_numeric_or_decimal(data_type: &DataType) -> bool { ) } +fn i256_from_bytes(bytes: &[u8]) -> Result { + if bytes.len() != DECIMAL_VARIANCE_BINARY_LEN { + return exec_err!( + "Decimal variance state expected {} bytes got {}", + DECIMAL_VARIANCE_BINARY_LEN, + bytes.len() + ); + } + let mut buffer = [0u8; DECIMAL_VARIANCE_BINARY_LEN]; + buffer.copy_from_slice(bytes); + Ok(i256::from_le_bytes(buffer)) +} + +const DECIMAL_VARIANCE_BINARY_LEN: usize = DECIMAL_VARIANCE_BINARY_SIZE as usize; + +fn i256_to_scalar(value: i256) -> ScalarValue { + ScalarValue::FixedSizeBinary( + DECIMAL_VARIANCE_BINARY_SIZE, + Some(value.to_le_bytes().to_vec()), + ) +} + +fn create_decimal_variance_accumulator( + data_type: &DataType, + stats_type: StatsType, +) -> Result>> { + let accumulator = match data_type { + DataType::Decimal32(_, scale) => Some(Box::new(DecimalVarianceAccumulator::< + Decimal32Type, + >::try_new( + *scale, stats_type + )?) as Box), + DataType::Decimal64(_, scale) => Some(Box::new(DecimalVarianceAccumulator::< + Decimal64Type, + >::try_new( + *scale, stats_type + )?) as Box), + DataType::Decimal128(_, scale) => Some(Box::new(DecimalVarianceAccumulator::< + Decimal128Type, + >::try_new( + *scale, stats_type + )?) as Box), + DataType::Decimal256(_, scale) => Some(Box::new(DecimalVarianceAccumulator::< + Decimal256Type, + >::try_new( + *scale, stats_type + )?) as Box), + _ => None, + }; + Ok(accumulator) +} + +fn create_decimal_variance_groups_accumulator( + data_type: &DataType, + stats_type: StatsType, +) -> Result>> { + let accumulator = match data_type { + DataType::Decimal32(_, scale) => Some(Box::new( + DecimalVarianceGroupsAccumulator::::new(*scale, stats_type), + ) as Box), + DataType::Decimal64(_, scale) => Some(Box::new( + DecimalVarianceGroupsAccumulator::::new(*scale, stats_type), + ) as Box), + DataType::Decimal128(_, scale) => Some(Box::new( + DecimalVarianceGroupsAccumulator::::new(*scale, stats_type), + ) as Box), + DataType::Decimal256(_, scale) => Some(Box::new( + DecimalVarianceGroupsAccumulator::::new(*scale, stats_type), + ) as Box), + _ => None, + }; + Ok(accumulator) +} + +trait DecimalNative: Copy { + fn to_i256(self) -> i256; +} + +impl DecimalNative for i32 { + fn to_i256(self) -> i256 { + i256::from(self) + } +} + +impl DecimalNative for i64 { + fn to_i256(self) -> i256 { + i256::from(self) + } +} + +impl DecimalNative for i128 { + fn to_i256(self) -> i256 { + i256::from_i128(self) + } +} + +impl DecimalNative for i256 { + fn to_i256(self) -> i256 { + self + } +} + +#[derive(Clone, Debug, Default)] +struct DecimalVarianceState { + count: u64, + sum: i256, + sum_squares: i256, +} + +impl DecimalVarianceState { + fn update(&mut self, value: i256) -> Result<()> { + self.count = self.count.checked_add(1).ok_or_else(decimal_overflow_err)?; + self.sum = self + .sum + .checked_add(value) + .ok_or_else(decimal_overflow_err)?; + let square = value.checked_mul(value).ok_or_else(decimal_overflow_err)?; + self.sum_squares = self + .sum_squares + .checked_add(square) + .ok_or_else(decimal_overflow_err)?; + Ok(()) + } + + fn retract(&mut self, value: i256) -> Result<()> { + if self.count == 0 { + return exec_err!("Decimal variance retract underflow"); + } + self.count -= 1; + self.sum = self + .sum + .checked_sub(value) + .ok_or_else(decimal_overflow_err)?; + let square = value.checked_mul(value).ok_or_else(decimal_overflow_err)?; + self.sum_squares = self + .sum_squares + .checked_sub(square) + .ok_or_else(decimal_overflow_err)?; + Ok(()) + } + + fn merge(&mut self, other: &Self) -> Result<()> { + self.count = self + .count + .checked_add(other.count) + .ok_or_else(decimal_overflow_err)?; + self.sum = self + .sum + .checked_add(other.sum) + .ok_or_else(decimal_overflow_err)?; + self.sum_squares = self + .sum_squares + .checked_add(other.sum_squares) + .ok_or_else(decimal_overflow_err)?; + Ok(()) + } + + fn variance(&self, stats_type: StatsType, scale: i8) -> Result> { + if self.count == 0 { + return Ok(None); + } + if matches!(stats_type, StatsType::Sample) && self.count <= 1 { + return Ok(None); + } + + let count_i256 = i256::from_i128(self.count as i128); + let scaled_sum_squares = self + .sum_squares + .checked_mul(count_i256) + .ok_or_else(decimal_overflow_err)?; + let sum_squared = self + .sum + .checked_mul(self.sum) + .ok_or_else(decimal_overflow_err)?; + let numerator = scaled_sum_squares + .checked_sub(sum_squared) + .ok_or_else(decimal_overflow_err)?; + + let numerator = if numerator < i256::ZERO { + i256::ZERO + } else { + numerator + }; + + let denominator_counts = match stats_type { + StatsType::Population => { + let count = self.count as f64; + count * count + } + StatsType::Sample => { + let count = self.count as f64; + count * ((self.count - 1) as f64) + } + }; + + if denominator_counts == 0.0 { + return Ok(None); + } + + let numerator_f64 = i256_to_f64_lossy(numerator); + let scale_factor = 10f64.powi(2 * scale as i32); + Ok(Some(numerator_f64 / (denominator_counts * scale_factor))) + } + + fn to_scalar_state(&self) -> Vec { + vec![ + ScalarValue::from(self.count), + i256_to_scalar(self.sum), + i256_to_scalar(self.sum_squares), + ] + } +} + +#[derive(Debug)] +struct DecimalVarianceAccumulator +where + T: DecimalType + ArrowNumericType + Debug, + T::Native: DecimalNative, +{ + state: DecimalVarianceState, + scale: i8, + stats_type: StatsType, + _marker: PhantomData, +} + +impl DecimalVarianceAccumulator +where + T: DecimalType + ArrowNumericType + Debug, + T::Native: DecimalNative, +{ + fn try_new(scale: i8, stats_type: StatsType) -> Result { + if scale > DECIMAL256_MAX_SCALE { + return exec_err!( + "Decimal variance does not support scale {} greater than {}", + scale, + DECIMAL256_MAX_SCALE + ); + } + Ok(Self { + state: DecimalVarianceState::default(), + scale, + stats_type, + _marker: PhantomData, + }) + } + + fn convert_array(values: &ArrayRef) -> &PrimitiveArray { + values.as_primitive::() + } +} + +impl Accumulator for DecimalVarianceAccumulator +where + T: DecimalType + ArrowNumericType + Debug, + T::Native: DecimalNative, +{ + fn state(&mut self) -> Result> { + Ok(self.state.to_scalar_state()) + } + + fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + let array = Self::convert_array(&values[0]); + for value in array.iter().flatten() { + self.state.update(value.to_i256())?; + } + Ok(()) + } + + fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + let array = Self::convert_array(&values[0]); + for value in array.iter().flatten() { + self.state.retract(value.to_i256())?; + } + Ok(()) + } + + fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { + let counts = downcast_value!(states[0], UInt64Array); + let sums = downcast_value!(states[1], FixedSizeBinaryArray); + let sum_squares = downcast_value!(states[2], FixedSizeBinaryArray); + + for i in 0..counts.len() { + if counts.is_null(i) { + continue; + } + let count = counts.value(i); + if count == 0 { + continue; + } + let sum = i256_from_bytes(sums.value(i))?; + let sum_sq = i256_from_bytes(sum_squares.value(i))?; + let other = DecimalVarianceState { + count, + sum, + sum_squares: sum_sq, + }; + self.state.merge(&other)?; + } + Ok(()) + } + + fn evaluate(&mut self) -> Result { + match self.state.variance(self.stats_type, self.scale)? { + Some(v) => Ok(ScalarValue::Float64(Some(v))), + None => Ok(ScalarValue::Float64(None)), + } + } + + fn size(&self) -> usize { + size_of_val(self) + } + + fn supports_retract_batch(&self) -> bool { + true + } +} + +#[derive(Debug)] +struct DecimalVarianceGroupsAccumulator +where + T: DecimalType + ArrowNumericType + Debug, + T::Native: DecimalNative, +{ + states: Vec, + scale: i8, + stats_type: StatsType, + _marker: PhantomData, +} + +impl DecimalVarianceGroupsAccumulator +where + T: DecimalType + ArrowNumericType + Debug, + T::Native: DecimalNative, +{ + fn new(scale: i8, stats_type: StatsType) -> Self { + Self { + states: Vec::new(), + scale, + stats_type, + _marker: PhantomData, + } + } + + fn resize(&mut self, total_num_groups: usize) { + if self.states.len() < total_num_groups { + self.states + .resize(total_num_groups, DecimalVarianceState::default()); + } + } +} + +impl GroupsAccumulator for DecimalVarianceGroupsAccumulator +where + T: DecimalType + ArrowNumericType + Debug, + T::Native: DecimalNative, +{ + fn update_batch( + &mut self, + values: &[ArrayRef], + group_indices: &[usize], + opt_filter: Option<&BooleanArray>, + total_num_groups: usize, + ) -> Result<()> { + let array = values[0].as_primitive::(); + self.resize(total_num_groups); + for (row, group_index) in group_indices.iter().enumerate() { + if let Some(filter) = opt_filter { + if !filter.value(row) { + continue; + } + } + if array.is_null(row) { + continue; + } + let value = array.value(row).to_i256(); + self.states[*group_index].update(value)?; + } + Ok(()) + } + + fn merge_batch( + &mut self, + values: &[ArrayRef], + group_indices: &[usize], + _opt_filter: Option<&BooleanArray>, + total_num_groups: usize, + ) -> Result<()> { + let counts = downcast_value!(values[0], UInt64Array); + let sums = downcast_value!(values[1], FixedSizeBinaryArray); + let sum_squares = downcast_value!(values[2], FixedSizeBinaryArray); + self.resize(total_num_groups); + + for (row, group_index) in group_indices.iter().enumerate() { + if counts.is_null(row) { + continue; + } + let count = counts.value(row); + if count == 0 { + continue; + } + let sum = i256_from_bytes(sums.value(row))?; + let sum_sq = i256_from_bytes(sum_squares.value(row))?; + let other = DecimalVarianceState { + count, + sum, + sum_squares: sum_sq, + }; + self.states[*group_index].merge(&other)?; + } + Ok(()) + } + + fn evaluate(&mut self, emit_to: datafusion_expr::EmitTo) -> Result { + let states = emit_to.take_needed(&mut self.states); + let mut builder = Float64Builder::with_capacity(states.len()); + for state in &states { + match state.variance(self.stats_type, self.scale)? { + Some(value) => builder.append_value(value), + None => builder.append_null(), + } + } + Ok(Arc::new(builder.finish())) + } + + fn state(&mut self, emit_to: datafusion_expr::EmitTo) -> Result> { + let states = emit_to.take_needed(&mut self.states); + let mut counts = UInt64Builder::with_capacity(states.len()); + let mut sums = FixedSizeBinaryBuilder::with_capacity( + states.len(), + DECIMAL_VARIANCE_BINARY_SIZE, + ); + let mut sum_squares = FixedSizeBinaryBuilder::with_capacity( + states.len(), + DECIMAL_VARIANCE_BINARY_SIZE, + ); + + for state in states { + counts.append_value(state.count); + sums.append_value(state.sum.to_le_bytes())?; + sum_squares.append_value(state.sum_squares.to_le_bytes())?; + } + + Ok(vec![ + Arc::new(counts.finish()), + Arc::new(sums.finish()), + Arc::new(sum_squares.finish()), + ]) + } + + fn size(&self) -> usize { + self.states.capacity() * size_of::() + } +} + #[user_doc( doc_section(label = "General Functions"), description = "Returns the statistical sample variance of a set of numbers.", @@ -133,6 +654,14 @@ impl AggregateUDFImpl for VarianceSample { fn state_fields(&self, args: StateFieldsArgs) -> Result> { let name = args.name; + if args + .input_fields + .first() + .and_then(|field| decimal_scale(field.data_type())) + .is_some() + { + return Ok(decimal_variance_state_fields(name)); + } Ok(vec![ Field::new(format_state_name(name, "count"), DataType::UInt64, true), Field::new(format_state_name(name, "mean"), DataType::Float64, true), @@ -148,6 +677,13 @@ impl AggregateUDFImpl for VarianceSample { return not_impl_err!("VAR(DISTINCT) aggregations are not available"); } + if let Some(acc) = create_decimal_variance_accumulator( + acc_args.expr_fields[0].data_type(), + StatsType::Sample, + )? { + return Ok(acc); + } + Ok(Box::new(VarianceAccumulator::try_new(StatsType::Sample)?)) } @@ -161,8 +697,14 @@ impl AggregateUDFImpl for VarianceSample { fn create_groups_accumulator( &self, - _args: AccumulatorArgs, + args: AccumulatorArgs, ) -> Result> { + if let Some(acc) = create_decimal_variance_groups_accumulator( + args.expr_fields[0].data_type(), + StatsType::Sample, + )? { + return Ok(acc); + } Ok(Box::new(VarianceGroupsAccumulator::new(StatsType::Sample))) } @@ -230,6 +772,14 @@ impl AggregateUDFImpl for VariancePopulation { fn state_fields(&self, args: StateFieldsArgs) -> Result> { let name = args.name; + if args + .input_fields + .first() + .and_then(|field| decimal_scale(field.data_type())) + .is_some() + { + return Ok(decimal_variance_state_fields(name)); + } Ok(vec![ Field::new(format_state_name(name, "count"), DataType::UInt64, true), Field::new(format_state_name(name, "mean"), DataType::Float64, true), @@ -245,6 +795,13 @@ impl AggregateUDFImpl for VariancePopulation { return not_impl_err!("VAR_POP(DISTINCT) aggregations are not available"); } + if let Some(acc) = create_decimal_variance_accumulator( + acc_args.expr_fields[0].data_type(), + StatsType::Population, + )? { + return Ok(acc); + } + Ok(Box::new(VarianceAccumulator::try_new( StatsType::Population, )?)) @@ -260,8 +817,14 @@ impl AggregateUDFImpl for VariancePopulation { fn create_groups_accumulator( &self, - _args: AccumulatorArgs, + args: AccumulatorArgs, ) -> Result> { + if let Some(acc) = create_decimal_variance_groups_accumulator( + args.expr_fields[0].data_type(), + StatsType::Population, + )? { + return Ok(acc); + } Ok(Box::new(VarianceGroupsAccumulator::new( StatsType::Population, ))) diff --git a/datafusion/sqllogictest/test_files/aggregate.slt b/datafusion/sqllogictest/test_files/aggregate.slt index 51e2c4e75c38..a669f26fbc30 100644 --- a/datafusion/sqllogictest/test_files/aggregate.slt +++ b/datafusion/sqllogictest/test_files/aggregate.slt @@ -5633,18 +5633,18 @@ B -100.0045 Decimal128(14, 7) query RT select var_pop(c1), arrow_typeof(var_pop(c1)) from d_table ---- -11025.945028500004 Float64 +11025.9450285 Float64 query RT select var(c1), arrow_typeof(var(c1)) from d_table ---- -11606.257924736847 Float64 +11606.257924736841 Float64 query TRT select c2, var_pop(c1), arrow_typeof(var_pop(c1)) from d_table GROUP BY c2 ORDER BY c2 ---- -A 0.00000825 Float64 -B 0.00000825 Float64 +A 0.000008249999999997783 Float64 +B 0.000008249999999997783 Float64 # aggregate_decimal_count_distinct query I From 5b92361c75e4f185e8860f09a17895575bae7b39 Mon Sep 17 00:00:00 2001 From: Kumar Ujjawal Date: Wed, 26 Nov 2025 15:34:22 +0530 Subject: [PATCH 3/4] fixed aggregate test --- datafusion/sqllogictest/test_files/aggregate.slt | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/datafusion/sqllogictest/test_files/aggregate.slt b/datafusion/sqllogictest/test_files/aggregate.slt index a669f26fbc30..b248b86e533f 100644 --- a/datafusion/sqllogictest/test_files/aggregate.slt +++ b/datafusion/sqllogictest/test_files/aggregate.slt @@ -5643,8 +5643,8 @@ select var(c1), arrow_typeof(var(c1)) from d_table query TRT select c2, var_pop(c1), arrow_typeof(var_pop(c1)) from d_table GROUP BY c2 ORDER BY c2 ---- -A 0.000008249999999997783 Float64 -B 0.000008249999999997783 Float64 +A 0.00000825 Float64 +B 0.00000825 Float64 # aggregate_decimal_count_distinct query I From 18af96445979d3e8384c6dfbfd4ebff6016b3126 Mon Sep 17 00:00:00 2001 From: Kumar Ujjawal Date: Thu, 27 Nov 2025 10:41:40 +0530 Subject: [PATCH 4/4] fixed incorrect tests and handled edge cases --- .../functions-aggregate/src/variance.rs | 165 ++++++++++++++++-- 1 file changed, 153 insertions(+), 12 deletions(-) diff --git a/datafusion/functions-aggregate/src/variance.rs b/datafusion/functions-aggregate/src/variance.rs index 6673e09f2799..654a6625fc5e 100644 --- a/datafusion/functions-aggregate/src/variance.rs +++ b/datafusion/functions-aggregate/src/variance.rs @@ -203,16 +203,24 @@ fn create_decimal_variance_groups_accumulator( ) -> Result>> { let accumulator = match data_type { DataType::Decimal32(_, scale) => Some(Box::new( - DecimalVarianceGroupsAccumulator::::new(*scale, stats_type), + DecimalVarianceGroupsAccumulator::::try_new( + *scale, stats_type, + )?, ) as Box), DataType::Decimal64(_, scale) => Some(Box::new( - DecimalVarianceGroupsAccumulator::::new(*scale, stats_type), + DecimalVarianceGroupsAccumulator::::try_new( + *scale, stats_type, + )?, ) as Box), DataType::Decimal128(_, scale) => Some(Box::new( - DecimalVarianceGroupsAccumulator::::new(*scale, stats_type), + DecimalVarianceGroupsAccumulator::::try_new( + *scale, stats_type, + )?, ) as Box), DataType::Decimal256(_, scale) => Some(Box::new( - DecimalVarianceGroupsAccumulator::::new(*scale, stats_type), + DecimalVarianceGroupsAccumulator::::try_new( + *scale, stats_type, + )?, ) as Box), _ => None, }; @@ -323,7 +331,12 @@ impl DecimalVarianceState { .checked_sub(sum_squared) .ok_or_else(decimal_overflow_err)?; - let numerator = if numerator < i256::ZERO { + let negative_numerator = numerator < i256::ZERO; + debug_assert!( + !negative_numerator, + "Decimal variance numerator became negative: {numerator:?}. This indicates precision loss or overflow in intermediate calculations." + ); + let numerator = if negative_numerator { i256::ZERO } else { numerator @@ -479,13 +492,20 @@ where T: DecimalType + ArrowNumericType + Debug, T::Native: DecimalNative, { - fn new(scale: i8, stats_type: StatsType) -> Self { - Self { + fn try_new(scale: i8, stats_type: StatsType) -> Result { + if scale > DECIMAL256_MAX_SCALE { + return exec_err!( + "Decimal variance does not support scale {} greater than {}", + scale, + DECIMAL256_MAX_SCALE + ); + } + Ok(Self { states: Vec::new(), scale, stats_type, _marker: PhantomData, - } + }) } fn resize(&mut self, total_num_groups: usize) { @@ -512,7 +532,7 @@ where self.resize(total_num_groups); for (row, group_index) in group_indices.iter().enumerate() { if let Some(filter) = opt_filter { - if !filter.value(row) { + if !filter.is_valid(row) || !filter.value(row) { continue; } } @@ -1169,7 +1189,8 @@ impl GroupsAccumulator for VarianceGroupsAccumulator { #[cfg(test)] mod tests { - use arrow::array::Decimal128Builder; + use arrow::array::{Decimal128Array, Decimal128Builder, Float64Array}; + use arrow::datatypes::DECIMAL256_MAX_PRECISION; use datafusion_expr::EmitTo; use std::sync::Arc; @@ -1194,12 +1215,16 @@ mod tests { let decimal_array = builder.finish().with_precision_and_scale(10, 3).unwrap(); let array: ArrayRef = Arc::new(decimal_array); - let mut pop_acc = VarianceAccumulator::try_new(StatsType::Population)?; + let mut pop_acc = DecimalVarianceAccumulator::::try_new( + 3, + StatsType::Population, + )?; let pop_input = [Arc::clone(&array)]; pop_acc.update_batch(&pop_input)?; assert_variance(pop_acc.evaluate()?, 11025.9450285); - let mut sample_acc = VarianceAccumulator::try_new(StatsType::Sample)?; + let mut sample_acc = + DecimalVarianceAccumulator::::try_new(3, StatsType::Sample)?; let sample_input = [array]; sample_acc.update_batch(&sample_input)?; assert_variance(sample_acc.evaluate()?, 11606.257924736841); @@ -1207,6 +1232,122 @@ mod tests { Ok(()) } + #[test] + fn variance_decimal_handles_nulls() -> Result<()> { + let mut builder = Decimal128Builder::with_capacity(3); + builder.append_value(100); + builder.append_null(); + builder.append_value(300); + let array = builder.finish().with_precision_and_scale(10, 2).unwrap(); + let array: ArrayRef = Arc::new(array); + + let mut acc = DecimalVarianceAccumulator::::try_new( + 2, + StatsType::Population, + )?; + acc.update_batch(&[Arc::clone(&array)])?; + assert_variance(acc.evaluate()?, 1.0); + Ok(()) + } + + #[test] + fn variance_decimal_empty_input() -> Result<()> { + let array = Decimal128Array::from(Vec::>::new()) + .with_precision_and_scale(10, 2) + .unwrap(); + let array: ArrayRef = Arc::new(array); + + let mut acc = DecimalVarianceAccumulator::::try_new( + 2, + StatsType::Population, + )?; + acc.update_batch(&[array])?; + match acc.evaluate()? { + ScalarValue::Float64(None) => Ok(()), + other => panic!("expected NULL variance for empty input, got {other:?}"), + } + } + + #[test] + fn variance_decimal_single_value_sample() -> Result<()> { + let array = Decimal128Array::from(vec![Some(500)]) + .with_precision_and_scale(10, 2) + .unwrap(); + let array: ArrayRef = Arc::new(array); + let mut acc = + DecimalVarianceAccumulator::::try_new(2, StatsType::Sample)?; + acc.update_batch(&[array])?; + match acc.evaluate()? { + ScalarValue::Float64(None) => Ok(()), + other => { + panic!("expected NULL sample variance for single value, got {other:?}") + } + } + } + + #[test] + fn variance_decimal_groups_mixed_values() -> Result<()> { + let array = + Decimal128Array::from(vec![Some(100), Some(300), Some(-200), Some(-400)]) + .with_precision_and_scale(10, 2) + .unwrap(); + let array: ArrayRef = Arc::new(array); + let mut groups = DecimalVarianceGroupsAccumulator::::try_new( + 2, + StatsType::Population, + )?; + let group_indices = vec![0, 0, 1, 1]; + groups.update_batch(&[Arc::clone(&array)], &group_indices, None, 2)?; + let result = groups.evaluate(EmitTo::All)?; + let result = result.as_any().downcast_ref::().unwrap(); + assert!((result.value(0) - 1.0).abs() < 1e-9); + assert!((result.value(1) - 1.0).abs() < 1e-9); + Ok(()) + } + + #[test] + fn variance_decimal_max_scale() -> Result<()> { + let values = vec![ + ScalarValue::Decimal256( + Some(i256::from_i128(1)), + DECIMAL256_MAX_PRECISION, + DECIMAL256_MAX_SCALE, + ), + ScalarValue::Decimal256( + Some(i256::from_i128(-1)), + DECIMAL256_MAX_PRECISION, + DECIMAL256_MAX_SCALE, + ), + ]; + let array = ScalarValue::iter_to_array(values).unwrap(); + let mut acc = DecimalVarianceAccumulator::::try_new( + DECIMAL256_MAX_SCALE, + StatsType::Population, + )?; + acc.update_batch(&[array])?; + assert_variance(acc.evaluate()?, 1e-152); + Ok(()) + } + + #[test] + fn variance_decimal_retract_batch() -> Result<()> { + let update = Decimal128Array::from(vec![Some(100), Some(200), Some(300)]) + .with_precision_and_scale(10, 2) + .unwrap(); + let retract = Decimal128Array::from(vec![Some(100), Some(200)]) + .with_precision_and_scale(10, 2) + .unwrap(); + + let mut acc = DecimalVarianceAccumulator::::try_new( + 2, + StatsType::Population, + )?; + acc.update_batch(&[Arc::new(update)])?; + acc.retract_batch(&[Arc::new(retract)])?; + assert_variance(acc.evaluate()?, 0.0); + Ok(()) + } + fn assert_variance(value: ScalarValue, expected: f64) { match value { ScalarValue::Float64(Some(actual)) => {