diff --git a/datafusion/substrait/src/logical_plan/consumer.rs b/datafusion/substrait/src/logical_plan/consumer.rs index b7cacf131d24..767c4a39375a 100644 --- a/datafusion/substrait/src/logical_plan/consumer.rs +++ b/datafusion/substrait/src/logical_plan/consumer.rs @@ -19,10 +19,11 @@ use async_recursion::async_recursion; use datafusion::arrow::datatypes::DataType; use datafusion::common::{DFField, DFSchema, DFSchemaRef}; use datafusion::logical_expr::{ - aggregate_function, BinaryExpr, Case, Expr, LogicalPlan, Operator, + aggregate_function, window_function::find_df_window_func, BinaryExpr, Case, Expr, + LogicalPlan, Operator, }; use datafusion::logical_expr::{build_join_schema, LogicalPlanBuilder}; -use datafusion::logical_expr::{expr, Cast}; +use datafusion::logical_expr::{expr, Cast, WindowFrameBound, WindowFrameUnits}; use datafusion::prelude::JoinType; use datafusion::sql::TableReference; use datafusion::{ @@ -35,7 +36,10 @@ use substrait::proto::{ aggregate_function::AggregationInvocation, expression::{ field_reference::ReferenceType::DirectReference, literal::LiteralType, - reference_segment::ReferenceType::StructField, MaskExpression, RexType, + reference_segment::ReferenceType::StructField, + window_function::bound as SubstraitBound, + window_function::bound::Kind as BoundKind, window_function::Bound, + MaskExpression, RexType, }, extensions::simple_extension_declaration::MappingType, function_argument::ArgType, @@ -45,6 +49,7 @@ use substrait::proto::{ sort_field::{SortDirection, SortKind::*}, AggregateFunction, Expression, Plan, Rel, Type, }; +use substrait::proto::{FunctionArgument, SortField}; use datafusion::logical_expr::expr::Sort; use std::collections::HashMap; @@ -139,13 +144,25 @@ pub async fn from_substrait_rel( match &rel.rel_type { Some(RelType::Project(p)) => { if let Some(input) = p.input.as_ref() { - let input = LogicalPlanBuilder::from( + let mut input = LogicalPlanBuilder::from( from_substrait_rel(ctx, input, extensions).await?, ); let mut exprs: Vec = vec![]; for e in &p.expressions { - let x = from_substrait_rex(e, input.schema(), extensions).await?; - exprs.push(x.as_ref().clone()); + let x = + from_substrait_rex(e, input.clone().schema(), extensions).await?; + // if the expression is WindowFunction, wrap in a Window relation + // before returning and do not add to list of this Projection's expression list + // otherwise, add expression to the Projection's expression list + match &*x { + Expr::WindowFunction(_) => { + input = input.window(vec![x.as_ref().clone()])?; + exprs.push(x.as_ref().clone()); + } + _ => { + exprs.push(x.as_ref().clone()); + } + } } input.project(exprs)?.build() } else { @@ -193,45 +210,8 @@ pub async fn from_substrait_rel( let input = LogicalPlanBuilder::from( from_substrait_rel(ctx, input, extensions).await?, ); - let mut sorts: Vec = vec![]; - for s in &sort.sorts { - let expr = from_substrait_rex( - s.expr.as_ref().unwrap(), - input.schema(), - extensions, - ) - .await?; - let asc_nullfirst = match &s.sort_kind { - Some(k) => match k { - Direction(d) => { - let direction : SortDirection = unsafe { - ::std::mem::transmute(*d) - }; - match direction { - SortDirection::AscNullsFirst => Ok((true, true)), - SortDirection::AscNullsLast => Ok((true, false)), - SortDirection::DescNullsFirst => Ok((false, true)), - SortDirection::DescNullsLast => Ok((false, false)), - SortDirection::Clustered => - Err(DataFusionError::NotImplemented("Sort with direction clustered is not yet supported".to_string())) - , - SortDirection::Unspecified => - Err(DataFusionError::NotImplemented("Unspecified sort direction is invalid".to_string())) - } - } - ComparisonFunctionReference(_) => { - Err(DataFusionError::NotImplemented("Sort using comparison function reference is not supported".to_string())) - }, - }, - None => Err(DataFusionError::NotImplemented("Sort without sort kind is invalid".to_string())) - }; - let (asc, nulls_first) = asc_nullfirst.unwrap(); - sorts.push(Expr::Sort(Sort { - expr: Box::new(expr.as_ref().clone()), - asc, - nulls_first, - })); - } + let sorts = + from_substrait_sorts(&sort.sorts, input.schema(), extensions).await?; input.sort(sorts)?.build() } else { Err(DataFusionError::NotImplemented( @@ -452,6 +432,90 @@ fn from_substrait_jointype(join_type: i32) -> Result { } } +/// Convert Substrait Sorts to DataFusion Exprs +pub async fn from_substrait_sorts( + substrait_sorts: &Vec, + input_schema: &DFSchema, + extensions: &HashMap, +) -> Result> { + let mut sorts: Vec = vec![]; + for s in substrait_sorts { + let expr = from_substrait_rex(s.expr.as_ref().unwrap(), input_schema, extensions) + .await?; + let asc_nullfirst = match &s.sort_kind { + Some(k) => match k { + Direction(d) => { + let direction: SortDirection = unsafe { ::std::mem::transmute(*d) }; + match direction { + SortDirection::AscNullsFirst => Ok((true, true)), + SortDirection::AscNullsLast => Ok((true, false)), + SortDirection::DescNullsFirst => Ok((false, true)), + SortDirection::DescNullsLast => Ok((false, false)), + SortDirection::Clustered => Err(DataFusionError::NotImplemented( + "Sort with direction clustered is not yet supported" + .to_string(), + )), + SortDirection::Unspecified => { + Err(DataFusionError::NotImplemented( + "Unspecified sort direction is invalid".to_string(), + )) + } + } + } + ComparisonFunctionReference(_) => Err(DataFusionError::NotImplemented( + "Sort using comparison function reference is not supported" + .to_string(), + )), + }, + None => Err(DataFusionError::NotImplemented( + "Sort without sort kind is invalid".to_string(), + )), + }; + let (asc, nulls_first) = asc_nullfirst.unwrap(); + sorts.push(Expr::Sort(Sort { + expr: Box::new(expr.as_ref().clone()), + asc, + nulls_first, + })); + } + Ok(sorts) +} + +/// Convert Substrait Expressions to DataFusion Exprs +pub async fn from_substrait_rex_vec( + exprs: &Vec, + input_schema: &DFSchema, + extensions: &HashMap, +) -> Result> { + let mut expressions: Vec = vec![]; + for expr in exprs { + let expression = from_substrait_rex(expr, input_schema, extensions).await?; + expressions.push(expression.as_ref().clone()); + } + Ok(expressions) +} + +/// Convert Substrait FunctionArguments to DataFusion Exprs +pub async fn from_substriat_func_args( + arguments: &Vec, + input_schema: &DFSchema, + extensions: &HashMap, +) -> Result> { + let mut args: Vec = vec![]; + for arg in arguments { + let arg_expr = match &arg.arg_type { + Some(ArgType::Value(e)) => { + from_substrait_rex(e, input_schema, extensions).await + } + _ => Err(DataFusionError::NotImplemented( + "Aggregated function argument non-Value type not supported".to_string(), + )), + }; + args.push(arg_expr?.as_ref().clone()); + } + Ok(args) +} + /// Convert Substrait AggregateFunction to DataFusion Expr pub async fn from_substrait_agg_func( f: &AggregateFunction, @@ -740,6 +804,47 @@ pub async fn from_substrait_rex( "Cast experssion without output type is not allowed".to_string(), )), }, + Some(RexType::WindowFunction(window)) => { + let fun = match extensions.get(&window.function_reference) { + Some(function_name) => Ok(find_df_window_func(function_name)), + None => Err(DataFusionError::NotImplemented(format!( + "Window function not found: function anchor = {:?}", + &window.function_reference + ))), + }; + let order_by = + from_substrait_sorts(&window.sorts, input_schema, extensions).await?; + // Substrait does not encode WindowFrameUnits so we're using a simple logic to determine the units + // If there is no `ORDER BY`, then by default, the frame counts each row from the lower up to upper boundary + // If there is `ORDER BY`, then by default, each frame is a range starting from unbounded preceding to current row + // TODO: Consider the cases where window frame is specified in query and is different from default + let units = if order_by.is_empty() { + WindowFrameUnits::Rows + } else { + WindowFrameUnits::Range + }; + Ok(Arc::new(Expr::WindowFunction(expr::WindowFunction { + fun: fun?.unwrap(), + args: from_substriat_func_args( + &window.arguments, + input_schema, + extensions, + ) + .await?, + partition_by: from_substrait_rex_vec( + &window.partitions, + input_schema, + extensions, + ) + .await?, + order_by, + window_frame: datafusion::logical_expr::WindowFrame { + units, + start_bound: from_substrait_bound(&window.lower_bound, true)?, + end_bound: from_substrait_bound(&window.upper_bound, false)?, + }, + }))) + } _ => Err(DataFusionError::NotImplemented( "unsupported rex_type".to_string(), )), @@ -767,6 +872,44 @@ fn from_substrait_type(dt: &substrait::proto::Type) -> Result { } } +fn from_substrait_bound( + bound: &Option, + is_lower: bool, +) -> Result { + match bound { + Some(b) => match &b.kind { + Some(k) => match k { + BoundKind::CurrentRow(SubstraitBound::CurrentRow {}) => { + Ok(WindowFrameBound::CurrentRow) + } + BoundKind::Preceding(SubstraitBound::Preceding { offset }) => Ok( + WindowFrameBound::Preceding(ScalarValue::Int64(Some(*offset))), + ), + BoundKind::Following(SubstraitBound::Following { offset }) => Ok( + WindowFrameBound::Following(ScalarValue::Int64(Some(*offset))), + ), + BoundKind::Unbounded(SubstraitBound::Unbounded {}) => { + if is_lower { + Ok(WindowFrameBound::Preceding(ScalarValue::Null)) + } else { + Ok(WindowFrameBound::Following(ScalarValue::Null)) + } + } + }, + None => Err(DataFusionError::Substrait( + "WindowFunction missing Substrait Bound kind".to_string(), + )), + }, + None => { + if is_lower { + Ok(WindowFrameBound::Preceding(ScalarValue::Null)) + } else { + Ok(WindowFrameBound::Following(ScalarValue::Null)) + } + } + } +} + fn from_substrait_null(null_type: &Type) -> Result { if let Some(kind) = &null_type.kind { match kind { diff --git a/datafusion/substrait/src/logical_plan/producer.rs b/datafusion/substrait/src/logical_plan/producer.rs index 1de26a343372..ecb322edb70e 100644 --- a/datafusion/substrait/src/logical_plan/producer.rs +++ b/datafusion/substrait/src/logical_plan/producer.rs @@ -20,6 +20,7 @@ use std::{collections::HashMap, mem, sync::Arc}; use datafusion::{ arrow::datatypes::DataType, error::{DataFusionError, Result}, + logical_expr::{WindowFrame, WindowFrameBound}, prelude::JoinType, scalar::ScalarValue, }; @@ -27,7 +28,7 @@ use datafusion::{ use datafusion::common::DFSchemaRef; #[allow(unused_imports)] use datafusion::logical_expr::aggregate_function; -use datafusion::logical_expr::expr::{BinaryExpr, Case, Cast, Sort}; +use datafusion::logical_expr::expr::{BinaryExpr, Case, Cast, Sort, WindowFunction}; use datafusion::logical_expr::{expr, Between, JoinConstraint, LogicalPlan, Operator}; use datafusion::prelude::{binary_expr, Expr}; use substrait::proto::{ @@ -38,8 +39,12 @@ use substrait::proto::{ if_then::IfClause, literal::{Decimal, LiteralType}, mask_expression::{StructItem, StructSelect}, - reference_segment, FieldReference, IfThen, Literal, MaskExpression, - ReferenceSegment, RexType, ScalarFunction, + reference_segment, + window_function::bound as SubstraitBound, + window_function::bound::Kind as BoundKind, + window_function::Bound, + FieldReference, IfThen, Literal, MaskExpression, ReferenceSegment, RexType, + ScalarFunction, WindowFunction as SubstraitWindowFunction, }, extensions::{ self, @@ -301,6 +306,42 @@ pub fn to_substrait_rel( // since there is no corresponding relation type in Substrait to_substrait_rel(alias.input.as_ref(), extension_info) } + LogicalPlan::Window(window) => { + let input = to_substrait_rel(window.input.as_ref(), extension_info)?; + // If the input is a Project relation, we can just append the WindowFunction expressions + // before returning + // Otherwise, wrap the input in a Project relation before appending the WindowFunction + // expressions + let mut project_rel: Box = match &input.as_ref().rel_type { + Some(RelType::Project(p)) => Box::new(*p.clone()), + _ => { + // Create Projection with field referencing all output fields in the input relation + let expressions = (0..window.input.schema().fields().len()) + .map(substrait_field_ref) + .collect::>>()?; + Box::new(ProjectRel { + common: None, + input: Some(input), + expressions, + advanced_extension: None, + }) + } + }; + // Parse WindowFunction expression + let mut window_exprs = vec![]; + for expr in &window.window_expr { + window_exprs.push(to_substrait_rex( + expr, + window.input.schema(), + extension_info, + )?); + } + // Append parsed WindowFunction expressions + project_rel.expressions.extend(window_exprs); + Ok(Box::new(Rel { + rel_type: Some(RelType::Project(project_rel)), + })) + } _ => Err(DataFusionError::NotImplemented(format!( "Unsupported operator: {plan:?}" ))), @@ -636,6 +677,47 @@ pub fn to_substrait_rex( }) } Expr::Alias(expr, _alias) => to_substrait_rex(expr, schema, extension_info), + Expr::WindowFunction(WindowFunction { + fun, + args, + partition_by, + order_by, + window_frame, + }) => { + // function reference + let function_name = fun.to_string().to_lowercase(); + let function_anchor = _register_function(function_name, extension_info); + // arguments + let mut arguments: Vec = vec![]; + for arg in args { + arguments.push(FunctionArgument { + arg_type: Some(ArgType::Value(to_substrait_rex( + arg, + schema, + extension_info, + )?)), + }); + } + // partition by expressions + let partition_by = partition_by + .iter() + .map(|e| to_substrait_rex(e, schema, extension_info)) + .collect::>>()?; + // order by expressions + let order_by = order_by + .iter() + .map(|e| substrait_sort_field(e, schema, extension_info)) + .collect::>>()?; + // window frame + let bounds = to_substrait_bounds(window_frame)?; + Ok(make_substrait_window_function( + function_anchor, + arguments, + partition_by, + order_by, + bounds, + )) + } _ => Err(DataFusionError::NotImplemented(format!( "Unsupported expression: {expr:?}" ))), @@ -693,6 +775,136 @@ fn to_substrait_type(dt: &DataType) -> Result { } } +#[allow(deprecated)] +fn make_substrait_window_function( + function_reference: u32, + arguments: Vec, + partitions: Vec, + sorts: Vec, + bounds: (Bound, Bound), +) -> Expression { + Expression { + rex_type: Some(RexType::WindowFunction(SubstraitWindowFunction { + function_reference, + arguments, + partitions, + sorts, + options: vec![], + output_type: None, + phase: 0, // default to AGGREGATION_PHASE_UNSPECIFIED + invocation: 0, // TODO: fix + lower_bound: Some(bounds.0), + upper_bound: Some(bounds.1), + args: vec![], + })), + } +} + +fn to_substrait_bound(bound: &WindowFrameBound) -> Bound { + match bound { + WindowFrameBound::CurrentRow => Bound { + kind: Some(BoundKind::CurrentRow(SubstraitBound::CurrentRow {})), + }, + WindowFrameBound::Preceding(s) => match s { + ScalarValue::UInt8(Some(v)) => Bound { + kind: Some(BoundKind::Preceding(SubstraitBound::Preceding { + offset: *v as i64, + })), + }, + ScalarValue::UInt16(Some(v)) => Bound { + kind: Some(BoundKind::Preceding(SubstraitBound::Preceding { + offset: *v as i64, + })), + }, + ScalarValue::UInt32(Some(v)) => Bound { + kind: Some(BoundKind::Preceding(SubstraitBound::Preceding { + offset: *v as i64, + })), + }, + ScalarValue::UInt64(Some(v)) => Bound { + kind: Some(BoundKind::Preceding(SubstraitBound::Preceding { + offset: *v as i64, + })), + }, + ScalarValue::Int8(Some(v)) => Bound { + kind: Some(BoundKind::Preceding(SubstraitBound::Preceding { + offset: *v as i64, + })), + }, + ScalarValue::Int16(Some(v)) => Bound { + kind: Some(BoundKind::Preceding(SubstraitBound::Preceding { + offset: *v as i64, + })), + }, + ScalarValue::Int32(Some(v)) => Bound { + kind: Some(BoundKind::Preceding(SubstraitBound::Preceding { + offset: *v as i64, + })), + }, + ScalarValue::Int64(Some(v)) => Bound { + kind: Some(BoundKind::Preceding(SubstraitBound::Preceding { + offset: *v, + })), + }, + _ => Bound { + kind: Some(BoundKind::Unbounded(SubstraitBound::Unbounded {})), + }, + }, + WindowFrameBound::Following(s) => match s { + ScalarValue::UInt8(Some(v)) => Bound { + kind: Some(BoundKind::Following(SubstraitBound::Following { + offset: *v as i64, + })), + }, + ScalarValue::UInt16(Some(v)) => Bound { + kind: Some(BoundKind::Following(SubstraitBound::Following { + offset: *v as i64, + })), + }, + ScalarValue::UInt32(Some(v)) => Bound { + kind: Some(BoundKind::Following(SubstraitBound::Following { + offset: *v as i64, + })), + }, + ScalarValue::UInt64(Some(v)) => Bound { + kind: Some(BoundKind::Following(SubstraitBound::Following { + offset: *v as i64, + })), + }, + ScalarValue::Int8(Some(v)) => Bound { + kind: Some(BoundKind::Following(SubstraitBound::Following { + offset: *v as i64, + })), + }, + ScalarValue::Int16(Some(v)) => Bound { + kind: Some(BoundKind::Following(SubstraitBound::Following { + offset: *v as i64, + })), + }, + ScalarValue::Int32(Some(v)) => Bound { + kind: Some(BoundKind::Following(SubstraitBound::Following { + offset: *v as i64, + })), + }, + ScalarValue::Int64(Some(v)) => Bound { + kind: Some(BoundKind::Following(SubstraitBound::Following { + offset: *v, + })), + }, + _ => Bound { + kind: Some(BoundKind::Unbounded(SubstraitBound::Unbounded {})), + }, + }, + } +} + +fn to_substrait_bounds(window_frame: &WindowFrame) -> Result<(Bound, Bound)> { + Ok(( + to_substrait_bound(&window_frame.start_bound), + to_substrait_bound(&window_frame.end_bound), + )) +} + fn try_to_substrait_null(v: &ScalarValue) -> Result { let default_type_ref = 0; let default_nullability = r#type::Nullability::Nullable as i32; diff --git a/datafusion/substrait/tests/roundtrip_logical_plan.rs b/datafusion/substrait/tests/roundtrip_logical_plan.rs index 9aa430bb09dd..936c4670b37d 100644 --- a/datafusion/substrait/tests/roundtrip_logical_plan.rs +++ b/datafusion/substrait/tests/roundtrip_logical_plan.rs @@ -250,6 +250,11 @@ mod tests { .await } + #[tokio::test] + async fn simple_window_function() -> Result<()> { + roundtrip("SELECT RANK() OVER (PARTITION BY a ORDER BY b), d, SUM(b) OVER (PARTITION BY a) FROM data;").await + } + async fn assert_expected_plan(sql: &str, expected_plan_str: &str) -> Result<()> { let mut ctx = create_context().await?; let df = ctx.sql(sql).await?;