Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
233 changes: 188 additions & 45 deletions datafusion/substrait/src/logical_plan/consumer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,11 @@ use async_recursion::async_recursion;
use datafusion::arrow::datatypes::DataType;
use datafusion::common::{DFField, DFSchema, DFSchemaRef};
use datafusion::logical_expr::{
aggregate_function, BinaryExpr, Case, Expr, LogicalPlan, Operator,
aggregate_function, window_function::find_df_window_func, BinaryExpr, Case, Expr,
LogicalPlan, Operator,
};
use datafusion::logical_expr::{build_join_schema, LogicalPlanBuilder};
use datafusion::logical_expr::{expr, Cast};
use datafusion::logical_expr::{expr, Cast, WindowFrameBound, WindowFrameUnits};
use datafusion::prelude::JoinType;
use datafusion::sql::TableReference;
use datafusion::{
Expand All @@ -35,7 +36,10 @@ use substrait::proto::{
aggregate_function::AggregationInvocation,
expression::{
field_reference::ReferenceType::DirectReference, literal::LiteralType,
reference_segment::ReferenceType::StructField, MaskExpression, RexType,
reference_segment::ReferenceType::StructField,
window_function::bound as SubstraitBound,
window_function::bound::Kind as BoundKind, window_function::Bound,
MaskExpression, RexType,
},
extensions::simple_extension_declaration::MappingType,
function_argument::ArgType,
Expand All @@ -45,6 +49,7 @@ use substrait::proto::{
sort_field::{SortDirection, SortKind::*},
AggregateFunction, Expression, Plan, Rel, Type,
};
use substrait::proto::{FunctionArgument, SortField};

use datafusion::logical_expr::expr::Sort;
use std::collections::HashMap;
Expand Down Expand Up @@ -139,13 +144,25 @@ pub async fn from_substrait_rel(
match &rel.rel_type {
Some(RelType::Project(p)) => {
if let Some(input) = p.input.as_ref() {
let input = LogicalPlanBuilder::from(
let mut input = LogicalPlanBuilder::from(
from_substrait_rel(ctx, input, extensions).await?,
);
let mut exprs: Vec<Expr> = vec![];
for e in &p.expressions {
let x = from_substrait_rex(e, input.schema(), extensions).await?;
exprs.push(x.as_ref().clone());
let x =
from_substrait_rex(e, input.clone().schema(), extensions).await?;
// if the expression is WindowFunction, wrap in a Window relation
// before returning and do not add to list of this Projection's expression list
// otherwise, add expression to the Projection's expression list
match &*x {
Expr::WindowFunction(_) => {
input = input.window(vec![x.as_ref().clone()])?;
exprs.push(x.as_ref().clone());
}
_ => {
exprs.push(x.as_ref().clone());
}
}
}
input.project(exprs)?.build()
} else {
Expand Down Expand Up @@ -193,45 +210,8 @@ pub async fn from_substrait_rel(
let input = LogicalPlanBuilder::from(
from_substrait_rel(ctx, input, extensions).await?,
);
let mut sorts: Vec<Expr> = vec![];
for s in &sort.sorts {
let expr = from_substrait_rex(
s.expr.as_ref().unwrap(),
input.schema(),
extensions,
)
.await?;
let asc_nullfirst = match &s.sort_kind {
Some(k) => match k {
Direction(d) => {
let direction : SortDirection = unsafe {
::std::mem::transmute(*d)
};
match direction {
SortDirection::AscNullsFirst => Ok((true, true)),
SortDirection::AscNullsLast => Ok((true, false)),
SortDirection::DescNullsFirst => Ok((false, true)),
SortDirection::DescNullsLast => Ok((false, false)),
SortDirection::Clustered =>
Err(DataFusionError::NotImplemented("Sort with direction clustered is not yet supported".to_string()))
,
SortDirection::Unspecified =>
Err(DataFusionError::NotImplemented("Unspecified sort direction is invalid".to_string()))
}
}
ComparisonFunctionReference(_) => {
Err(DataFusionError::NotImplemented("Sort using comparison function reference is not supported".to_string()))
},
},
None => Err(DataFusionError::NotImplemented("Sort without sort kind is invalid".to_string()))
};
let (asc, nulls_first) = asc_nullfirst.unwrap();
sorts.push(Expr::Sort(Sort {
expr: Box::new(expr.as_ref().clone()),
asc,
nulls_first,
}));
}
let sorts =
from_substrait_sorts(&sort.sorts, input.schema(), extensions).await?;
input.sort(sorts)?.build()
} else {
Err(DataFusionError::NotImplemented(
Expand Down Expand Up @@ -452,6 +432,90 @@ fn from_substrait_jointype(join_type: i32) -> Result<JoinType> {
}
}

/// Convert Substrait Sorts to DataFusion Exprs
pub async fn from_substrait_sorts(
substrait_sorts: &Vec<SortField>,
input_schema: &DFSchema,
extensions: &HashMap<u32, &String>,
) -> Result<Vec<Expr>> {
let mut sorts: Vec<Expr> = vec![];
for s in substrait_sorts {
let expr = from_substrait_rex(s.expr.as_ref().unwrap(), input_schema, extensions)
.await?;
let asc_nullfirst = match &s.sort_kind {
Some(k) => match k {
Direction(d) => {
let direction: SortDirection = unsafe { ::std::mem::transmute(*d) };
match direction {
SortDirection::AscNullsFirst => Ok((true, true)),
SortDirection::AscNullsLast => Ok((true, false)),
SortDirection::DescNullsFirst => Ok((false, true)),
SortDirection::DescNullsLast => Ok((false, false)),
SortDirection::Clustered => Err(DataFusionError::NotImplemented(
"Sort with direction clustered is not yet supported"
.to_string(),
)),
SortDirection::Unspecified => {
Err(DataFusionError::NotImplemented(
"Unspecified sort direction is invalid".to_string(),
))
}
}
}
ComparisonFunctionReference(_) => Err(DataFusionError::NotImplemented(
"Sort using comparison function reference is not supported"
.to_string(),
)),
},
None => Err(DataFusionError::NotImplemented(
"Sort without sort kind is invalid".to_string(),
)),
};
let (asc, nulls_first) = asc_nullfirst.unwrap();
sorts.push(Expr::Sort(Sort {
expr: Box::new(expr.as_ref().clone()),
asc,
nulls_first,
}));
}
Ok(sorts)
}

/// Convert Substrait Expressions to DataFusion Exprs
pub async fn from_substrait_rex_vec(
exprs: &Vec<Expression>,
input_schema: &DFSchema,
extensions: &HashMap<u32, &String>,
) -> Result<Vec<Expr>> {
let mut expressions: Vec<Expr> = vec![];
for expr in exprs {
let expression = from_substrait_rex(expr, input_schema, extensions).await?;
expressions.push(expression.as_ref().clone());
}
Ok(expressions)
}

/// Convert Substrait FunctionArguments to DataFusion Exprs
pub async fn from_substriat_func_args(
arguments: &Vec<FunctionArgument>,
input_schema: &DFSchema,
extensions: &HashMap<u32, &String>,
) -> Result<Vec<Expr>> {
let mut args: Vec<Expr> = vec![];
for arg in arguments {
let arg_expr = match &arg.arg_type {
Some(ArgType::Value(e)) => {
from_substrait_rex(e, input_schema, extensions).await
}
_ => Err(DataFusionError::NotImplemented(
"Aggregated function argument non-Value type not supported".to_string(),
)),
};
args.push(arg_expr?.as_ref().clone());
}
Ok(args)
}

/// Convert Substrait AggregateFunction to DataFusion Expr
pub async fn from_substrait_agg_func(
f: &AggregateFunction,
Expand Down Expand Up @@ -740,6 +804,47 @@ pub async fn from_substrait_rex(
"Cast experssion without output type is not allowed".to_string(),
)),
},
Some(RexType::WindowFunction(window)) => {
let fun = match extensions.get(&window.function_reference) {
Some(function_name) => Ok(find_df_window_func(function_name)),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we return an error if find_df_window_func returns None here? We unwrap this later on currently, which could fail?

None => Err(DataFusionError::NotImplemented(format!(
"Window function not found: function anchor = {:?}",
&window.function_reference
))),
};
let order_by =
from_substrait_sorts(&window.sorts, input_schema, extensions).await?;
// Substrait does not encode WindowFrameUnits so we're using a simple logic to determine the units
// If there is no `ORDER BY`, then by default, the frame counts each row from the lower up to upper boundary
// If there is `ORDER BY`, then by default, each frame is a range starting from unbounded preceding to current row
// TODO: Consider the cases where window frame is specified in query and is different from default
let units = if order_by.is_empty() {
WindowFrameUnits::Rows
} else {
WindowFrameUnits::Range
};
Ok(Arc::new(Expr::WindowFunction(expr::WindowFunction {
fun: fun?.unwrap(),
args: from_substriat_func_args(
&window.arguments,
input_schema,
extensions,
)
.await?,
partition_by: from_substrait_rex_vec(
&window.partitions,
input_schema,
extensions,
)
.await?,
order_by,
window_frame: datafusion::logical_expr::WindowFrame {
units,
start_bound: from_substrait_bound(&window.lower_bound, true)?,
end_bound: from_substrait_bound(&window.upper_bound, false)?,
},
})))
}
_ => Err(DataFusionError::NotImplemented(
"unsupported rex_type".to_string(),
)),
Expand Down Expand Up @@ -767,6 +872,44 @@ fn from_substrait_type(dt: &substrait::proto::Type) -> Result<DataType> {
}
}

fn from_substrait_bound(
bound: &Option<Bound>,
is_lower: bool,
) -> Result<WindowFrameBound> {
match bound {
Some(b) => match &b.kind {
Some(k) => match k {
BoundKind::CurrentRow(SubstraitBound::CurrentRow {}) => {
Ok(WindowFrameBound::CurrentRow)
}
BoundKind::Preceding(SubstraitBound::Preceding { offset }) => Ok(
WindowFrameBound::Preceding(ScalarValue::Int64(Some(*offset))),
),
BoundKind::Following(SubstraitBound::Following { offset }) => Ok(
WindowFrameBound::Following(ScalarValue::Int64(Some(*offset))),
),
BoundKind::Unbounded(SubstraitBound::Unbounded {}) => {
if is_lower {
Ok(WindowFrameBound::Preceding(ScalarValue::Null))
} else {
Ok(WindowFrameBound::Following(ScalarValue::Null))
}
}
},
None => Err(DataFusionError::Substrait(
"WindowFunction missing Substrait Bound kind".to_string(),
)),
},
None => {
if is_lower {
Ok(WindowFrameBound::Preceding(ScalarValue::Null))
} else {
Ok(WindowFrameBound::Following(ScalarValue::Null))
}
}
}
}

fn from_substrait_null(null_type: &Type) -> Result<ScalarValue> {
if let Some(kind) = &null_type.kind {
match kind {
Expand Down
Loading