diff --git a/datafusion/functions-aggregate/src/sum.rs b/datafusion/functions-aggregate/src/sum.rs index 958553d78ca5..aad5d744d3e4 100644 --- a/datafusion/functions-aggregate/src/sum.rs +++ b/datafusion/functions-aggregate/src/sum.rs @@ -108,6 +108,18 @@ macro_rules! downcast_sum { }; } +/// Properties for [`Sum`] +#[derive(Default, Debug, PartialEq, Eq, Hash)] +pub struct SumProperties { + /// Whether to maintain the precision of decimal types + /// If `false`, the new precision of this [`Sum`] will be calculated as + /// `MIN(, + 10)` (similar to Spark). + /// If `true`, the new precision will be the same as the precision of the first argument. + /// + /// The default value is `false`. + pub maintains_decimal_precision: bool, +} + #[user_doc( doc_section(label = "General Functions"), description = "Returns the sum of all values in the specified column.", @@ -125,12 +137,21 @@ macro_rules! downcast_sum { #[derive(Debug, PartialEq, Eq, Hash)] pub struct Sum { signature: Signature, + properties: SumProperties, } impl Sum { pub fn new() -> Self { Self { signature: Signature::user_defined(Volatility::Immutable), + properties: SumProperties::default(), + } + } + + pub fn new_with_properties(properties: SumProperties) -> Self { + Self { + signature: Signature::user_defined(Volatility::Immutable), + properties, } } } @@ -180,34 +201,46 @@ impl AggregateUDFImpl for Sum { } fn return_type(&self, arg_types: &[DataType]) -> Result { + macro_rules! new_with_precision { + ($dt:expr,$max:expr,$precision:expr,$scale:expr) => { + if self.properties.maintains_decimal_precision { + $dt($precision, $scale) + } else { + // In Spark, the resulting decimal precision is bounded + // ref: https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala#L66 + let new_precision = $max.min($precision + 10); + $dt(new_precision, $scale) + } + }; + } match &arg_types[0] { DataType::Int64 => Ok(DataType::Int64), DataType::UInt64 => Ok(DataType::UInt64), DataType::Float64 => Ok(DataType::Float64), - DataType::Decimal32(precision, scale) => { - // in the spark, the result type is DECIMAL(min(38,precision+10), s) - // ref: https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala#L66 - let new_precision = DECIMAL32_MAX_PRECISION.min(*precision + 10); - Ok(DataType::Decimal32(new_precision, *scale)) - } - DataType::Decimal64(precision, scale) => { - // in the spark, the result type is DECIMAL(min(38,precision+10), s) - // ref: https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala#L66 - let new_precision = DECIMAL64_MAX_PRECISION.min(*precision + 10); - Ok(DataType::Decimal64(new_precision, *scale)) - } - DataType::Decimal128(precision, scale) => { - // in the spark, the result type is DECIMAL(min(38,precision+10), s) - // ref: https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala#L66 - let new_precision = DECIMAL128_MAX_PRECISION.min(*precision + 10); - Ok(DataType::Decimal128(new_precision, *scale)) - } - DataType::Decimal256(precision, scale) => { - // in the spark, the result type is DECIMAL(min(38,precision+10), s) - // ref: https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala#L66 - let new_precision = DECIMAL256_MAX_PRECISION.min(*precision + 10); - Ok(DataType::Decimal256(new_precision, *scale)) - } + DataType::Decimal32(precision, scale) => Ok(new_with_precision!( + DataType::Decimal32, + DECIMAL32_MAX_PRECISION, + *precision, + *scale + )), + DataType::Decimal64(precision, scale) => Ok(new_with_precision!( + DataType::Decimal64, + DECIMAL64_MAX_PRECISION, + *precision, + *scale + )), + DataType::Decimal128(precision, scale) => Ok(new_with_precision!( + DataType::Decimal128, + DECIMAL128_MAX_PRECISION, + *precision, + *scale + )), + DataType::Decimal256(precision, scale) => Ok(new_with_precision!( + DataType::Decimal256, + DECIMAL256_MAX_PRECISION, + *precision, + *scale + )), other => { exec_err!("[return_type] SUM not supported for {}", other) } diff --git a/datafusion/optimizer/Cargo.toml b/datafusion/optimizer/Cargo.toml index f10510e0973c..fb92fd87e2cb 100644 --- a/datafusion/optimizer/Cargo.toml +++ b/datafusion/optimizer/Cargo.toml @@ -46,6 +46,7 @@ chrono = { workspace = true } datafusion-common = { workspace = true, default-features = true } datafusion-expr = { workspace = true } datafusion-expr-common = { workspace = true } +datafusion-functions-aggregate = { workspace = true } datafusion-physical-expr = { workspace = true } indexmap = { workspace = true } itertools = { workspace = true } diff --git a/datafusion/optimizer/src/single_distinct_to_groupby.rs b/datafusion/optimizer/src/single_distinct_to_groupby.rs index e9a23c7c4dc5..354e63914cbf 100644 --- a/datafusion/optimizer/src/single_distinct_to_groupby.rs +++ b/datafusion/optimizer/src/single_distinct_to_groupby.rs @@ -27,12 +27,14 @@ use datafusion_common::{ }; use datafusion_expr::builder::project; use datafusion_expr::expr::AggregateFunctionParams; +use datafusion_expr::AggregateUDF; use datafusion_expr::{ col, expr::AggregateFunction, logical_plan::{Aggregate, LogicalPlan}, Expr, }; +use datafusion_functions_aggregate::sum::{Sum, SumProperties}; /// single distinct to group by optimizer rule /// ```text @@ -219,7 +221,7 @@ impl OptimizerRule for SingleDistinctToGroupBy { .alias(&alias_str), ); Ok(Expr::AggregateFunction(AggregateFunction::new_udf( - func, + rewrite_outer_aggregate_func(func), vec![col(&alias_str)], false, None, @@ -277,13 +279,34 @@ impl OptimizerRule for SingleDistinctToGroupBy { } } +/// Rewrite the outer aggregate functions that may require special handling +/// when duplicated to accommodate two-phase aggregation. +fn rewrite_outer_aggregate_func(func: Arc) -> Arc { + let inner = func.inner(); + + if inner.as_any().is::() { + // For SUM, we should maintain the precision from the initial aggregation. + // There should be no precision expansion in the second phase. + return Arc::new(AggregateUDF::new_from_impl(Sum::new_with_properties( + SumProperties { + maintains_decimal_precision: true, + }, + ))); + } + + func +} + #[cfg(test)] mod tests { use super::*; use crate::assert_optimized_plan_eq_display_indent_snapshot; use crate::test::*; + use arrow::datatypes::{DataType, Field, Schema}; use datafusion_expr::expr::GroupingSet; + use datafusion_expr::table_scan; use datafusion_expr::ExprFunctionExt; + use datafusion_expr::LogicalPlanBuilderOptions; use datafusion_expr::{lit, logical_plan::builder::LogicalPlanBuilder}; use datafusion_functions_aggregate::count::count_udaf; use datafusion_functions_aggregate::expr_fn::{count, count_distinct, max, min, sum}; @@ -719,6 +742,36 @@ mod tests { ) } + #[test] + fn sum_maintains_decimal_precision() -> Result<()> { + let schema = Schema::new(vec![ + Field::new("o_orderkey", DataType::Int32, false), + Field::new("o_totalprice", DataType::Decimal128(15, 2), false), + ]); + + let table_scan = table_scan(Some("test"), &schema, None)?.build()?; + let builder = LogicalPlanBuilder::from(table_scan).with_options( + LogicalPlanBuilderOptions::new().with_add_implicit_group_by_exprs(true), + ); + + let plan = builder + .aggregate( + Vec::::new(), + vec![sum(col("o_totalprice")), count_distinct(col("o_orderkey"))], + )? + .build()?; + + assert_optimized_plan_equal!( + plan, + @r" + Projection: sum(alias2) AS sum(test.o_totalprice), count(alias1) AS count(DISTINCT test.o_orderkey) [sum(test.o_totalprice):Decimal128(25, 2);N, count(DISTINCT test.o_orderkey):Int64] + Aggregate: groupBy=[[]], aggr=[[sum(alias2), count(alias1)]] [sum(alias2):Decimal128(25, 2);N, count(alias1):Int64] + Aggregate: groupBy=[[test.o_orderkey AS alias1]], aggr=[[sum(test.o_totalprice) AS alias2]] [alias1:Int32, alias2:Decimal128(25, 2);N] + TableScan: test [o_orderkey:Int32, o_totalprice:Decimal128(15, 2)] + " + ) + } + #[test] fn aggregate_with_filter_and_order_by() -> Result<()> { let table_scan = test_table_scan()?; diff --git a/datafusion/sqllogictest/test_files/aggregate.slt b/datafusion/sqllogictest/test_files/aggregate.slt index a5973afc0a93..8d6f156fed04 100644 --- a/datafusion/sqllogictest/test_files/aggregate.slt +++ b/datafusion/sqllogictest/test_files/aggregate.slt @@ -7931,3 +7931,31 @@ NULL NULL NULL NULL statement ok drop table distinct_avg; + + +# Regression test for https://github.com/apache/datafusion/issues/17699 + +statement ok +CREATE TABLE orders ( + o_orderkey INT, + o_totalprice DECIMAL(15, 2) +); + +statement ok +INSERT INTO orders VALUES (1, 10.00); + +query R +SELECT total_spent +FROM ( + SELECT + SUM(o_totalprice) AS total_spent, + COUNT(DISTINCT o_orderkey) AS order_count + FROM orders +) t +WHERE total_spent > 0; +---- +10 + + +statement ok +DROP TABLE orders;