Skip to content

Commit

Permalink
move schema and logical ordering exprs
Browse files Browse the repository at this point in the history
Signed-off-by: jayzhan211 <jayzhan211@gmail.com>
  • Loading branch information
jayzhan211 committed Mar 7, 2024
1 parent 8592e6b commit 0f8fc24
Show file tree
Hide file tree
Showing 10 changed files with 60 additions and 45 deletions.
4 changes: 2 additions & 2 deletions datafusion-examples/examples/advanced_udaf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -89,8 +89,8 @@ impl AggregateUDFImpl for GeoMeanUdaf {
fn accumulator(
&self,
_arg: &DataType,
_sort_exprs: Vec<Expr>,
_schema: Option<&Schema>,
_sort_exprs: &[Expr],
_schema: &Schema,
) -> Result<Box<dyn Accumulator>> {
Ok(Box::new(GeometricMean::new()))
}
Expand Down
16 changes: 10 additions & 6 deletions datafusion/core/src/physical_planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1672,7 +1672,9 @@ pub fn create_aggregate_expr_with_name_and_maybe_filter(
)?),
None => None,
};
let order_by = match order_by {

let sort_exprs = order_by.clone().unwrap_or(vec![]);
let phy_order_by = match order_by {
Some(e) => Some(
e.iter()
.map(|expr| {
Expand All @@ -1691,7 +1693,7 @@ pub fn create_aggregate_expr_with_name_and_maybe_filter(
== NullTreatment::IgnoreNulls;
let (agg_expr, filter, order_by) = match func_def {
AggregateFunctionDefinition::BuiltIn(fun) => {
let ordering_reqs = order_by.clone().unwrap_or(vec![]);
let ordering_reqs = phy_order_by.clone().unwrap_or(vec![]);
let agg_expr = aggregates::create_aggregate_expr(
fun,
*distinct,
Expand All @@ -1701,27 +1703,29 @@ pub fn create_aggregate_expr_with_name_and_maybe_filter(
name,
ignore_nulls,
)?;
(agg_expr, filter, order_by)
(agg_expr, filter, phy_order_by)
}
AggregateFunctionDefinition::UDF(fun) => {
let ordering_reqs: Vec<PhysicalSortExpr> =
order_by.clone().unwrap_or(vec![]);
phy_order_by.clone().unwrap_or(vec![]);

let agg_expr = udaf::create_aggregate_expr(
fun,
&args,
&sort_exprs,
&ordering_reqs,
physical_input_schema,
name,
)?;
(agg_expr, filter, order_by)
(agg_expr, filter, phy_order_by)
}
AggregateFunctionDefinition::Name(_) => {
return internal_err!(
"Aggregate function name should have been resolved"
)
}
};
Ok((agg_expr, filter, order_by))
Ok((agg_expr, filter, phy_order_by))
}
other => internal_err!("Invalid aggregate expression '{other:?}'"),
}
Expand Down
26 changes: 9 additions & 17 deletions datafusion/core/tests/user_defined/user_defined_aggregates.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,7 @@ use datafusion::{
prelude::SessionContext,
scalar::ScalarValue,
};
use datafusion_common::{
assert_contains, cast::as_primitive_array, exec_err, Column, DataFusionError,
};
use datafusion_common::{assert_contains, cast::as_primitive_array, exec_err, Column};
use datafusion_expr::{
create_udaf, create_udaf_with_ordering, expr::Sort, AggregateUDFImpl, Expr,
GroupsAccumulator, SimpleAggregateUDF,
Expand Down Expand Up @@ -234,12 +232,9 @@ async fn simple_udaf_order() -> Result<()> {

fn create_accumulator(
data_type: &DataType,
order_by: Vec<Expr>,
schema: Option<&Schema>,
order_by: &[Expr],
schema: &Schema,
) -> Result<Box<dyn Accumulator>> {
// test with ordering so schema is required
let schema = schema.unwrap();

let mut all_sort_orders = vec![];

// Construct PhysicalSortExpr objects from Expr objects:
Expand All @@ -265,16 +260,13 @@ async fn simple_udaf_order() -> Result<()> {

let ordering_req = all_sort_orders;

let ordering_types = ordering_req
let ordering_dtypes = ordering_req
.iter()
.map(|e| e.expr.data_type(schema))
.collect::<Result<Vec<_>>>()?;

let acc = FirstValueAccumulator::try_new(
data_type,
ordering_types.as_slice(),
ordering_req,
)?;
let acc =
FirstValueAccumulator::try_new(data_type, &ordering_dtypes, ordering_req)?;
Ok(Box::new(acc))
}

Expand Down Expand Up @@ -369,7 +361,7 @@ async fn deregister_udaf() -> Result<()> {
vec![DataType::Float64],
Arc::new(DataType::Float64),
Volatility::Immutable,
Arc::new(|_| Ok(Box::<AvgAccumulator>::default())),
Arc::new(|_, _, _| Ok(Box::<AvgAccumulator>::default())),
Arc::new(vec![DataType::UInt64, DataType::Float64]),
);

Expand Down Expand Up @@ -791,8 +783,8 @@ impl AggregateUDFImpl for TestGroupsAccumulator {
fn accumulator(
&self,
_arg: &DataType,
_sort_exprs: Vec<Expr>,
_schema: Option<&Schema>,
_sort_exprs: &[Expr],
_schema: &Schema,
) -> Result<Box<dyn Accumulator>> {
// should use groups accumulator
panic!("accumulator shouldn't invoke");
Expand Down
8 changes: 4 additions & 4 deletions datafusion/expr/src/expr_fn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1125,8 +1125,8 @@ impl AggregateUDFImpl for SimpleAggregateUDF {
fn accumulator(
&self,
arg: &DataType,
sort_exprs: Vec<Expr>,
schema: Option<&Schema>,
sort_exprs: &[Expr],
schema: &Schema,
) -> Result<Box<dyn crate::Accumulator>> {
(self.accumulator)(arg, sort_exprs, schema)
}
Expand Down Expand Up @@ -1206,8 +1206,8 @@ impl AggregateUDFImpl for SimpleOrderedAggregateUDF {
fn accumulator(
&self,
arg: &DataType,
sort_exprs: Vec<Expr>,
schema: Option<&Schema>,
sort_exprs: &[Expr],
schema: &Schema,
) -> Result<Box<dyn crate::Accumulator>> {
(self.accumulator)(arg, sort_exprs, schema)
}
Expand Down
4 changes: 1 addition & 3 deletions datafusion/expr/src/function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,7 @@ pub type ReturnTypeFunction =
/// Factory that returns an accumulator for the given aggregate, given
/// its return datatype, the sorting expressions and the schema for ordering.
pub type AccumulatorFactoryFunction = Arc<
dyn Fn(&DataType, Vec<Expr>, Option<&Schema>) -> Result<Box<dyn Accumulator>>
+ Send
+ Sync,
dyn Fn(&DataType, &[Expr], &Schema) -> Result<Box<dyn Accumulator>> + Send + Sync,
>;

/// Factory that creates a PartitionEvaluator for the given window
Expand Down
19 changes: 12 additions & 7 deletions datafusion/expr/src/udaf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -154,9 +154,14 @@ impl AggregateUDF {
}

/// Return an accumulator the given aggregate, given its return datatype
pub fn accumulator(&self, return_type: &DataType) -> Result<Box<dyn Accumulator>> {
let sort_exprs = self.inner.sort_exprs();
let schema = self.inner.schema();
pub fn accumulator(
&self,
return_type: &DataType,
sort_exprs: &[Expr],
schema: &Schema,
) -> Result<Box<dyn Accumulator>> {
// let sort_exprs = self.inner.sort_exprs();
// let schema = self.inner.schema();
self.inner.accumulator(return_type, sort_exprs, schema)
}

Expand Down Expand Up @@ -266,8 +271,8 @@ pub trait AggregateUDFImpl: Debug + Send + Sync {
fn accumulator(
&self,
arg: &DataType,
sort_exprs: Vec<Expr>,
schema: Option<&Schema>,
sort_exprs: &[Expr],
schema: &Schema,
) -> Result<Box<dyn Accumulator>>;

/// Return the type used to serialize the [`Accumulator`]'s intermediate state.
Expand Down Expand Up @@ -348,8 +353,8 @@ impl AggregateUDFImpl for AggregateUDFLegacyWrapper {
fn accumulator(
&self,
arg: &DataType,
sort_exprs: Vec<Expr>,
schema: Option<&Schema>,
sort_exprs: &[Expr],
schema: &Schema,
) -> Result<Box<dyn Accumulator>> {
(self.accumulator)(arg, sort_exprs, schema)
}
Expand Down
22 changes: 17 additions & 5 deletions datafusion/physical-plan/src/udaf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

//! This module contains functions and structs supporting user-defined aggregate functions.
use datafusion_expr::GroupsAccumulator;
use datafusion_expr::{Expr, GroupsAccumulator};
use fmt::Debug;
use std::any::Any;
use std::fmt;
Expand All @@ -37,20 +37,23 @@ use std::sync::Arc;
pub fn create_aggregate_expr(
fun: &AggregateUDF,
input_phy_exprs: &[Arc<dyn PhysicalExpr>],
sort_exprs: &[Expr],
ordering_req: &[PhysicalSortExpr],
input_schema: &Schema,
schema: &Schema,
name: impl Into<String>,
) -> Result<Arc<dyn AggregateExpr>> {
let input_exprs_types = input_phy_exprs
.iter()
.map(|arg| arg.data_type(input_schema))
.map(|arg| arg.data_type(schema))
.collect::<Result<Vec<_>>>()?;

Ok(Arc::new(AggregateFunctionExpr {
fun: fun.clone(),
args: input_phy_exprs.to_vec(),
data_type: fun.return_type(&input_exprs_types)?,
name: name.into(),
schema: schema.clone(),
sort_exprs: sort_exprs.to_vec(),
ordering_req: ordering_req.to_vec(),
}))
}
Expand All @@ -63,6 +66,10 @@ pub struct AggregateFunctionExpr {
/// Output / return type of this aggregate
data_type: DataType,
name: String,
schema: Schema,
// The logical order by expressions
sort_exprs: Vec<Expr>,
// The physical order by expressions
ordering_req: LexOrdering,
}

Expand Down Expand Up @@ -106,11 +113,16 @@ impl AggregateExpr for AggregateFunctionExpr {
}

fn create_accumulator(&self) -> Result<Box<dyn Accumulator>> {
self.fun.accumulator(&self.data_type)
self.fun
.accumulator(&self.data_type, self.sort_exprs.as_slice(), &self.schema)
}

fn create_sliding_accumulator(&self) -> Result<Box<dyn Accumulator>> {
let accumulator = self.fun.accumulator(&self.data_type)?;
let accumulator = self.fun.accumulator(
&self.data_type,
self.sort_exprs.as_slice(),
&self.schema,
)?;

// Accumulators that have window frame startings different
// than `UNBOUNDED PRECEDING`, such as `1 PRECEEDING`, need to
Expand Down
3 changes: 3 additions & 0 deletions datafusion/physical-plan/src/windows/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -93,10 +93,13 @@ pub fn create_window_expr(
}
WindowFunctionDefinition::AggregateUDF(fun) => {
// TODO: Ordering not supported for Window UDFs yet
let sort_exprs = &[];
let ordering_req = &[];

let aggregate = udaf::create_aggregate_expr(
fun.as_ref(),
args,
sort_exprs,
ordering_req,
input_schema,
name,
Expand Down
2 changes: 1 addition & 1 deletion datafusion/proto/src/physical_plan/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -474,7 +474,7 @@ impl AsExecutionPlan for PhysicalPlanNode {
AggregateFunction::UserDefinedAggrFunction(udaf_name) => {
let agg_udf = registry.udaf(udaf_name)?;
let ordering_req = &[];
udaf::create_aggregate_expr(agg_udf.as_ref(), &input_phy_expr, ordering_req, &physical_schema, name)
udaf::create_aggregate_expr(agg_udf.as_ref(), &input_phy_expr, &[], ordering_req, &physical_schema, name)
}
}
}).transpose()?.ok_or_else(|| {
Expand Down
1 change: 1 addition & 0 deletions datafusion/proto/tests/cases/roundtrip_physical_plan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -427,6 +427,7 @@ fn roundtrip_aggregate_udaf() -> Result<()> {
&udaf,
&[col("b", &schema)?],
&[],
&[],
&schema,
"example_agg",
)?];
Expand Down

0 comments on commit 0f8fc24

Please sign in to comment.