diff --git a/datafusion/functions-aggregate/src/percentile_cont.rs b/datafusion/functions-aggregate/src/percentile_cont.rs index b46186bdfcab..1ea896d4b36f 100644 --- a/datafusion/functions-aggregate/src/percentile_cont.rs +++ b/datafusion/functions-aggregate/src/percentile_cont.rs @@ -33,14 +33,18 @@ use arrow::{ use arrow::array::ArrowNativeTypeOp; +use crate::min_max::{max_udaf, min_udaf}; use datafusion_common::{ - assert_eq_or_internal_err, internal_datafusion_err, plan_err, DataFusionError, - Result, ScalarValue, + assert_eq_or_internal_err, internal_datafusion_err, plan_err, + utils::take_function_args, DataFusionError, Result, ScalarValue, }; -use datafusion_expr::expr::{AggregateFunction, Sort}; -use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs}; use datafusion_expr::type_coercion::aggregates::NUMERICS; use datafusion_expr::utils::format_state_name; +use datafusion_expr::{ + expr::{AggregateFunction, Cast, Sort}, + function::{AccumulatorArgs, AggregateFunctionSimplification, StateFieldsArgs}, + simplify::SimplifyInfo, +}; use datafusion_expr::{ Accumulator, AggregateUDFImpl, Documentation, Expr, Signature, TypeSignature, Volatility, @@ -358,6 +362,12 @@ impl AggregateUDFImpl for PercentileCont { } } + fn simplify(&self) -> Option { + Some(Box::new(|aggregate_function, info| { + simplify_percentile_cont_aggregate(aggregate_function, info) + })) + } + fn supports_within_group_clause(&self) -> bool { true } @@ -367,6 +377,88 @@ impl AggregateUDFImpl for PercentileCont { } } +#[derive(Clone, Copy)] +enum PercentileRewriteTarget { + Min, + Max, +} + +#[expect(clippy::needless_pass_by_value)] +fn simplify_percentile_cont_aggregate( + aggregate_function: AggregateFunction, + info: &dyn SimplifyInfo, +) -> Result { + let original_expr = Expr::AggregateFunction(aggregate_function.clone()); + let params = &aggregate_function.params; + + let [value, percentile] = take_function_args("percentile_cont", ¶ms.args)?; + + let is_descending = params + .order_by + .first() + .map(|sort| !sort.asc) + .unwrap_or(false); + + let rewrite_target = match extract_percentile_literal(percentile) { + Some(0.0) => { + if is_descending { + PercentileRewriteTarget::Max + } else { + PercentileRewriteTarget::Min + } + } + Some(1.0) => { + if is_descending { + PercentileRewriteTarget::Min + } else { + PercentileRewriteTarget::Max + } + } + _ => return Ok(original_expr), + }; + + let input_type = match info.get_data_type(value) { + Ok(data_type) => data_type, + Err(_) => return Ok(original_expr), + }; + + let expected_return_type = + match percentile_cont_udaf().return_type(std::slice::from_ref(&input_type)) { + Ok(data_type) => data_type, + Err(_) => return Ok(original_expr), + }; + + let mut agg_arg = value.clone(); + if expected_return_type != input_type { + // min/max return the same type as their input. percentile_cont widens + // integers to Float64 (and preserves float/decimal types), so ensure the + // rewritten aggregate sees an input of the final return type. + agg_arg = Expr::Cast(Cast::new(Box::new(agg_arg), expected_return_type.clone())); + } + + let udaf = match rewrite_target { + PercentileRewriteTarget::Min => min_udaf(), + PercentileRewriteTarget::Max => max_udaf(), + }; + + let rewritten = Expr::AggregateFunction(AggregateFunction::new_udf( + udaf, + vec![agg_arg], + params.distinct, + params.filter.clone(), + vec![], + params.null_treatment, + )); + Ok(rewritten) +} + +fn extract_percentile_literal(expr: &Expr) -> Option { + match expr { + Expr::Literal(ScalarValue::Float64(Some(value)), _) => Some(*value), + _ => None, + } +} + /// The percentile_cont accumulator accumulates the raw input values /// as native types. /// diff --git a/datafusion/sqllogictest/test_files/aggregate.slt b/datafusion/sqllogictest/test_files/aggregate.slt index e81bfb72a0ef..22d99108be84 100644 --- a/datafusion/sqllogictest/test_files/aggregate.slt +++ b/datafusion/sqllogictest/test_files/aggregate.slt @@ -3488,6 +3488,59 @@ SELECT percentile_cont(1.0) WITHIN GROUP (ORDER BY c2) FROM aggregate_test_100 ---- 5 +# Ensure percentile_cont simplification rewrites to min/max plans +query TT +EXPLAIN SELECT percentile_cont(0.0) WITHIN GROUP (ORDER BY c2) FROM aggregate_test_100; +---- +logical_plan +01)Aggregate: groupBy=[[]], aggr=[[min(CAST(aggregate_test_100.c2 AS Float64)) AS percentile_cont(Float64(0)) WITHIN GROUP [aggregate_test_100.c2 ASC NULLS LAST]]] +02)--TableScan: aggregate_test_100 projection=[c2] +physical_plan +01)AggregateExec: mode=Final, gby=[], aggr=[percentile_cont(Float64(0)) WITHIN GROUP [aggregate_test_100.c2 ASC NULLS LAST]] +02)--CoalescePartitionsExec +03)----AggregateExec: mode=Partial, gby=[], aggr=[percentile_cont(Float64(0)) WITHIN GROUP [aggregate_test_100.c2 ASC NULLS LAST]] +04)------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +05)--------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100_with_dates.csv]]}, projection=[c2], file_type=csv, has_header=true + +query TT +EXPLAIN SELECT percentile_cont(0.0) WITHIN GROUP (ORDER BY c2 DESC) FROM aggregate_test_100; +---- +logical_plan +01)Aggregate: groupBy=[[]], aggr=[[max(CAST(aggregate_test_100.c2 AS Float64)) AS percentile_cont(Float64(0)) WITHIN GROUP [aggregate_test_100.c2 DESC NULLS FIRST]]] +02)--TableScan: aggregate_test_100 projection=[c2] +physical_plan +01)AggregateExec: mode=Final, gby=[], aggr=[percentile_cont(Float64(0)) WITHIN GROUP [aggregate_test_100.c2 DESC NULLS FIRST]] +02)--CoalescePartitionsExec +03)----AggregateExec: mode=Partial, gby=[], aggr=[percentile_cont(Float64(0)) WITHIN GROUP [aggregate_test_100.c2 DESC NULLS FIRST]] +04)------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +05)--------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100_with_dates.csv]]}, projection=[c2], file_type=csv, has_header=true + +query TT +EXPLAIN SELECT percentile_cont(c2, 0.0) FROM aggregate_test_100; +---- +logical_plan +01)Aggregate: groupBy=[[]], aggr=[[min(CAST(aggregate_test_100.c2 AS Float64)) AS percentile_cont(aggregate_test_100.c2,Float64(0))]] +02)--TableScan: aggregate_test_100 projection=[c2] +physical_plan +01)AggregateExec: mode=Final, gby=[], aggr=[percentile_cont(aggregate_test_100.c2,Float64(0))] +02)--CoalescePartitionsExec +03)----AggregateExec: mode=Partial, gby=[], aggr=[percentile_cont(aggregate_test_100.c2,Float64(0))] +04)------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +05)--------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100_with_dates.csv]]}, projection=[c2], file_type=csv, has_header=true + +query TT +EXPLAIN SELECT percentile_cont(c2, 1.0) FROM aggregate_test_100; +---- +logical_plan +01)Aggregate: groupBy=[[]], aggr=[[max(CAST(aggregate_test_100.c2 AS Float64)) AS percentile_cont(aggregate_test_100.c2,Float64(1))]] +02)--TableScan: aggregate_test_100 projection=[c2] +physical_plan +01)AggregateExec: mode=Final, gby=[], aggr=[percentile_cont(aggregate_test_100.c2,Float64(1))] +02)--CoalescePartitionsExec +03)----AggregateExec: mode=Partial, gby=[], aggr=[percentile_cont(aggregate_test_100.c2,Float64(1))] +04)------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +05)--------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100_with_dates.csv]]}, projection=[c2], file_type=csv, has_header=true + query R SELECT percentile_cont(0.25) WITHIN GROUP (ORDER BY c2) FROM aggregate_test_100 ----