diff --git a/datafusion/core/tests/dataframe/dataframe_functions.rs b/datafusion/core/tests/dataframe/dataframe_functions.rs index 7a0e9888a61c..1bd90fce839d 100644 --- a/datafusion/core/tests/dataframe/dataframe_functions.rs +++ b/datafusion/core/tests/dataframe/dataframe_functions.rs @@ -360,7 +360,7 @@ async fn test_fn_approx_median() -> Result<()> { #[tokio::test] async fn test_fn_approx_percentile_cont() -> Result<()> { - let expr = approx_percentile_cont(col("b"), lit(0.5)); + let expr = approx_percentile_cont(col("b"), lit(0.5), None); let expected = [ "+---------------------------------------------+", @@ -381,7 +381,7 @@ async fn test_fn_approx_percentile_cont() -> Result<()> { None::<&str>, "arg_2".to_string(), )); - let expr = approx_percentile_cont(col("b"), alias_expr); + let expr = approx_percentile_cont(col("b"), alias_expr, None); let df = create_test_table().await?; let expected = [ "+--------------------------------------+", @@ -394,6 +394,21 @@ async fn test_fn_approx_percentile_cont() -> Result<()> { assert_batches_eq!(expected, &batches); + // with number of centroids set + let expr = approx_percentile_cont(col("b"), lit(0.5), Some(lit(2))); + let expected = [ + "+------------------------------------------------------+", + "| approx_percentile_cont(test.b,Float64(0.5),Int32(2)) |", + "+------------------------------------------------------+", + "| 30 |", + "+------------------------------------------------------+", + ]; + + let df = create_test_table().await?; + let batches = df.aggregate(vec![], vec![expr]).unwrap().collect().await?; + + assert_batches_eq!(expected, &batches); + Ok(()) } diff --git a/datafusion/functions-aggregate/src/approx_percentile_cont.rs b/datafusion/functions-aggregate/src/approx_percentile_cont.rs index af2a26fd05ec..ffa623c13b0b 100644 --- a/datafusion/functions-aggregate/src/approx_percentile_cont.rs +++ b/datafusion/functions-aggregate/src/approx_percentile_cont.rs @@ -46,13 +46,21 @@ use datafusion_physical_expr_common::aggregate::tdigest::{ }; use datafusion_physical_expr_common::utils::limited_convert_logical_expr_to_physical_expr_with_dfschema; -make_udaf_expr_and_func!( - ApproxPercentileCont, - approx_percentile_cont, - expression percentile, - "Computes the approximate percentile continuous of a set of numbers", - approx_percentile_cont_udaf -); +create_func!(ApproxPercentileCont, approx_percentile_cont_udaf); + +/// Computes the approximate percentile continuous of a set of numbers +pub fn approx_percentile_cont( + expression: Expr, + percentile: Expr, + centroids: Option, +) -> Expr { + let args = if let Some(centroids) = centroids { + vec![expression, percentile, centroids] + } else { + vec![expression, percentile] + }; + approx_percentile_cont_udaf().call(args) +} pub struct ApproxPercentileCont { signature: Signature, diff --git a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs index b96398ef217f..66c25f3bf382 100644 --- a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs @@ -885,7 +885,8 @@ async fn roundtrip_expr_api() -> Result<()> { stddev_pop(lit(2.2)), approx_distinct(lit(2)), approx_median(lit(2)), - approx_percentile_cont(lit(2), lit(0.5)), + approx_percentile_cont(lit(2), lit(0.5), None), + approx_percentile_cont(lit(2), lit(0.5), Some(lit(50))), approx_percentile_cont_with_weight(lit(2), lit(1), lit(0.5)), grouping(lit(1)), bit_and(lit(2)),