Skip to content

Commit

Permalink
Add plan and function extension support (#27)
Browse files Browse the repository at this point in the history
* Add plan and function extension support

* Removed unwraps
  • Loading branch information
nseekhao committed Oct 24, 2022
1 parent e1b9569 commit 09b2102
Show file tree
Hide file tree
Showing 5 changed files with 116 additions and 58 deletions.
25 changes: 25 additions & 0 deletions src/consumer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ use datafusion::{
scalar::ScalarValue,
};
use substrait::protobuf::sort_field::{SortKind::*, SortDirection};
use substrait::protobuf::Plan;
use std::collections::HashMap;
use std::sync::Arc;
use substrait::protobuf::{
Expand Down Expand Up @@ -60,6 +61,30 @@ pub fn reference_to_op(reference: u32) -> Result<Operator> {
}
}

/// Convert Substrait Plan to DataFusion DataFrame
pub async fn from_substrait_plan(ctx: &mut SessionContext, plan: &Plan) -> Result<Arc<DataFrame>> {
match plan.relations.len() {
1 => {
match plan.relations[0].rel_type.as_ref() {
Some(rt) => match rt {
substrait::protobuf::plan_rel::RelType::Rel(rel) => {
Ok(from_substrait_rel(ctx, &rel).await?)
},
substrait::protobuf::plan_rel::RelType::Root(_) => Err(DataFusionError::NotImplemented(
"RootRel not supported".to_string()
)),
},
None => Err(DataFusionError::Internal("Cannot parse plan relation: None".to_string()))
}

},
_ => Err(DataFusionError::NotImplemented(format!(
"Substrait plan with more than 1 relation trees not supported. Number of relation trees: {:?}",
plan.relations.len()
)))
}
}

/// Convert Substrait Rel to DataFusion DataFrame
#[async_recursion]
pub async fn from_substrait_rel(ctx: &mut SessionContext, rel: &Rel) -> Result<Arc<DataFrame>> {
Expand Down
125 changes: 79 additions & 46 deletions src/producer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,37 @@ use substrait::protobuf::{
SortKind,
},
function_argument::ArgType,
plan_rel,
read_rel::{NamedTable, ReadType},
rel::RelType,
Expression, FetchRel, FilterRel, FunctionArgument, JoinRel, NamedStruct, ProjectRel, ReadRel, Rel, SortField, SortRel
Expression, FetchRel, FilterRel, FunctionArgument, JoinRel, NamedStruct, ProjectRel, ReadRel, SortField, SortRel,
PlanRel,
Plan, Rel, extensions::{self, simple_extension_declaration::{MappingType, ExtensionFunction}},
};

/// Convert DataFusion LogicalPlan to Substrait Plan
pub fn to_substrait_plan(plan: &LogicalPlan) -> Result<Box<Plan>> {
// Parse relation nodes
let mut extensions: Vec<extensions::SimpleExtensionDeclaration> = vec![];
// Generate PlanRel(s)
// Note: Only 1 relation tree is currently supported
let plan_rels = vec![PlanRel {
rel_type: Some(plan_rel::RelType::Rel(*to_substrait_rel(plan, &mut extensions)?))
}];

// Return parsed plan
Ok(Box::new(Plan {
extension_uris: vec![],
extensions: extensions,
relations: plan_rels,
advanced_extensions: None,
expected_type_urls: vec![],
}))

}

/// Convert DataFusion LogicalPlan to Substrait Rel
pub fn to_substrait_rel(plan: &LogicalPlan) -> Result<Box<Rel>> {
pub fn to_substrait_rel(plan: &LogicalPlan, extensions: &mut Vec<extensions::SimpleExtensionDeclaration>) -> Result<Box<Rel>> {
match plan {
LogicalPlan::TableScan(scan) => {
let projection = scan.projection.as_ref().map(|p| {
Expand Down Expand Up @@ -69,20 +93,20 @@ pub fn to_substrait_rel(plan: &LogicalPlan) -> Result<Box<Rel>> {
let expressions = p
.expr
.iter()
.map(|e| to_substrait_rex(e, p.input.schema()))
.map(|e| to_substrait_rex(e, p.input.schema(), extensions))
.collect::<Result<Vec<_>>>()?;
Ok(Box::new(Rel {
rel_type: Some(RelType::Project(Box::new(ProjectRel {
common: None,
input: Some(to_substrait_rel(p.input.as_ref())?),
input: Some(to_substrait_rel(p.input.as_ref(), extensions)?),
expressions,
advanced_extension: None,
}))),
}))
}
LogicalPlan::Filter(filter) => {
let input = to_substrait_rel(filter.input.as_ref())?;
let filter_expr = to_substrait_rex(&filter.predicate, filter.input.schema())?;
let input = to_substrait_rel(filter.input.as_ref(), extensions)?;
let filter_expr = to_substrait_rex(&filter.predicate, filter.input.schema(), extensions)?;
Ok(Box::new(Rel {
rel_type: Some(RelType::Filter(Box::new(FilterRel {
common: None,
Expand All @@ -93,7 +117,7 @@ pub fn to_substrait_rel(plan: &LogicalPlan) -> Result<Box<Rel>> {
}))
}
LogicalPlan::Limit(limit) => {
let input = to_substrait_rel(limit.input.as_ref())?;
let input = to_substrait_rel(limit.input.as_ref(), extensions)?;
let limit_fetch = match limit.fetch {
Some(count) => count,
None => 0,
Expand All @@ -109,11 +133,11 @@ pub fn to_substrait_rel(plan: &LogicalPlan) -> Result<Box<Rel>> {
}))
}
LogicalPlan::Sort(sort) => {
let input = to_substrait_rel(sort.input.as_ref())?;
let input = to_substrait_rel(sort.input.as_ref(), extensions)?;
let sort_fields = sort
.expr
.iter()
.map(|e| substrait_sort_field(e, sort.input.schema()))
.map(|e| substrait_sort_field(e, sort.input.schema(), extensions))
.collect::<Result<Vec<_>>>()?;
Ok(Box::new(Rel {
rel_type: Some(RelType::Sort(Box::new(SortRel {
Expand All @@ -125,8 +149,8 @@ pub fn to_substrait_rel(plan: &LogicalPlan) -> Result<Box<Rel>> {
}))
}
LogicalPlan::Join(join) => {
let left = to_substrait_rel(join.left.as_ref())?;
let right = to_substrait_rel(join.right.as_ref())?;
let left = to_substrait_rel(join.left.as_ref(), extensions)?;
let right = to_substrait_rel(join.right.as_ref(), extensions)?;
let join_type = match join.join_type {
JoinType::Inner => 1,
JoinType::Left => 2,
Expand Down Expand Up @@ -169,7 +193,7 @@ pub fn to_substrait_rel(plan: &LogicalPlan) -> Result<Box<Rel>> {
left: Some(left),
right: Some(right),
r#type: join_type,
expression: Some(Box::new(to_substrait_rex(&e, &join.schema)?)),
expression: Some(Box::new(to_substrait_rex(&e, &join.schema, extensions)?)),
post_join_filter: None,
advanced_extension: None,
}))),
Expand All @@ -187,49 +211,58 @@ pub fn to_substrait_rel(plan: &LogicalPlan) -> Result<Box<Rel>> {
}
}

pub fn operator_to_reference(op: Operator) -> u32 {
pub fn operator_to_reference(op: Operator) -> (u32, &'static str) {
match op {
Operator::Eq => 1,
Operator::NotEq => 2,
Operator::Lt => 3,
Operator::LtEq => 4,
Operator::Gt => 5,
Operator::GtEq => 6,
Operator::Plus => 7,
Operator::Minus => 8,
Operator::Multiply => 9,
Operator::Divide => 10,
Operator::Modulo => 11,
Operator::And => 12,
Operator::Or => 13,
Operator::Like => 14,
Operator::NotLike => 15,
Operator::IsDistinctFrom => 16,
Operator::IsNotDistinctFrom => 17,
Operator::RegexMatch => 18,
Operator::RegexIMatch => 19,
Operator::RegexNotMatch => 20,
Operator::RegexNotIMatch => 21,
Operator::BitwiseAnd => 22,
Operator::BitwiseOr => 23,
Operator::StringConcat => 24,
Operator::BitwiseXor => 25,
Operator::BitwiseShiftRight => 26,
Operator::BitwiseShiftLeft => 27,
Operator::Eq => (1, "equal"),
Operator::NotEq => (2, "not_equal"),
Operator::Lt => (3, "lt"),
Operator::LtEq => (4, "lte"),
Operator::Gt => (5, "gt"),
Operator::GtEq => (6, "gte"),
Operator::Plus => (7, "add"),
Operator::Minus => (8, "substract"),
Operator::Multiply => (9, "multiply"),
Operator::Divide => (10, "divide"),
Operator::Modulo => (11, "mod"),
Operator::And => (12, "and"),
Operator::Or => (13, "or"),
Operator::Like => (14, "like"),
Operator::NotLike => (15, "not_like"),
Operator::IsDistinctFrom => (16, "is_distinct_from"),
Operator::IsNotDistinctFrom => (17, "is_not_distinct_from"),
Operator::RegexMatch => (18, "regex_match"),
Operator::RegexIMatch => (19, "regex_imatch"),
Operator::RegexNotMatch => (20, "regex_not_match"),
Operator::RegexNotIMatch => (21, "regex_not_imatch"),
Operator::BitwiseAnd => (22, "bitwise_and"),
Operator::BitwiseOr => (23, "bitwise_or"),
Operator::StringConcat => (24, "str_concat"),
Operator::BitwiseXor => (25, "bitwise_xor"),
Operator::BitwiseShiftRight => (26, "bitwise_shift_right"),
Operator::BitwiseShiftLeft => (27, "bitwise_shift_left"),
}
}

/// Convert DataFusion Expr to Substrait Rex
pub fn to_substrait_rex(expr: &Expr, schema: &DFSchemaRef) -> Result<Expression> {
pub fn to_substrait_rex(expr: &Expr, schema: &DFSchemaRef, extensions: &mut Vec<extensions::SimpleExtensionDeclaration>) -> Result<Expression> {
match expr {
Expr::Column(col) => {
let index = schema.index_of_column(&col)?;
substrait_field_ref(index)
}
Expr::BinaryExpr { left, op, right } => {
let l = to_substrait_rex(left, schema)?;
let r = to_substrait_rex(right, schema)?;
let function_reference: u32 = operator_to_reference(*op);
let l = to_substrait_rex(left, schema, extensions)?;
let r = to_substrait_rex(right, schema, extensions)?;
let (function_reference, function_name) = operator_to_reference(*op);
let extension_function = ExtensionFunction {
extension_uri_reference: extensions.len() as u32,
function_anchor: function_reference,
name: function_name.to_string(),
};
let extension = extensions::SimpleExtensionDeclaration {
mapping_type: Some(MappingType::ExtensionFunction(extension_function)),
};
extensions.push(extension);
Ok(Expression {
rex_type: Some(RexType::ScalarFunction(ScalarFunction {
function_reference,
Expand Down Expand Up @@ -282,10 +315,10 @@ pub fn to_substrait_rex(expr: &Expr, schema: &DFSchemaRef) -> Result<Expression>
}
}

fn substrait_sort_field(expr: &Expr, schema: &DFSchemaRef) -> Result<SortField> {
fn substrait_sort_field(expr: &Expr, schema: &DFSchemaRef, extensions: &mut Vec<extensions::SimpleExtensionDeclaration>) -> Result<SortField> {
match expr {
Expr::Sort { expr, asc, nulls_first } => {
let e = to_substrait_rex(expr, schema)?;
let e = to_substrait_rex(expr, schema, extensions)?;
let d = match (asc, nulls_first) {
(true, true) => SortDirection::AscNullsFirst,
(true, false) => SortDirection::AscNullsLast,
Expand Down
6 changes: 3 additions & 3 deletions src/serializer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,15 @@ use datafusion::error::Result;
use datafusion::prelude::*;

use prost::Message;
use substrait::protobuf::Rel;
use substrait::protobuf::Plan;

use std::fs::OpenOptions;
use std::io::{Write, Read};

pub async fn serialize(sql: &str, ctx: &SessionContext, path: &str) -> Result<()> {
let df = ctx.sql(sql).await?;
let plan = df.to_logical_plan()?;
let proto = producer::to_substrait_rel(&plan)?;
let proto = producer::to_substrait_plan(&plan)?;

let mut protobuf_out = Vec::<u8>::new();
proto.encode(&mut protobuf_out).unwrap();
Expand All @@ -24,7 +24,7 @@ pub async fn serialize(sql: &str, ctx: &SessionContext, path: &str) -> Result<()
Ok(())
}

pub async fn deserialize(path: &str) -> Result<Box<Rel>> {
pub async fn deserialize(path: &str) -> Result<Box<Plan>> {
let mut protobuf_in = Vec::<u8>::new();

let mut file = OpenOptions::new()
Expand Down
14 changes: 7 additions & 7 deletions tests/roundtrip.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use datafusion_substrait::producer;
#[cfg(test)]
mod tests {

use crate::{consumer::from_substrait_rel, producer::to_substrait_rel};
use crate::{consumer::from_substrait_plan, producer::to_substrait_plan};
use datafusion::error::Result;
use datafusion::prelude::*;

Expand Down Expand Up @@ -65,8 +65,8 @@ mod tests {
let mut ctx = create_context().await?;
let df = ctx.sql(sql).await?;
let plan = df.to_logical_plan()?;
let proto = to_substrait_rel(&plan)?;
let df = from_substrait_rel(&mut ctx, &proto).await?;
let proto = to_substrait_plan(&plan)?;
let df = from_substrait_plan(&mut ctx, &proto).await?;
let plan2 = df.to_logical_plan()?;
let plan2str = format!("{:?}", plan2);
assert_eq!(expected_plan_str, &plan2str);
Expand All @@ -77,9 +77,9 @@ mod tests {
let mut ctx = create_context().await?;
let df = ctx.sql(sql).await?;
let plan1 = df.to_logical_plan()?;
let proto = to_substrait_rel(&plan1)?;
let proto = to_substrait_plan(&plan1)?;

let df = from_substrait_rel(&mut ctx, &proto).await?;
let df = from_substrait_plan(&mut ctx, &proto).await?;
let plan2 = df.to_logical_plan()?;

// Format plan string and replace all None's with 0
Expand All @@ -94,9 +94,9 @@ mod tests {
let mut ctx = create_context().await?;
let df = ctx.sql(sql).await?;
let plan = df.to_logical_plan()?;
let proto = to_substrait_rel(&plan)?;
let proto = to_substrait_plan(&plan)?;

let df = from_substrait_rel(&mut ctx, &proto).await?;
let df = from_substrait_plan(&mut ctx, &proto).await?;
let plan2 = df.to_logical_plan()?;

let plan1str = format!("{:?}", plan);
Expand Down
4 changes: 2 additions & 2 deletions tests/serialize.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
#[cfg(test)]
mod tests {

use datafusion_substrait::consumer::from_substrait_rel;
use datafusion_substrait::consumer::from_substrait_plan;
use datafusion_substrait::serializer;

use datafusion::error::Result;
Expand All @@ -24,7 +24,7 @@ mod tests {
// Read substrait plan from file
let proto = serializer::deserialize(path).await?;
// Check plan equality
let df = from_substrait_rel(&mut ctx, &proto).await?;
let df = from_substrait_plan(&mut ctx, &proto).await?;
let plan = df.to_logical_plan()?;
let plan_str_ref = format!("{:?}", plan_ref);
let plan_str = format!("{:?}", plan);
Expand Down

0 comments on commit 09b2102

Please sign in to comment.