Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 19 additions & 4 deletions crates/core/src/functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -696,9 +696,10 @@ aggregate_function!(var_pop);
aggregate_function!(approx_distinct);
aggregate_function!(approx_median);

// Code is commented out since grouping is not yet implemented
// https://github.com/apache/datafusion-python/issues/861
// aggregate_function!(grouping);
// The grouping function's physical plan is not implemented, but the
// ResolveGroupingFunction analyzer rule rewrites it before the physical
// planner sees it, so it works correctly at runtime.
aggregate_function!(grouping);

#[pyfunction]
#[pyo3(signature = (sort_expression, percentile, num_centroids=None, filter=None))]
Expand Down Expand Up @@ -736,6 +737,19 @@ pub fn approx_percentile_cont_with_weight(
add_builder_fns_to_aggregate(agg_fn, None, filter, None, None)
}

#[pyfunction]
#[pyo3(signature = (sort_expression, percentile, filter=None))]
pub fn percentile_cont(
sort_expression: PySortExpr,
percentile: f64,
filter: Option<PyExpr>,
) -> PyDataFusionResult<PyExpr> {
let agg_fn =
functions_aggregate::expr_fn::percentile_cont(sort_expression.sort, lit(percentile));

add_builder_fns_to_aggregate(agg_fn, None, filter, None, None)
}

// We handle last_value explicitly because the signature expects an order_by
// https://github.com/apache/datafusion/issues/12376
#[pyfunction]
Expand Down Expand Up @@ -936,6 +950,7 @@ pub(crate) fn init_module(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_wrapped(wrap_pyfunction!(approx_median))?;
m.add_wrapped(wrap_pyfunction!(approx_percentile_cont))?;
m.add_wrapped(wrap_pyfunction!(approx_percentile_cont_with_weight))?;
m.add_wrapped(wrap_pyfunction!(percentile_cont))?;
m.add_wrapped(wrap_pyfunction!(range))?;
m.add_wrapped(wrap_pyfunction!(array_agg))?;
m.add_wrapped(wrap_pyfunction!(arrow_typeof))?;
Expand Down Expand Up @@ -981,7 +996,7 @@ pub(crate) fn init_module(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_wrapped(wrap_pyfunction!(floor))?;
m.add_wrapped(wrap_pyfunction!(from_unixtime))?;
m.add_wrapped(wrap_pyfunction!(gcd))?;
// m.add_wrapped(wrap_pyfunction!(grouping))?;
m.add_wrapped(wrap_pyfunction!(grouping))?;
m.add_wrapped(wrap_pyfunction!(in_list))?;
m.add_wrapped(wrap_pyfunction!(initcap))?;
m.add_wrapped(wrap_pyfunction!(isnan))?;
Expand Down
77 changes: 77 additions & 0 deletions python/datafusion/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,7 @@
"floor",
"from_unixtime",
"gcd",
"grouping",
"in_list",
"initcap",
"isnan",
Expand Down Expand Up @@ -216,6 +217,7 @@
"order_by",
"overlay",
"percent_rank",
"percentile_cont",
"pi",
"pow",
"power",
Expand Down Expand Up @@ -286,6 +288,7 @@
"uuid",
"var",
"var_pop",
"var_population",
"var_samp",
"var_sample",
"when",
Expand Down Expand Up @@ -3523,6 +3526,47 @@ def approx_percentile_cont_with_weight(
)


def percentile_cont(
sort_expression: Expr | SortExpr,
percentile: float,
filter: Expr | None = None,
) -> Expr:
"""Computes the exact percentile of input values using continuous interpolation.

Unlike :py:func:`approx_percentile_cont`, this function computes the exact
percentile value rather than an approximation.

If using the builder functions described in ref:`_aggregation` this function ignores
the options ``order_by``, ``null_treatment``, and ``distinct``.

Args:
sort_expression: Values for which to find the percentile
percentile: This must be between 0.0 and 1.0, inclusive
filter: If provided, only compute against rows for which the filter is True

Examples:
>>> ctx = dfn.SessionContext()
>>> df = ctx.from_pydict({"a": [1.0, 2.0, 3.0, 4.0, 5.0]})
>>> result = df.aggregate(
... [], [dfn.functions.percentile_cont(
... dfn.col("a"), 0.5
... ).alias("v")])
>>> result.collect_column("v")[0].as_py()
3.0

>>> result = df.aggregate(
... [], [dfn.functions.percentile_cont(
... dfn.col("a"), 0.5,
... filter=dfn.col("a") > dfn.lit(1.0),
... ).alias("v")])
>>> result.collect_column("v")[0].as_py()
3.5
"""
sort_expr_raw = sort_or_default(sort_expression)
filter_raw = filter.expr if filter is not None else None
return Expr(f.percentile_cont(sort_expr_raw, percentile, filter=filter_raw))


def array_agg(
expression: Expr,
distinct: bool = False,
Expand Down Expand Up @@ -3581,6 +3625,30 @@ def array_agg(
)


def grouping(
expression: Expr,
distinct: bool | None = None,
filter: Expr | None = None,
) -> Expr:
"""Returns 1 if the data is aggregated across the specified column, or 0 otherwise.

This function is used with ``GROUPING SETS``, ``CUBE``, or ``ROLLUP`` to
distinguish between aggregated and non-aggregated rows. In a regular
``GROUP BY`` without grouping sets, it always returns 0.

Note: The ``grouping`` aggregate function is rewritten by the query
optimizer before execution, so it works correctly even though its
physical plan is not directly implemented.

Args:
expression: The column to check grouping status for
distinct: If True, compute on distinct values only
filter: If provided, only compute against rows for which the filter is True
"""
filter_raw = filter.expr if filter is not None else None
return Expr(f.grouping(expression.expr, distinct=distinct, filter=filter_raw))


def avg(
expression: Expr,
filter: Expr | None = None,
Expand Down Expand Up @@ -4052,6 +4120,15 @@ def var_pop(expression: Expr, filter: Expr | None = None) -> Expr:
return Expr(f.var_pop(expression.expr, filter=filter_raw))


def var_population(expression: Expr, filter: Expr | None = None) -> Expr:
"""Computes the population variance of the argument.

See Also:
This is an alias for :py:func:`var_pop`.
"""
return var_pop(expression, filter)


def var_samp(expression: Expr, filter: Expr | None = None) -> Expr:
"""Computes the sample variance of the argument.

Expand Down
46 changes: 46 additions & 0 deletions python/tests/test_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -1435,3 +1435,49 @@ def test_coalesce(df):
assert result.column(0) == pa.array(
["Hello", "fallback", "!"], type=pa.string_view()
)


def test_percentile_cont():
ctx = SessionContext()
df = ctx.from_pydict({"a": [1.0, 2.0, 3.0, 4.0, 5.0]})
result = df.aggregate(
[], [f.percentile_cont(column("a"), 0.5).alias("v")]
).collect()[0]
assert result.column(0)[0].as_py() == 3.0


def test_percentile_cont_with_filter():
ctx = SessionContext()
df = ctx.from_pydict({"a": [1.0, 2.0, 3.0, 4.0, 5.0]})
result = df.aggregate(
[],
[
f.percentile_cont(
column("a"), 0.5, filter=column("a") > literal(1.0)
).alias("v")
],
).collect()[0]
assert result.column(0)[0].as_py() == 3.5


def test_grouping():
ctx = SessionContext()
df = ctx.from_pydict({"a": [1, 1, 2], "b": [10, 20, 30]})
# In a simple GROUP BY (no grouping sets), grouping() returns 0 for all rows.
# Note: grouping() must not be aliased directly in the aggregate expression list
# due to an upstream DataFusion analyzer limitation (the ResolveGroupingFunction
# rule doesn't unwrap Alias nodes). Apply aliases via a follow-up select instead.
result = df.aggregate(
[column("a")], [f.grouping(column("a")), f.sum(column("b")).alias("s")]
).collect()
grouping_col = pa.concat_arrays([batch.column(1) for batch in result]).to_pylist()
assert all(v == 0 for v in grouping_col)


def test_var_population():
ctx = SessionContext()
df = ctx.from_pydict({"a": [-1.0, 0.0, 2.0]})
result = df.aggregate([], [f.var_population(column("a")).alias("v")]).collect()[0]
# var_population is an alias for var_pop
expected = df.aggregate([], [f.var_pop(column("a")).alias("v")]).collect()[0]
assert abs(result.column(0)[0].as_py() - expected.column(0)[0].as_py()) < 1e-10
Loading