From 89f2213f16b11bacd08545da9cf0469d8ef42a2b Mon Sep 17 00:00:00 2001 From: Kumar Ujjawal Date: Thu, 20 Nov 2025 12:38:19 +0530 Subject: [PATCH 1/8] Simplify percentile_cont for 0/1 percentiles --- .../src/percentile_cont.rs | 235 +++++++++++++++++- 1 file changed, 233 insertions(+), 2 deletions(-) diff --git a/datafusion/functions-aggregate/src/percentile_cont.rs b/datafusion/functions-aggregate/src/percentile_cont.rs index b46186bdfcab..6a696f06abde 100644 --- a/datafusion/functions-aggregate/src/percentile_cont.rs +++ b/datafusion/functions-aggregate/src/percentile_cont.rs @@ -33,14 +33,18 @@ use arrow::{ use arrow::array::ArrowNativeTypeOp; +use crate::min_max::{max_udaf, min_udaf}; use datafusion_common::{ assert_eq_or_internal_err, internal_datafusion_err, plan_err, DataFusionError, Result, ScalarValue, }; -use datafusion_expr::expr::{AggregateFunction, Sort}; -use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs}; use datafusion_expr::type_coercion::aggregates::NUMERICS; use datafusion_expr::utils::format_state_name; +use datafusion_expr::{ + expr::{AggregateFunction, Cast, Sort}, + function::{AccumulatorArgs, AggregateFunctionSimplification, StateFieldsArgs}, + simplify::SimplifyInfo, +}; use datafusion_expr::{ Accumulator, AggregateUDFImpl, Documentation, Expr, Signature, TypeSignature, Volatility, @@ -358,6 +362,12 @@ impl AggregateUDFImpl for PercentileCont { } } + fn simplify(&self) -> Option { + Some(Box::new(|aggregate_function, info| { + simplify_percentile_cont_aggregate(aggregate_function, info) + })) + } + fn supports_within_group_clause(&self) -> bool { true } @@ -367,6 +377,150 @@ impl AggregateUDFImpl for PercentileCont { } } +const PERCENTILE_LITERAL_EPSILON: f64 = 1e-12; + +#[derive(Clone, Copy)] +enum PercentileRewriteTarget { + Min, + Max, +} + +#[allow(clippy::needless_pass_by_value)] +fn simplify_percentile_cont_aggregate( + aggregate_function: AggregateFunction, + info: &dyn SimplifyInfo, +) -> Result { + let original_expr = Expr::AggregateFunction(aggregate_function.clone()); + let params = &aggregate_function.params; + + if params.args.len() != 2 { + return Ok(original_expr); + } + + let percentile_value = match extract_percentile_literal(¶ms.args[1]) { + Some(value) => value, + None => return Ok(original_expr), + }; + + let is_descending = params + .order_by + .first() + .map(|sort| !sort.asc) + .unwrap_or(false); + + let rewrite_target = match classify_rewrite_target(percentile_value, is_descending) { + Some(target) => target, + None => return Ok(original_expr), + }; + + let value_expr = params.args[0].clone(); + let input_type = match info.get_data_type(&value_expr) { + Ok(data_type) => data_type, + Err(_) => return Ok(original_expr), + }; + + let expected_return_type = match percentile_cont_result_type(&input_type) { + Some(data_type) => data_type, + None => return Ok(original_expr), + }; + + let udaf = match rewrite_target { + PercentileRewriteTarget::Min => min_udaf(), + PercentileRewriteTarget::Max => max_udaf(), + }; + + let mut rewritten = Expr::AggregateFunction(AggregateFunction::new_udf( + udaf, + vec![value_expr], + params.distinct, + params.filter.clone(), + vec![], + params.null_treatment, + )); + + if expected_return_type != input_type { + rewritten = Expr::Cast(Cast::new(Box::new(rewritten), expected_return_type)); + } + + Ok(rewritten) +} + +fn classify_rewrite_target( + percentile_value: f64, + is_descending: bool, +) -> Option { + if nearly_equals_fraction(percentile_value, 0.0) { + Some(if is_descending { + PercentileRewriteTarget::Max + } else { + PercentileRewriteTarget::Min + }) + } else if nearly_equals_fraction(percentile_value, 1.0) { + Some(if is_descending { + PercentileRewriteTarget::Min + } else { + PercentileRewriteTarget::Max + }) + } else { + None + } +} + +fn nearly_equals_fraction(value: f64, target: f64) -> bool { + (value - target).abs() < PERCENTILE_LITERAL_EPSILON +} + +fn percentile_cont_result_type(input_type: &DataType) -> Option { + if !input_type.is_numeric() { + return None; + } + + let result_type = match input_type { + DataType::Float16 | DataType::Float32 | DataType::Float64 => input_type.clone(), + DataType::Decimal32(_, _) + | DataType::Decimal64(_, _) + | DataType::Decimal128(_, _) + | DataType::Decimal256(_, _) => input_type.clone(), + DataType::UInt8 + | DataType::UInt16 + | DataType::UInt32 + | DataType::UInt64 + | DataType::Int8 + | DataType::Int16 + | DataType::Int32 + | DataType::Int64 => DataType::Float64, + _ => return None, + }; + + Some(result_type) +} + +fn extract_percentile_literal(expr: &Expr) -> Option { + match expr { + Expr::Literal(value, _) => literal_scalar_to_f64(value), + Expr::Alias(alias) => extract_percentile_literal(alias.expr.as_ref()), + Expr::Cast(cast) => extract_percentile_literal(cast.expr.as_ref()), + Expr::TryCast(cast) => extract_percentile_literal(cast.expr.as_ref()), + _ => None, + } +} + +fn literal_scalar_to_f64(value: &ScalarValue) -> Option { + match value { + ScalarValue::Float64(Some(v)) => Some(*v), + ScalarValue::Float32(Some(v)) => Some(*v as f64), + ScalarValue::Int64(Some(v)) => Some(*v as f64), + ScalarValue::Int32(Some(v)) => Some(*v as f64), + ScalarValue::Int16(Some(v)) => Some(*v as f64), + ScalarValue::Int8(Some(v)) => Some(*v as f64), + ScalarValue::UInt64(Some(v)) => Some(*v as f64), + ScalarValue::UInt32(Some(v)) => Some(*v as f64), + ScalarValue::UInt16(Some(v)) => Some(*v as f64), + ScalarValue::UInt8(Some(v)) => Some(*v as f64), + _ => None, + } +} + /// The percentile_cont accumulator accumulates the raw input values /// as native types. /// @@ -760,3 +914,80 @@ fn calculate_percentile( } } } + +#[cfg(test)] +mod tests { + use super::*; + use crate::min_max::{max, min}; + use arrow::datatypes::{DataType, Field, Schema}; + use datafusion_common::{DFSchema, Result}; + use datafusion_expr::{ + col, execution_props::ExecutionProps, expr::Cast as ExprCast, expr::Sort, lit, + simplify::SimplifyContext, Expr, + }; + use std::sync::Arc; + + fn run_simplifier(expr: &Expr, schema: Arc) -> Result { + let props = ExecutionProps::new(); + let context = SimplifyContext::new(&props).with_schema(schema); + let simplifier = percentile_cont_udaf() + .simplify() + .expect("simplifier should be available"); + + match expr.clone() { + Expr::AggregateFunction(agg) => simplifier(agg, &context), + _ => panic!("expected aggregate expression"), + } + } + + fn schema_for(field: Field) -> Arc { + Arc::new(DFSchema::try_from(Schema::new(vec![field])).unwrap()) + } + + #[test] + fn simplify_percentile_cont_zero_to_min_with_cast() -> Result<()> { + let schema = schema_for(Field::new("value", DataType::Int32, true)); + let expr = percentile_cont(Sort::new(col("value"), true, true), lit(0_f64)); + let simplified = run_simplifier(&expr, Arc::clone(&schema))?; + + let expected_min = min(col("value")); + let expected = + Expr::Cast(ExprCast::new(Box::new(expected_min), DataType::Float64)); + + assert_eq!(simplified, expected); + Ok(()) + } + + #[test] + fn simplify_percentile_cont_zero_desc_to_max() -> Result<()> { + let schema = schema_for(Field::new("value", DataType::Float64, true)); + let expr = percentile_cont(Sort::new(col("value"), false, true), lit(0_f64)); + let simplified = run_simplifier(&expr, schema)?; + + let expected = max(col("value")); + assert_eq!(simplified, expected); + Ok(()) + } + + #[test] + fn simplify_percentile_cont_one_desc_to_min() -> Result<()> { + let schema = schema_for(Field::new("value", DataType::Float64, true)); + let expr = percentile_cont(Sort::new(col("value"), false, true), lit(1_f64)); + let simplified = run_simplifier(&expr, schema)?; + + let expected = min(col("value")); + assert_eq!(simplified, expected); + Ok(()) + } + + #[test] + fn percentile_cont_not_simplified_for_other_percentiles() -> Result<()> { + let schema = schema_for(Field::new("value", DataType::Float64, true)); + let expr = percentile_cont(Sort::new(col("value"), true, true), lit(0.5_f64)); + let expected = expr.clone(); + + let simplified = run_simplifier(&expr, schema)?; + assert_eq!(simplified, expected); + Ok(()) + } +} From e158fd41c1112aba0ef5b3f881280b2cfaeaf578 Mon Sep 17 00:00:00 2001 From: Kumar Ujjawal Date: Thu, 20 Nov 2025 13:35:17 +0530 Subject: [PATCH 2/8] fix sqllogic error --- .../src/percentile_cont.rs | 20 +++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/datafusion/functions-aggregate/src/percentile_cont.rs b/datafusion/functions-aggregate/src/percentile_cont.rs index 6a696f06abde..a529e41f18ce 100644 --- a/datafusion/functions-aggregate/src/percentile_cont.rs +++ b/datafusion/functions-aggregate/src/percentile_cont.rs @@ -429,19 +429,19 @@ fn simplify_percentile_cont_aggregate( PercentileRewriteTarget::Max => max_udaf(), }; - let mut rewritten = Expr::AggregateFunction(AggregateFunction::new_udf( + let mut agg_arg = value_expr; + if expected_return_type != input_type { + agg_arg = Expr::Cast(Cast::new(Box::new(agg_arg), expected_return_type.clone())); + } + + let rewritten = Expr::AggregateFunction(AggregateFunction::new_udf( udaf, - vec![value_expr], + vec![agg_arg], params.distinct, params.filter.clone(), vec![], params.null_treatment, )); - - if expected_return_type != input_type { - rewritten = Expr::Cast(Cast::new(Box::new(rewritten), expected_return_type)); - } - Ok(rewritten) } @@ -950,9 +950,9 @@ mod tests { let expr = percentile_cont(Sort::new(col("value"), true, true), lit(0_f64)); let simplified = run_simplifier(&expr, Arc::clone(&schema))?; - let expected_min = min(col("value")); - let expected = - Expr::Cast(ExprCast::new(Box::new(expected_min), DataType::Float64)); + let casted_value = + Expr::Cast(ExprCast::new(Box::new(col("value")), DataType::Float64)); + let expected = min(casted_value); assert_eq!(simplified, expected); Ok(()) From 64b047d2f6181d80c8225e2845f0420a8a80c027 Mon Sep 17 00:00:00 2001 From: Kumar Ujjawal Date: Thu, 20 Nov 2025 15:31:55 +0530 Subject: [PATCH 3/8] added the test in sqllogictests files --- .../sqllogictest/test_files/aggregate.slt | 53 +++++++++++++++++++ 1 file changed, 53 insertions(+) diff --git a/datafusion/sqllogictest/test_files/aggregate.slt b/datafusion/sqllogictest/test_files/aggregate.slt index e81bfb72a0ef..22d99108be84 100644 --- a/datafusion/sqllogictest/test_files/aggregate.slt +++ b/datafusion/sqllogictest/test_files/aggregate.slt @@ -3488,6 +3488,59 @@ SELECT percentile_cont(1.0) WITHIN GROUP (ORDER BY c2) FROM aggregate_test_100 ---- 5 +# Ensure percentile_cont simplification rewrites to min/max plans +query TT +EXPLAIN SELECT percentile_cont(0.0) WITHIN GROUP (ORDER BY c2) FROM aggregate_test_100; +---- +logical_plan +01)Aggregate: groupBy=[[]], aggr=[[min(CAST(aggregate_test_100.c2 AS Float64)) AS percentile_cont(Float64(0)) WITHIN GROUP [aggregate_test_100.c2 ASC NULLS LAST]]] +02)--TableScan: aggregate_test_100 projection=[c2] +physical_plan +01)AggregateExec: mode=Final, gby=[], aggr=[percentile_cont(Float64(0)) WITHIN GROUP [aggregate_test_100.c2 ASC NULLS LAST]] +02)--CoalescePartitionsExec +03)----AggregateExec: mode=Partial, gby=[], aggr=[percentile_cont(Float64(0)) WITHIN GROUP [aggregate_test_100.c2 ASC NULLS LAST]] +04)------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +05)--------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100_with_dates.csv]]}, projection=[c2], file_type=csv, has_header=true + +query TT +EXPLAIN SELECT percentile_cont(0.0) WITHIN GROUP (ORDER BY c2 DESC) FROM aggregate_test_100; +---- +logical_plan +01)Aggregate: groupBy=[[]], aggr=[[max(CAST(aggregate_test_100.c2 AS Float64)) AS percentile_cont(Float64(0)) WITHIN GROUP [aggregate_test_100.c2 DESC NULLS FIRST]]] +02)--TableScan: aggregate_test_100 projection=[c2] +physical_plan +01)AggregateExec: mode=Final, gby=[], aggr=[percentile_cont(Float64(0)) WITHIN GROUP [aggregate_test_100.c2 DESC NULLS FIRST]] +02)--CoalescePartitionsExec +03)----AggregateExec: mode=Partial, gby=[], aggr=[percentile_cont(Float64(0)) WITHIN GROUP [aggregate_test_100.c2 DESC NULLS FIRST]] +04)------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +05)--------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100_with_dates.csv]]}, projection=[c2], file_type=csv, has_header=true + +query TT +EXPLAIN SELECT percentile_cont(c2, 0.0) FROM aggregate_test_100; +---- +logical_plan +01)Aggregate: groupBy=[[]], aggr=[[min(CAST(aggregate_test_100.c2 AS Float64)) AS percentile_cont(aggregate_test_100.c2,Float64(0))]] +02)--TableScan: aggregate_test_100 projection=[c2] +physical_plan +01)AggregateExec: mode=Final, gby=[], aggr=[percentile_cont(aggregate_test_100.c2,Float64(0))] +02)--CoalescePartitionsExec +03)----AggregateExec: mode=Partial, gby=[], aggr=[percentile_cont(aggregate_test_100.c2,Float64(0))] +04)------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +05)--------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100_with_dates.csv]]}, projection=[c2], file_type=csv, has_header=true + +query TT +EXPLAIN SELECT percentile_cont(c2, 1.0) FROM aggregate_test_100; +---- +logical_plan +01)Aggregate: groupBy=[[]], aggr=[[max(CAST(aggregate_test_100.c2 AS Float64)) AS percentile_cont(aggregate_test_100.c2,Float64(1))]] +02)--TableScan: aggregate_test_100 projection=[c2] +physical_plan +01)AggregateExec: mode=Final, gby=[], aggr=[percentile_cont(aggregate_test_100.c2,Float64(1))] +02)--CoalescePartitionsExec +03)----AggregateExec: mode=Partial, gby=[], aggr=[percentile_cont(aggregate_test_100.c2,Float64(1))] +04)------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +05)--------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100_with_dates.csv]]}, projection=[c2], file_type=csv, has_header=true + query R SELECT percentile_cont(0.25) WITHIN GROUP (ORDER BY c2) FROM aggregate_test_100 ---- From 812d3705408f7aeabb3e91296a88c297df62be0b Mon Sep 17 00:00:00 2001 From: Kumar Ujjawal Date: Thu, 20 Nov 2025 21:06:22 +0530 Subject: [PATCH 4/8] minor changes Co-authored-by: Martin Grigorov --- datafusion/functions-aggregate/src/percentile_cont.rs | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/datafusion/functions-aggregate/src/percentile_cont.rs b/datafusion/functions-aggregate/src/percentile_cont.rs index a529e41f18ce..2ba5978e656e 100644 --- a/datafusion/functions-aggregate/src/percentile_cont.rs +++ b/datafusion/functions-aggregate/src/percentile_cont.rs @@ -385,7 +385,7 @@ enum PercentileRewriteTarget { Max, } -#[allow(clippy::needless_pass_by_value)] +#[expect(clippy::needless_pass_by_value)] fn simplify_percentile_cont_aggregate( aggregate_function: AggregateFunction, info: &dyn SimplifyInfo, @@ -398,7 +398,7 @@ fn simplify_percentile_cont_aggregate( } let percentile_value = match extract_percentile_literal(¶ms.args[1]) { - Some(value) => value, + Some(value) if value >= 0.0 && value <= 1.0 => value, None => return Ok(original_expr), }; @@ -509,6 +509,7 @@ fn literal_scalar_to_f64(value: &ScalarValue) -> Option { match value { ScalarValue::Float64(Some(v)) => Some(*v), ScalarValue::Float32(Some(v)) => Some(*v as f64), + ScalarValue::Float16(Some(v)) => Some(v.to_f64()), ScalarValue::Int64(Some(v)) => Some(*v as f64), ScalarValue::Int32(Some(v)) => Some(*v as f64), ScalarValue::Int16(Some(v)) => Some(*v as f64), From 6696bc8a98ea59983551c60c6c825af47367c9df Mon Sep 17 00:00:00 2001 From: Kumar Ujjawal Date: Thu, 20 Nov 2025 21:19:47 +0530 Subject: [PATCH 5/8] fix failing ci --- datafusion/functions-aggregate/src/percentile_cont.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/datafusion/functions-aggregate/src/percentile_cont.rs b/datafusion/functions-aggregate/src/percentile_cont.rs index 2ba5978e656e..aad65b5726a0 100644 --- a/datafusion/functions-aggregate/src/percentile_cont.rs +++ b/datafusion/functions-aggregate/src/percentile_cont.rs @@ -398,8 +398,8 @@ fn simplify_percentile_cont_aggregate( } let percentile_value = match extract_percentile_literal(¶ms.args[1]) { - Some(value) if value >= 0.0 && value <= 1.0 => value, - None => return Ok(original_expr), + Some(value) if (0.0..=1.0).contains(&value) => value, + _ => return Ok(original_expr), }; let is_descending = params From 2251dfa5f84b59a379d831ac8d72306761b60e2c Mon Sep 17 00:00:00 2001 From: Kumar Ujjawal Date: Fri, 21 Nov 2025 13:29:04 +0530 Subject: [PATCH 6/8] nit Co-authored-by: Jeffrey Vo --- datafusion/functions-aggregate/src/percentile_cont.rs | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/datafusion/functions-aggregate/src/percentile_cont.rs b/datafusion/functions-aggregate/src/percentile_cont.rs index aad65b5726a0..281978c70da8 100644 --- a/datafusion/functions-aggregate/src/percentile_cont.rs +++ b/datafusion/functions-aggregate/src/percentile_cont.rs @@ -393,9 +393,7 @@ fn simplify_percentile_cont_aggregate( let original_expr = Expr::AggregateFunction(aggregate_function.clone()); let params = &aggregate_function.params; - if params.args.len() != 2 { - return Ok(original_expr); - } + let [value, percentile] = take_function_args("percentile_cont", ¶ms.args)?; let percentile_value = match extract_percentile_literal(¶ms.args[1]) { Some(value) if (0.0..=1.0).contains(&value) => value, From fc46939603e1b53bed25c032aa08187c577c1001 Mon Sep 17 00:00:00 2001 From: Kumar Ujjawal Date: Fri, 21 Nov 2025 13:30:05 +0530 Subject: [PATCH 7/8] minor Co-authored-by: Jeffrey Vo --- datafusion/functions-aggregate/src/percentile_cont.rs | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/datafusion/functions-aggregate/src/percentile_cont.rs b/datafusion/functions-aggregate/src/percentile_cont.rs index 281978c70da8..ae49ac798304 100644 --- a/datafusion/functions-aggregate/src/percentile_cont.rs +++ b/datafusion/functions-aggregate/src/percentile_cont.rs @@ -412,10 +412,7 @@ fn simplify_percentile_cont_aggregate( }; let value_expr = params.args[0].clone(); - let input_type = match info.get_data_type(&value_expr) { - Ok(data_type) => data_type, - Err(_) => return Ok(original_expr), - }; + let input_type = match info.get_data_type(&value_expr)?; let expected_return_type = match percentile_cont_result_type(&input_type) { Some(data_type) => data_type, From 4a10e04cfcb7d9c62b79d97aa9e206dacccaed4d Mon Sep 17 00:00:00 2001 From: Kumar Ujjawal Date: Fri, 21 Nov 2025 16:12:23 +0530 Subject: [PATCH 8/8] removed unit tests and other ergonomica changes with commnets --- .../src/percentile_cont.rs | 207 +++--------------- 1 file changed, 36 insertions(+), 171 deletions(-) diff --git a/datafusion/functions-aggregate/src/percentile_cont.rs b/datafusion/functions-aggregate/src/percentile_cont.rs index ae49ac798304..1ea896d4b36f 100644 --- a/datafusion/functions-aggregate/src/percentile_cont.rs +++ b/datafusion/functions-aggregate/src/percentile_cont.rs @@ -35,8 +35,8 @@ use arrow::array::ArrowNativeTypeOp; use crate::min_max::{max_udaf, min_udaf}; use datafusion_common::{ - assert_eq_or_internal_err, internal_datafusion_err, plan_err, DataFusionError, - Result, ScalarValue, + assert_eq_or_internal_err, internal_datafusion_err, plan_err, + utils::take_function_args, DataFusionError, Result, ScalarValue, }; use datafusion_expr::type_coercion::aggregates::NUMERICS; use datafusion_expr::utils::format_state_name; @@ -377,8 +377,6 @@ impl AggregateUDFImpl for PercentileCont { } } -const PERCENTILE_LITERAL_EPSILON: f64 = 1e-12; - #[derive(Clone, Copy)] enum PercentileRewriteTarget { Min, @@ -395,40 +393,54 @@ fn simplify_percentile_cont_aggregate( let [value, percentile] = take_function_args("percentile_cont", ¶ms.args)?; - let percentile_value = match extract_percentile_literal(¶ms.args[1]) { - Some(value) if (0.0..=1.0).contains(&value) => value, - _ => return Ok(original_expr), - }; - let is_descending = params .order_by .first() .map(|sort| !sort.asc) .unwrap_or(false); - let rewrite_target = match classify_rewrite_target(percentile_value, is_descending) { - Some(target) => target, - None => return Ok(original_expr), + let rewrite_target = match extract_percentile_literal(percentile) { + Some(0.0) => { + if is_descending { + PercentileRewriteTarget::Max + } else { + PercentileRewriteTarget::Min + } + } + Some(1.0) => { + if is_descending { + PercentileRewriteTarget::Min + } else { + PercentileRewriteTarget::Max + } + } + _ => return Ok(original_expr), }; - let value_expr = params.args[0].clone(); - let input_type = match info.get_data_type(&value_expr)?; - - let expected_return_type = match percentile_cont_result_type(&input_type) { - Some(data_type) => data_type, - None => return Ok(original_expr), + let input_type = match info.get_data_type(value) { + Ok(data_type) => data_type, + Err(_) => return Ok(original_expr), }; - let udaf = match rewrite_target { - PercentileRewriteTarget::Min => min_udaf(), - PercentileRewriteTarget::Max => max_udaf(), - }; + let expected_return_type = + match percentile_cont_udaf().return_type(std::slice::from_ref(&input_type)) { + Ok(data_type) => data_type, + Err(_) => return Ok(original_expr), + }; - let mut agg_arg = value_expr; + let mut agg_arg = value.clone(); if expected_return_type != input_type { + // min/max return the same type as their input. percentile_cont widens + // integers to Float64 (and preserves float/decimal types), so ensure the + // rewritten aggregate sees an input of the final return type. agg_arg = Expr::Cast(Cast::new(Box::new(agg_arg), expected_return_type.clone())); } + let udaf = match rewrite_target { + PercentileRewriteTarget::Min => min_udaf(), + PercentileRewriteTarget::Max => max_udaf(), + }; + let rewritten = Expr::AggregateFunction(AggregateFunction::new_udf( udaf, vec![agg_arg], @@ -440,79 +452,9 @@ fn simplify_percentile_cont_aggregate( Ok(rewritten) } -fn classify_rewrite_target( - percentile_value: f64, - is_descending: bool, -) -> Option { - if nearly_equals_fraction(percentile_value, 0.0) { - Some(if is_descending { - PercentileRewriteTarget::Max - } else { - PercentileRewriteTarget::Min - }) - } else if nearly_equals_fraction(percentile_value, 1.0) { - Some(if is_descending { - PercentileRewriteTarget::Min - } else { - PercentileRewriteTarget::Max - }) - } else { - None - } -} - -fn nearly_equals_fraction(value: f64, target: f64) -> bool { - (value - target).abs() < PERCENTILE_LITERAL_EPSILON -} - -fn percentile_cont_result_type(input_type: &DataType) -> Option { - if !input_type.is_numeric() { - return None; - } - - let result_type = match input_type { - DataType::Float16 | DataType::Float32 | DataType::Float64 => input_type.clone(), - DataType::Decimal32(_, _) - | DataType::Decimal64(_, _) - | DataType::Decimal128(_, _) - | DataType::Decimal256(_, _) => input_type.clone(), - DataType::UInt8 - | DataType::UInt16 - | DataType::UInt32 - | DataType::UInt64 - | DataType::Int8 - | DataType::Int16 - | DataType::Int32 - | DataType::Int64 => DataType::Float64, - _ => return None, - }; - - Some(result_type) -} - fn extract_percentile_literal(expr: &Expr) -> Option { match expr { - Expr::Literal(value, _) => literal_scalar_to_f64(value), - Expr::Alias(alias) => extract_percentile_literal(alias.expr.as_ref()), - Expr::Cast(cast) => extract_percentile_literal(cast.expr.as_ref()), - Expr::TryCast(cast) => extract_percentile_literal(cast.expr.as_ref()), - _ => None, - } -} - -fn literal_scalar_to_f64(value: &ScalarValue) -> Option { - match value { - ScalarValue::Float64(Some(v)) => Some(*v), - ScalarValue::Float32(Some(v)) => Some(*v as f64), - ScalarValue::Float16(Some(v)) => Some(v.to_f64()), - ScalarValue::Int64(Some(v)) => Some(*v as f64), - ScalarValue::Int32(Some(v)) => Some(*v as f64), - ScalarValue::Int16(Some(v)) => Some(*v as f64), - ScalarValue::Int8(Some(v)) => Some(*v as f64), - ScalarValue::UInt64(Some(v)) => Some(*v as f64), - ScalarValue::UInt32(Some(v)) => Some(*v as f64), - ScalarValue::UInt16(Some(v)) => Some(*v as f64), - ScalarValue::UInt8(Some(v)) => Some(*v as f64), + Expr::Literal(ScalarValue::Float64(Some(value)), _) => Some(*value), _ => None, } } @@ -910,80 +852,3 @@ fn calculate_percentile( } } } - -#[cfg(test)] -mod tests { - use super::*; - use crate::min_max::{max, min}; - use arrow::datatypes::{DataType, Field, Schema}; - use datafusion_common::{DFSchema, Result}; - use datafusion_expr::{ - col, execution_props::ExecutionProps, expr::Cast as ExprCast, expr::Sort, lit, - simplify::SimplifyContext, Expr, - }; - use std::sync::Arc; - - fn run_simplifier(expr: &Expr, schema: Arc) -> Result { - let props = ExecutionProps::new(); - let context = SimplifyContext::new(&props).with_schema(schema); - let simplifier = percentile_cont_udaf() - .simplify() - .expect("simplifier should be available"); - - match expr.clone() { - Expr::AggregateFunction(agg) => simplifier(agg, &context), - _ => panic!("expected aggregate expression"), - } - } - - fn schema_for(field: Field) -> Arc { - Arc::new(DFSchema::try_from(Schema::new(vec![field])).unwrap()) - } - - #[test] - fn simplify_percentile_cont_zero_to_min_with_cast() -> Result<()> { - let schema = schema_for(Field::new("value", DataType::Int32, true)); - let expr = percentile_cont(Sort::new(col("value"), true, true), lit(0_f64)); - let simplified = run_simplifier(&expr, Arc::clone(&schema))?; - - let casted_value = - Expr::Cast(ExprCast::new(Box::new(col("value")), DataType::Float64)); - let expected = min(casted_value); - - assert_eq!(simplified, expected); - Ok(()) - } - - #[test] - fn simplify_percentile_cont_zero_desc_to_max() -> Result<()> { - let schema = schema_for(Field::new("value", DataType::Float64, true)); - let expr = percentile_cont(Sort::new(col("value"), false, true), lit(0_f64)); - let simplified = run_simplifier(&expr, schema)?; - - let expected = max(col("value")); - assert_eq!(simplified, expected); - Ok(()) - } - - #[test] - fn simplify_percentile_cont_one_desc_to_min() -> Result<()> { - let schema = schema_for(Field::new("value", DataType::Float64, true)); - let expr = percentile_cont(Sort::new(col("value"), false, true), lit(1_f64)); - let simplified = run_simplifier(&expr, schema)?; - - let expected = min(col("value")); - assert_eq!(simplified, expected); - Ok(()) - } - - #[test] - fn percentile_cont_not_simplified_for_other_percentiles() -> Result<()> { - let schema = schema_for(Field::new("value", DataType::Float64, true)); - let expr = percentile_cont(Sort::new(col("value"), true, true), lit(0.5_f64)); - let expected = expr.clone(); - - let simplified = run_simplifier(&expr, schema)?; - assert_eq!(simplified, expected); - Ok(()) - } -}