Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions datafusion/core/src/physical_plan/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,8 @@ use arrow::datatypes::{Schema, SchemaRef};
use async_trait::async_trait;
use datafusion_common::{DFSchema, ScalarValue};
use datafusion_expr::expr::{
self, AggregateFunction, Between, BinaryExpr, Cast, GetIndexedField, GroupingSet,
Like, ScalarUDF, TryCast, WindowFunction,
self, AggregateFunction, AggregateUDF, Between, BinaryExpr, Cast, GetIndexedField,
GroupingSet, Like, ScalarUDF, TryCast, WindowFunction,
};
use datafusion_expr::expr_rewriter::unnormalize_cols;
use datafusion_expr::logical_plan::builder::wrap_projection_for_join_if_necessary;
Expand Down Expand Up @@ -199,7 +199,7 @@ fn create_physical_name(e: &Expr, is_first_expr: bool) -> Result<String> {
args,
..
}) => create_function_physical_name(&fun.to_string(), *distinct, args),
Expr::AggregateUDF { fun, args, filter } => {
Expr::AggregateUDF(AggregateUDF { fun, args, filter }) => {
if filter.is_some() {
return Err(DataFusionError::Execution(
"aggregate expression with filter is not supported".to_string(),
Expand Down Expand Up @@ -1666,7 +1666,7 @@ pub fn create_aggregate_expr_with_name_and_maybe_filter(
);
Ok((agg_expr?, filter))
}
Expr::AggregateUDF { fun, args, filter } => {
Expr::AggregateUDF(AggregateUDF { fun, args, filter }) => {
let args = args
.iter()
.map(|e| {
Expand Down
38 changes: 26 additions & 12 deletions datafusion/expr/src/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,10 @@ use crate::aggregate_function;
use crate::built_in_function;
use crate::expr_fn::binary_expr;
use crate::logical_plan::Subquery;
use crate::udaf;
use crate::utils::{expr_to_columns, find_out_reference_exprs};
use crate::window_frame;
use crate::window_function;
use crate::AggregateUDF;
use crate::Operator;
use arrow::datatypes::DataType;
use datafusion_common::Result;
Expand Down Expand Up @@ -159,14 +159,7 @@ pub enum Expr {
/// Represents the call of a window function with arguments.
WindowFunction(WindowFunction),
/// aggregate function
AggregateUDF {
/// The function
fun: Arc<AggregateUDF>,
/// List of expressions to feed to the functions as arguments
args: Vec<Expr>,
/// Optional filter applied prior to aggregating
filter: Option<Box<Expr>>,
},
AggregateUDF(AggregateUDF),
/// Returns whether the list contains the expr value.
InList {
/// The expression to compare
Expand Down Expand Up @@ -507,6 +500,27 @@ impl WindowFunction {
}
}

#[derive(Clone, PartialEq, Eq, Hash)]
pub struct AggregateUDF {
/// The function
pub fun: Arc<udaf::AggregateUDF>,
/// List of expressions to feed to the functions as arguments
pub args: Vec<Expr>,
/// Optional filter
pub filter: Option<Box<Expr>>,
}

impl AggregateUDF {
/// Create a new AggregateUDF expression
pub fn new(
fun: Arc<udaf::AggregateUDF>,
args: Vec<Expr>,
filter: Option<Box<Expr>>,
) -> Self {
Self { fun, args, filter }
}
}

/// Grouping sets
/// See <https://www.postgresql.org/docs/current/queries-table-expressions.html#QUERIES-GROUPING-SETS>
/// for Postgres definition.
Expand Down Expand Up @@ -988,12 +1002,12 @@ impl fmt::Debug for Expr {
}
Ok(())
}
Expr::AggregateUDF {
Expr::AggregateUDF(AggregateUDF {
fun,
ref args,
filter,
..
} => {
}) => {
fmt_function(f, &fun.name, false, args, false)?;
if let Some(fe) = filter {
write!(f, " FILTER (WHERE {fe})")?;
Expand Down Expand Up @@ -1344,7 +1358,7 @@ fn create_name(e: &Expr) -> Result<String> {
Ok(name)
}
}
Expr::AggregateUDF { fun, args, filter } => {
Expr::AggregateUDF(AggregateUDF { fun, args, filter }) => {
let mut names = Vec::with_capacity(args.len());
for e in args {
names.push(create_name(e)?);
Expand Down
6 changes: 3 additions & 3 deletions datafusion/expr/src/expr_schema.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@

use super::{Between, Expr, Like};
use crate::expr::{
AggregateFunction, BinaryExpr, Cast, GetIndexedField, ScalarFunction, ScalarUDF,
Sort, TryCast, WindowFunction,
AggregateFunction, AggregateUDF, BinaryExpr, Cast, GetIndexedField, ScalarFunction,
ScalarUDF, Sort, TryCast, WindowFunction,
};
use crate::field_util::get_indexed_field;
use crate::type_coercion::binary::get_result_type;
Expand Down Expand Up @@ -123,7 +123,7 @@ impl ExprSchemable for Expr {
.collect::<Result<Vec<_>>>()?;
aggregate_function::return_type(fun, &data_types)
}
Expr::AggregateUDF { fun, args, .. } => {
Expr::AggregateUDF(AggregateUDF { fun, args, .. }) => {
let data_types = args
.iter()
.map(|e| e.get_type(schema))
Expand Down
18 changes: 10 additions & 8 deletions datafusion/expr/src/tree_node/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@
//! Tree node implementation for logical expr

use crate::expr::{
AggregateFunction, Between, BinaryExpr, Case, Cast, GetIndexedField, GroupingSet,
Like, ScalarFunction, ScalarUDF, Sort, TryCast, WindowFunction,
AggregateFunction, AggregateUDF, Between, BinaryExpr, Case, Cast, GetIndexedField,
GroupingSet, Like, ScalarFunction, ScalarUDF, Sort, TryCast, WindowFunction,
};
use crate::Expr;
use datafusion_common::tree_node::VisitRecursion;
Expand Down Expand Up @@ -97,7 +97,7 @@ impl TreeNode for Expr {
expr_vec
}
Expr::AggregateFunction(AggregateFunction { args, filter, .. })
| Expr::AggregateUDF { args, filter, .. } => {
| Expr::AggregateUDF ( AggregateUDF{args, filter, .. }) => {
let mut expr_vec = args.clone();

if let Some(f) = filter {
Expand Down Expand Up @@ -313,11 +313,13 @@ impl TreeNode for Expr {
))
}
},
Expr::AggregateUDF { args, fun, filter } => Expr::AggregateUDF {
args: transform_vec(args, &mut transform)?,
fun,
filter: transform_option_box(filter, &mut transform)?,
},
Expr::AggregateUDF(AggregateUDF { args, fun, filter }) => {
Expr::AggregateUDF(AggregateUDF {
args: transform_vec(args, &mut transform)?,
fun,
filter: transform_option_box(filter, &mut transform)?,
})
}
Expr::InList {
expr,
list,
Expand Down
4 changes: 2 additions & 2 deletions datafusion/expr/src/udaf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -86,10 +86,10 @@ impl AggregateUDF {
/// creates a logical expression with a call of the UDAF
/// This utility allows using the UDAF without requiring access to the registry.
pub fn call(&self, args: Vec<Expr>) -> Expr {
Expr::AggregateUDF {
Expr::AggregateUDF(crate::expr::AggregateUDF {
fun: Arc::new(self.clone()),
args,
filter: None,
}
})
}
}
29 changes: 13 additions & 16 deletions datafusion/optimizer/src/analyzer/type_coercion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -404,17 +404,14 @@ impl TreeNodeRewriter for TypeCoercionRewriter {
));
Ok(expr)
}
Expr::AggregateUDF { fun, args, filter } => {
Expr::AggregateUDF(expr::AggregateUDF { fun, args, filter }) => {
let new_expr = coerce_arguments_for_signature(
args.as_slice(),
&self.schema,
&fun.signature,
)?;
let expr = Expr::AggregateUDF {
fun,
args: new_expr,
filter,
};
let expr =
Expr::AggregateUDF(expr::AggregateUDF::new(fun, new_expr, filter));
Ok(expr)
}
Expr::WindowFunction(WindowFunction {
Expand Down Expand Up @@ -883,11 +880,11 @@ mod test {
}),
Arc::new(vec![DataType::UInt64, DataType::Float64]),
);
let udaf = Expr::AggregateUDF {
fun: Arc::new(my_avg),
args: vec![lit(10i64)],
filter: None,
};
let udaf = Expr::AggregateUDF(expr::AggregateUDF::new(
Arc::new(my_avg),
vec![lit(10i64)],
None,
));
let plan = LogicalPlan::Projection(Projection::try_new(vec![udaf], empty)?);
let expected = "Projection: MY_AVG(CAST(Int64(10) AS Float64))\n EmptyRelation";
assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), &plan, expected)
Expand All @@ -913,11 +910,11 @@ mod test {
&accumulator,
&state_type,
);
let udaf = Expr::AggregateUDF {
fun: Arc::new(my_avg),
args: vec![lit("10")],
filter: None,
};
let udaf = Expr::AggregateUDF(expr::AggregateUDF::new(
Arc::new(my_avg),
vec![lit("10")],
None,
));
let plan = LogicalPlan::Projection(Projection::try_new(vec![udaf], empty)?);
let err = assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), &plan, "")
.err()
Expand Down
23 changes: 12 additions & 11 deletions datafusion/optimizer/src/common_subexpr_eliminate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ use std::collections::{BTreeSet, HashMap};
use std::sync::Arc;

use arrow::datatypes::DataType;

use datafusion_common::tree_node::{
RewriteRecursion, TreeNode, TreeNodeRewriter, TreeNodeVisitor, VisitRecursion,
};
Expand Down Expand Up @@ -870,16 +869,18 @@ mod test {
let accumulator: AccumulatorFunctionImplementation =
Arc::new(|_| unimplemented!());
let state_type: StateTypeFunction = Arc::new(|_| unimplemented!());
let udf_agg = |inner: Expr| Expr::AggregateUDF {
fun: Arc::new(AggregateUDF::new(
"my_agg",
&Signature::exact(vec![DataType::UInt32], Volatility::Stable),
&return_type,
&accumulator,
&state_type,
)),
args: vec![inner],
filter: None,
let udf_agg = |inner: Expr| {
Expr::AggregateUDF(datafusion_expr::expr::AggregateUDF::new(
Arc::new(AggregateUDF::new(
"my_agg",
&Signature::exact(vec![DataType::UInt32], Volatility::Stable),
&return_type,
&accumulator,
&state_type,
)),
vec![inner],
None,
))
};

// test: common aggregates
Expand Down
12 changes: 5 additions & 7 deletions datafusion/proto/src/logical_plan/from_proto.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1377,16 +1377,14 @@ pub fn parse_expr(
ExprType::AggregateUdfExpr(pb) => {
let agg_fn = registry.udaf(pb.fun_name.as_str())?;

Ok(Expr::AggregateUDF {
fun: agg_fn,
args: pb
.args
Ok(Expr::AggregateUDF(expr::AggregateUDF::new(
agg_fn,
pb.args
.iter()
.map(|expr| parse_expr(expr, registry))
.collect::<Result<Vec<_>, Error>>()?,
filter: parse_optional_expr(pb.filter.as_deref(), registry)?
.map(Box::new),
})
parse_optional_expr(pb.filter.as_deref(), registry)?.map(Box::new),
)))
}

ExprType::GroupingSet(GroupingSetNode { expr }) => {
Expand Down
10 changes: 5 additions & 5 deletions datafusion/proto/src/logical_plan/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2602,11 +2602,11 @@ mod roundtrip_tests {
Arc::new(vec![DataType::Float64, DataType::UInt32]),
);

let test_expr = Expr::AggregateUDF {
fun: Arc::new(dummy_agg.clone()),
args: vec![lit(1.0_f64)],
filter: Some(Box::new(lit(true))),
};
let test_expr = Expr::AggregateUDF(expr::AggregateUDF::new(
Arc::new(dummy_agg.clone()),
vec![lit(1.0_f64)],
Some(Box::new(lit(true))),
));

let ctx = SessionContext::new();
ctx.register_udaf(dummy_agg);
Expand Down
2 changes: 1 addition & 1 deletion datafusion/proto/src/logical_plan/to_proto.rs
Original file line number Diff line number Diff line change
Expand Up @@ -704,7 +704,7 @@ impl TryFrom<&Expr> for protobuf::LogicalExprNode {
.collect::<Result<Vec<_>, Error>>()?,
})),
},
Expr::AggregateUDF { fun, args, filter } => Self {
Expr::AggregateUDF(expr::AggregateUDF { fun, args, filter }) => Self {
expr_type: Some(ExprType::AggregateUdfExpr(Box::new(
protobuf::AggregateUdfExprNode {
fun_name: fun.name.clone(),
Expand Down
6 changes: 1 addition & 5 deletions datafusion/sql/src/expr/function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -128,11 +128,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
if let Some(fm) = self.schema_provider.get_aggregate_meta(&name) {
let args =
self.function_args_to_expr(function.args, schema, planner_context)?;
return Ok(Expr::AggregateUDF {
fun: fm,
args,
filter: None,
});
return Ok(Expr::AggregateUDF(expr::AggregateUDF::new(fm, args, None)));
}

// Special case arrow_cast (as its type is dependent on its argument value)
Expand Down
21 changes: 11 additions & 10 deletions datafusion/sql/src/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@ use sqlparser::ast::Ident;

use datafusion_common::{DataFusionError, Result, ScalarValue};
use datafusion_expr::expr::{
AggregateFunction, Between, BinaryExpr, Case, GetIndexedField, GroupingSet, Like,
ScalarFunction, ScalarUDF, WindowFunction,
AggregateFunction, AggregateUDF, Between, BinaryExpr, Case, GetIndexedField,
GroupingSet, Like, ScalarFunction, ScalarUDF, WindowFunction,
};
use datafusion_expr::expr::{Cast, Sort};
use datafusion_expr::utils::{expr_as_column_expr, find_column_exprs};
Expand Down Expand Up @@ -197,14 +197,15 @@ where
.collect::<Result<Vec<_>>>()?,
window_frame.clone(),
))),
Expr::AggregateUDF { fun, args, filter } => Ok(Expr::AggregateUDF {
fun: fun.clone(),
args: args
.iter()
.map(|e| clone_with_replacement(e, replacement_fn))
.collect::<Result<Vec<Expr>>>()?,
filter: filter.clone(),
}),
Expr::AggregateUDF(AggregateUDF { fun, args, filter }) => {
Ok(Expr::AggregateUDF(AggregateUDF::new(
fun.clone(),
args.iter()
.map(|e| clone_with_replacement(e, replacement_fn))
.collect::<Result<Vec<Expr>>>()?,
filter.clone(),
)))
}
Expr::Alias(nested_expr, alias_name) => Ok(Expr::Alias(
Box::new(clone_with_replacement(nested_expr, replacement_fn)?),
alias_name.clone(),
Expand Down