diff --git a/datafusion/substrait/src/logical_plan/consumer.rs b/datafusion/substrait/src/logical_plan/consumer.rs index 9bc842a12af4..5e601cd7abda 100644 --- a/datafusion/substrait/src/logical_plan/consumer.rs +++ b/datafusion/substrait/src/logical_plan/consumer.rs @@ -20,7 +20,8 @@ use datafusion::arrow::datatypes::{ DataType, Field, FieldRef, Fields, IntervalUnit, Schema, TimeUnit, }; use datafusion::common::{ - not_impl_err, substrait_datafusion_err, substrait_err, DFSchema, DFSchemaRef, + not_impl_datafusion_err, not_impl_err, plan_datafusion_err, plan_err, + substrait_datafusion_err, substrait_err, DFSchema, DFSchemaRef, }; use substrait::proto::expression::literal::IntervalDayToSecond; use substrait::proto::read_rel::local_files::file_or_files::PathType::UriFile; @@ -30,8 +31,7 @@ use arrow_buffer::{IntervalDayTime, IntervalMonthDayNano}; use datafusion::execution::FunctionRegistry; use datafusion::logical_expr::{ aggregate_function, expr::find_df_window_func, Aggregate, BinaryExpr, Case, - EmptyRelation, Expr, ExprSchemable, LogicalPlan, Operator, Projection, ScalarUDF, - Values, + EmptyRelation, Expr, ExprSchemable, LogicalPlan, Operator, Projection, Values, }; use datafusion::logical_expr::{ @@ -57,7 +57,7 @@ use substrait::proto::{ reference_segment::ReferenceType::StructField, window_function::bound as SubstraitBound, window_function::bound::Kind as BoundKind, window_function::Bound, - MaskExpression, RexType, + window_function::BoundsType, MaskExpression, RexType, }, extensions::simple_extension_declaration::MappingType, function_argument::ArgType, @@ -71,7 +71,6 @@ use substrait::proto::{ use substrait::proto::{FunctionArgument, SortField}; use datafusion::arrow::array::GenericListArray; -use datafusion::common::plan_err; use datafusion::common::scalar::ScalarStructBuilder; use datafusion::logical_expr::expr::{InList, InSubquery, Sort}; use std::collections::HashMap; @@ -89,12 +88,6 @@ use crate::variation_const::{ UNSIGNED_INTEGER_TYPE_VARIATION_REF, }; -enum ScalarFunctionType { - Op(Operator), - Expr(BuiltinExprBuilder), - Udf(Arc), -} - pub fn name_to_op(name: &str) -> Result { match name { "equal" => Ok(Operator::Eq), @@ -128,28 +121,6 @@ pub fn name_to_op(name: &str) -> Result { } } -fn scalar_function_type_from_str( - ctx: &SessionContext, - name: &str, -) -> Result { - let s = ctx.state(); - let name = substrait_fun_name(name); - - if let Some(func) = s.scalar_functions().get(name) { - return Ok(ScalarFunctionType::Udf(func.to_owned())); - } - - if let Ok(op) = name_to_op(name) { - return Ok(ScalarFunctionType::Op(op)); - } - - if let Some(builder) = BuiltinExprBuilder::try_from_name(name) { - return Ok(ScalarFunctionType::Expr(builder)); - } - - not_impl_err!("Unsupported function name: {name:?}") -} - pub fn substrait_fun_name(name: &str) -> &str { let name = match name.rsplit_once(':') { // Since 0.32.0, Substrait requires the function names to be in a compound format @@ -972,7 +943,7 @@ pub async fn from_substrait_rex_vec( } /// Convert Substrait FunctionArguments to DataFusion Exprs -pub async fn from_substriat_func_args( +pub async fn from_substrait_func_args( ctx: &SessionContext, arguments: &Vec, input_schema: &DFSchema, @@ -984,9 +955,7 @@ pub async fn from_substriat_func_args( Some(ArgType::Value(e)) => { from_substrait_rex(ctx, e, input_schema, extensions).await } - _ => { - not_impl_err!("Aggregated function argument non-Value type not supported") - } + _ => not_impl_err!("Function argument non-Value type not supported"), }; args.push(arg_expr?.as_ref().clone()); } @@ -1003,18 +972,8 @@ pub async fn from_substrait_agg_func( order_by: Option>, distinct: bool, ) -> Result> { - let mut args: Vec = vec![]; - for arg in &f.arguments { - let arg_expr = match &arg.arg_type { - Some(ArgType::Value(e)) => { - from_substrait_rex(ctx, e, input_schema, extensions).await - } - _ => { - not_impl_err!("Aggregated function argument non-Value type not supported") - } - }; - args.push(arg_expr?.as_ref().clone()); - } + let args = + from_substrait_func_args(ctx, &f.arguments, input_schema, extensions).await?; let Some(function_name) = extensions.get(&f.function_reference) else { return plan_err!( @@ -1022,14 +981,16 @@ pub async fn from_substrait_agg_func( f.function_reference ); }; - // function_name.split(':').next().unwrap_or(function_name); + let function_name = substrait_fun_name((**function_name).as_str()); // try udaf first, then built-in aggr fn. if let Ok(fun) = ctx.udaf(function_name) { // deal with situation that count(*) got no arguments - if fun.name() == "count" && args.is_empty() { - args.push(Expr::Literal(ScalarValue::Int64(Some(1)))); - } + let args = if fun.name() == "count" && args.is_empty() { + vec![Expr::Literal(ScalarValue::Int64(Some(1)))] + } else { + args + }; Ok(Arc::new(Expr::AggregateFunction( expr::AggregateFunction::new_udf(fun, args, distinct, filter, order_by, None), @@ -1041,7 +1002,7 @@ pub async fn from_substrait_agg_func( ))) } else { not_impl_err!( - "Aggregated function {} is not supported: function anchor = {:?}", + "Aggregate function {} is not supported: function anchor = {:?}", function_name, f.function_reference ) @@ -1145,84 +1106,40 @@ pub async fn from_substrait_rex( }))) } Some(RexType::ScalarFunction(f)) => { - let fn_name = extensions.get(&f.function_reference).ok_or_else(|| { - DataFusionError::NotImplemented(format!( - "Aggregated function not found: function reference = {:?}", + let Some(fn_name) = extensions.get(&f.function_reference) else { + return plan_err!( + "Scalar function not found: function reference = {:?}", f.function_reference - )) - })?; - - // Convert function arguments from Substrait to DataFusion - async fn decode_arguments( - ctx: &SessionContext, - input_schema: &DFSchema, - extensions: &HashMap, - function_args: &[FunctionArgument], - ) -> Result> { - let mut args = Vec::with_capacity(function_args.len()); - for arg in function_args { - let arg_expr = match &arg.arg_type { - Some(ArgType::Value(e)) => { - from_substrait_rex(ctx, e, input_schema, extensions).await - } - _ => not_impl_err!( - "Aggregated function argument non-Value type not supported" - ), - }?; - args.push(arg_expr.as_ref().clone()); - } - Ok(args) - } + ); + }; + let fn_name = substrait_fun_name(fn_name); - let fn_type = scalar_function_type_from_str(ctx, fn_name)?; - match fn_type { - ScalarFunctionType::Udf(fun) => { - let args = decode_arguments( - ctx, - input_schema, - extensions, - f.arguments.as_slice(), - ) + let args = + from_substrait_func_args(ctx, &f.arguments, input_schema, extensions) .await?; - Ok(Arc::new(Expr::ScalarFunction( - expr::ScalarFunction::new_udf(fun, args), - ))) - } - ScalarFunctionType::Op(op) => { - if f.arguments.len() != 2 { - return not_impl_err!( - "Expect two arguments for binary operator {op:?}" - ); - } - let lhs = &f.arguments[0].arg_type; - let rhs = &f.arguments[1].arg_type; - - match (lhs, rhs) { - (Some(ArgType::Value(l)), Some(ArgType::Value(r))) => { - Ok(Arc::new(Expr::BinaryExpr(BinaryExpr { - left: Box::new( - from_substrait_rex(ctx, l, input_schema, extensions) - .await? - .as_ref() - .clone(), - ), - op, - right: Box::new( - from_substrait_rex(ctx, r, input_schema, extensions) - .await? - .as_ref() - .clone(), - ), - }))) - } - (l, r) => not_impl_err!( - "Invalid arguments for binary expression: {l:?} and {r:?}" - ), - } - } - ScalarFunctionType::Expr(builder) => { - builder.build(ctx, f, input_schema, extensions).await + + // try to first match the requested function into registered udfs, then built-in ops + // and finally built-in expressions + if let Some(func) = ctx.state().scalar_functions().get(fn_name) { + Ok(Arc::new(Expr::ScalarFunction( + expr::ScalarFunction::new_udf(func.to_owned(), args), + ))) + } else if let Ok(op) = name_to_op(fn_name) { + if args.len() != 2 { + return not_impl_err!( + "Expect two arguments for binary operator {op:?}" + ); } + + Ok(Arc::new(Expr::BinaryExpr(BinaryExpr { + left: Box::new(args[0].to_owned()), + op, + right: Box::new(args[1].to_owned()), + }))) + } else if let Some(builder) = BuiltinExprBuilder::try_from_name(fn_name) { + builder.build(ctx, f, input_schema, extensions).await + } else { + not_impl_err!("Unsupported function name: {fn_name:?}") } } Some(RexType::Literal(lit)) => { @@ -1247,36 +1164,50 @@ pub async fn from_substrait_rex( None => substrait_err!("Cast expression without output type is not allowed"), }, Some(RexType::WindowFunction(window)) => { - let fun = match extensions.get(&window.function_reference) { - Some(function_name) => { - // check udaf - match ctx.udaf(function_name) { - Ok(udaf) => { - Ok(Some(WindowFunctionDefinition::AggregateUDF(udaf))) - } - Err(_) => Ok(find_df_window_func(function_name)), - } - } - None => not_impl_err!( - "Window function not found: function anchor = {:?}", - &window.function_reference - ), + let Some(fn_name) = extensions.get(&window.function_reference) else { + return plan_err!( + "Window function not found: function reference = {:?}", + window.function_reference + ); }; + let fn_name = substrait_fun_name(fn_name); + + // check udaf first, then built-in functions + let fun = match ctx.udaf(fn_name) { + Ok(udaf) => Ok(WindowFunctionDefinition::AggregateUDF(udaf)), + Err(_) => find_df_window_func(fn_name).ok_or_else(|| { + not_impl_datafusion_err!( + "Window function {} is not supported: function anchor = {:?}", + fn_name, + window.function_reference + ) + }), + }?; + let order_by = from_substrait_sorts(ctx, &window.sorts, input_schema, extensions) .await?; - // Substrait does not encode WindowFrameUnits so we're using a simple logic to determine the units - // If there is no `ORDER BY`, then by default, the frame counts each row from the lower up to upper boundary - // If there is `ORDER BY`, then by default, each frame is a range starting from unbounded preceding to current row - // TODO: Consider the cases where window frame is specified in query and is different from default - let units = if order_by.is_empty() { - WindowFrameUnits::Rows - } else { - WindowFrameUnits::Range - }; + + let bound_units = + match BoundsType::try_from(window.bounds_type).map_err(|e| { + plan_datafusion_err!("Invalid bound type {}: {e}", window.bounds_type) + })? { + BoundsType::Rows => WindowFrameUnits::Rows, + BoundsType::Range => WindowFrameUnits::Range, + BoundsType::Unspecified => { + // If the plan does not specify the bounds type, then we use a simple logic to determine the units + // If there is no `ORDER BY`, then by default, the frame counts each row from the lower up to upper boundary + // If there is `ORDER BY`, then by default, each frame is a range starting from unbounded preceding to current row + if order_by.is_empty() { + WindowFrameUnits::Rows + } else { + WindowFrameUnits::Range + } + } + }; Ok(Arc::new(Expr::WindowFunction(expr::WindowFunction { - fun: fun?.unwrap(), - args: from_substriat_func_args( + fun, + args: from_substrait_func_args( ctx, &window.arguments, input_schema, @@ -1292,7 +1223,7 @@ pub async fn from_substrait_rex( .await?, order_by, window_frame: datafusion::logical_expr::WindowFrame::new_bounds( - units, + bound_units, from_substrait_bound(&window.lower_bound, true)?, from_substrait_bound(&window.upper_bound, false)?, ), diff --git a/datafusion/substrait/tests/cases/logical_plans.rs b/datafusion/substrait/tests/cases/logical_plans.rs index 94572e098b2c..6492febc938e 100644 --- a/datafusion/substrait/tests/cases/logical_plans.rs +++ b/datafusion/substrait/tests/cases/logical_plans.rs @@ -28,7 +28,7 @@ mod tests { use substrait::proto::Plan; #[tokio::test] - async fn function_compound_signature() -> Result<()> { + async fn scalar_function_compound_signature() -> Result<()> { // DataFusion currently produces Substrait that refers to functions only by their name. // However, the Substrait spec requires that functions be identified by their compound signature. // This test confirms that DataFusion is able to consume plans following the spec, even though @@ -39,7 +39,7 @@ mod tests { // File generated with substrait-java's Isthmus: // ./isthmus-cli/build/graal/isthmus "select not d from data" -c "create table data (d boolean)" - let proto = read_json("tests/testdata/select_not_bool.substrait.json"); + let proto = read_json("tests/testdata/test_plans/select_not_bool.substrait.json"); let plan = from_substrait_plan(&ctx, &proto).await?; @@ -51,13 +51,41 @@ mod tests { Ok(()) } + // Aggregate function compound signature is tested through TPCH plans + + #[tokio::test] + async fn window_function_compound_signature() -> Result<()> { + // DataFusion currently produces Substrait that refers to functions only by their name. + // However, the Substrait spec requires that functions be identified by their compound signature. + // This test confirms that DataFusion is able to consume plans following the spec, even though + // we don't yet produce such plans. + // Once we start producing plans with compound signatures, this test can be replaced by the roundtrip tests. + + let ctx = create_context().await?; + + // File generated with substrait-java's Isthmus: + // ./isthmus-cli/build/graal/isthmus "select sum(d) OVER (PARTITION BY part ORDER BY ord ROWS BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING) AS lead_expr from data" -c "create table data (d int, part int, ord int)" + let proto = read_json("tests/testdata/test_plans/select_window.substrait.json"); + + let plan = from_substrait_plan(&ctx, &proto).await?; + + assert_eq!( + format!("{:?}", plan), + "Projection: sum(DATA.a) PARTITION BY [DATA.b] ORDER BY [DATA.c ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING AS LEAD_EXPR\ + \n WindowAggr: windowExpr=[[sum(DATA.a) PARTITION BY [DATA.b] ORDER BY [DATA.c ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING]]\ + \n TableScan: DATA projection=[a, b, c, d, e, f]" + ); + Ok(()) + } + #[tokio::test] async fn non_nullable_lists() -> Result<()> { // DataFusion's Substrait consumer treats all lists as nullable, even if the Substrait plan specifies them as non-nullable. // That's because implementing the non-nullability consistently is non-trivial. // This test confirms that reading a plan with non-nullable lists works as expected. let ctx = create_context().await?; - let proto = read_json("tests/testdata/non_nullable_lists.substrait.json"); + let proto = + read_json("tests/testdata/test_plans/non_nullable_lists.substrait.json"); let plan = from_substrait_plan(&ctx, &proto).await?; diff --git a/datafusion/substrait/tests/testdata/non_nullable_lists.substrait.json b/datafusion/substrait/tests/testdata/test_plans/non_nullable_lists.substrait.json similarity index 100% rename from datafusion/substrait/tests/testdata/non_nullable_lists.substrait.json rename to datafusion/substrait/tests/testdata/test_plans/non_nullable_lists.substrait.json diff --git a/datafusion/substrait/tests/testdata/select_not_bool.substrait.json b/datafusion/substrait/tests/testdata/test_plans/select_not_bool.substrait.json similarity index 100% rename from datafusion/substrait/tests/testdata/select_not_bool.substrait.json rename to datafusion/substrait/tests/testdata/test_plans/select_not_bool.substrait.json diff --git a/datafusion/substrait/tests/testdata/test_plans/select_window.substrait.json b/datafusion/substrait/tests/testdata/test_plans/select_window.substrait.json new file mode 100644 index 000000000000..3082c4258f83 --- /dev/null +++ b/datafusion/substrait/tests/testdata/test_plans/select_window.substrait.json @@ -0,0 +1,153 @@ +{ + "extensionUris": [ + { + "extensionUriAnchor": 1, + "uri": "/functions_arithmetic.yaml" + } + ], + "extensions": [ + { + "extensionFunction": { + "extensionUriReference": 1, + "functionAnchor": 0, + "name": "sum:i32" + } + } + ], + "relations": [ + { + "root": { + "input": { + "project": { + "common": { + "emit": { + "outputMapping": [ + 3 + ] + } + }, + "input": { + "read": { + "common": { + "direct": { + } + }, + "baseSchema": { + "names": [ + "D", + "PART", + "ORD" + ], + "struct": { + "types": [ + { + "i32": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "i32": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "i32": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + } + ], + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, + "namedTable": { + "names": [ + "DATA" + ] + } + } + }, + "expressions": [ + { + "windowFunction": { + "functionReference": 0, + "partitions": [ + { + "selection": { + "directReference": { + "structField": { + "field": 1 + } + }, + "rootReference": { + } + } + } + ], + "sorts": [ + { + "expr": { + "selection": { + "directReference": { + "structField": { + "field": 2 + } + }, + "rootReference": { + } + } + }, + "direction": "SORT_DIRECTION_ASC_NULLS_LAST" + } + ], + "upperBound": { + "unbounded": { + } + }, + "lowerBound": { + "preceding": { + "offset": "1" + } + }, + "phase": "AGGREGATION_PHASE_INITIAL_TO_RESULT", + "outputType": { + "i32": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + "args": [], + "arguments": [ + { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 0 + } + }, + "rootReference": { + } + } + } + } + ], + "invocation": "AGGREGATION_INVOCATION_ALL", + "options": [], + "boundsType": "BOUNDS_TYPE_ROWS" + } + } + ] + } + }, + "names": [ + "LEAD_EXPR" + ] + } + } + ], + "expectedTypeUrls": [] +}