Skip to content

Commit

Permalink
Add support for Substrait VirtualTables
Browse files Browse the repository at this point in the history
Adds support for Substrait's VirtualTables, ie. tables with data baked-in into the Substrait plan instead of being read from a source.

Adds conversion in both ways (Substrait -> DataFusion and DataFusion -> Substrait)
and a roundtrip test.
  • Loading branch information
Blizzara committed May 24, 2024
1 parent 8bedecc commit 02e7bc0
Show file tree
Hide file tree
Showing 3 changed files with 228 additions and 28 deletions.
143 changes: 119 additions & 24 deletions datafusion/substrait/src/logical_plan/consumer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,15 @@
// under the License.

use async_recursion::async_recursion;
use datafusion::arrow::datatypes::{DataType, Field, TimeUnit};
use datafusion::arrow::datatypes::{DataType, Field, Fields, Schema, TimeUnit};
use datafusion::common::{
not_impl_err, substrait_datafusion_err, substrait_err, DFSchema, DFSchemaRef,
};

use datafusion::execution::FunctionRegistry;
use datafusion::logical_expr::{
aggregate_function, expr::find_df_window_func, BinaryExpr, Case, Expr, LogicalPlan,
Operator, ScalarUDF,
Operator, ScalarUDF, Values,
};
use datafusion::logical_expr::{
expr, Cast, Extension, GroupingSet, Like, LogicalPlanBuilder, Partitioning,
Expand Down Expand Up @@ -507,7 +507,52 @@ pub async fn from_substrait_rel(
_ => Ok(t),
}
}
_ => not_impl_err!("Only NamedTable reads are supported"),
Some(ReadType::VirtualTable(vt)) => {
let base_schema = read.base_schema.as_ref().ok_or_else(|| {
substrait_datafusion_err!("No base schema provided for Virtual Table")
})?;

let fields = from_substrait_struct(
base_schema.r#struct.as_ref().ok_or_else(|| {
substrait_datafusion_err!("Named struct must contain a struct")
})?,
&base_schema.names,
&mut 0,
);
let schema = DFSchemaRef::new(DFSchema::try_from(Schema::new(fields?))?);

let values = vt
.values
.iter()
.map(|row| {
let mut name_idx = 0;
let lits = row
.fields
.iter()
.map(|lit| {
name_idx += 1; // top-level names are provided through schema
Ok(Expr::Literal(from_substrait_literal_with_names(
lit,
&base_schema.names,
&mut name_idx,
)?))
})
.collect::<Result<_>>()?;
if name_idx != base_schema.names.len() {
Err(substrait_datafusion_err!(
"Names list must match exactly to nested ®schema, but found {} uses for {} names",
name_idx,
base_schema.names.len()
))
} else {
Ok(lits)
}
})
.collect::<Result<_>>()?;

Ok(LogicalPlan::Values(Values { schema, values }))
}
_ => not_impl_err!("Only NamedTable and VirtualTable reads are supported"),
},
Some(RelType::Set(set)) => match set_rel::SetOp::try_from(set.op) {
Ok(set_op) => match set_op {
Expand Down Expand Up @@ -1060,7 +1105,15 @@ pub async fn from_substrait_rex(
}
}

pub(crate) fn from_substrait_type(dt: &substrait::proto::Type) -> Result<DataType> {
pub(crate) fn from_substrait_type(dt: &Type) -> Result<DataType> {
from_substrait_type_with_names(dt, &vec![], &mut 0)
}

fn from_substrait_type_with_names(
dt: &Type,
dfs_names: &Vec<String>,
name_idx: &mut usize,
) -> Result<DataType> {
match &dt.kind {
Some(s_kind) => match s_kind {
r#type::Kind::Bool(_) => Ok(DataType::Boolean),
Expand Down Expand Up @@ -1162,24 +1215,50 @@ pub(crate) fn from_substrait_type(dt: &substrait::proto::Type) -> Result<DataTyp
"Unsupported Substrait type variation {v} of type {s_kind:?}"
),
},
r#type::Kind::Struct(s) => {
let mut fields = vec![];
for (i, f) in s.types.iter().enumerate() {
let field = Field::new(
&format!("c{i}"),
from_substrait_type(f)?,
is_substrait_type_nullable(f)?,
);
fields.push(field);
}
Ok(DataType::Struct(fields.into()))
}
r#type::Kind::Struct(s) => Ok(DataType::Struct(from_substrait_struct(
s, dfs_names, name_idx,
)?)),
_ => not_impl_err!("Unsupported Substrait type: {s_kind:?}"),
},
_ => not_impl_err!("`None` Substrait kind is not supported"),
}
}

fn from_substrait_struct(
s: &r#type::Struct,
dfs_names: &Vec<String>,
name_idx: &mut usize,
) -> Result<Fields> {
let mut fields = vec![];
for (i, f) in s.types.iter().enumerate() {
let field = Field::new(
next_struct_field_name(i, dfs_names, name_idx)?,
from_substrait_type_with_names(f, dfs_names, name_idx)?,
is_substrait_type_nullable(f)?,
);
fields.push(field);
}
Ok(fields.into())
}

fn next_struct_field_name(
i: usize,
dfs_names: &Vec<String>,
name_idx: &mut usize,
) -> Result<String> {
if dfs_names.is_empty() {
// If names are not given, create dummy names
// c0, c1, ... align with e.g. SqlToRel::create_named_struct
Ok(format!("c{i}"))
} else {
let name = dfs_names.get(*name_idx).cloned().ok_or_else(|| {
substrait_datafusion_err!("Named schema must contain names for all fields")
})?;
*name_idx += 1;
Ok(name)
}
}

fn is_substrait_type_nullable(dtype: &Type) -> Result<bool> {
fn is_nullable(nullability: i32) -> bool {
nullability != substrait::proto::r#type::Nullability::Required as i32
Expand Down Expand Up @@ -1258,6 +1337,14 @@ fn from_substrait_bound(
}

pub(crate) fn from_substrait_literal(lit: &Literal) -> Result<ScalarValue> {
from_substrait_literal_with_names(lit, &vec![], &mut 0)
}

fn from_substrait_literal_with_names(
lit: &Literal,
dfs_names: &Vec<String>,
name_idx: &mut usize,
) -> Result<ScalarValue> {
let scalar_value = match &lit.literal_type {
Some(LiteralType::Boolean(b)) => ScalarValue::Boolean(Some(*b)),
Some(LiteralType::I8(n)) => match lit.type_variation_reference {
Expand Down Expand Up @@ -1377,23 +1464,27 @@ pub(crate) fn from_substrait_literal(lit: &Literal) -> Result<ScalarValue> {
Some(LiteralType::Struct(s)) => {
let mut builder = ScalarStructBuilder::new();
for (i, field) in s.fields.iter().enumerate() {
let sv = from_substrait_literal(field)?;
// c0, c1, ... align with e.g. SqlToRel::create_named_struct
builder = builder.with_scalar(
Field::new(&format!("c{i}"), sv.data_type(), field.nullable),
sv,
);
let name = next_struct_field_name(i, dfs_names, name_idx)?;
let sv = from_substrait_literal_with_names(field, dfs_names, name_idx)?;
builder = builder
.with_scalar(Field::new(name, sv.data_type(), field.nullable), sv);
}
builder.build()?
}
Some(LiteralType::Null(ntype)) => from_substrait_null(ntype)?,
Some(LiteralType::Null(ntype)) => {
from_substrait_null_with_names(ntype, dfs_names, name_idx)?
}
_ => return not_impl_err!("Unsupported literal_type: {:?}", lit.literal_type),
};

Ok(scalar_value)
}

fn from_substrait_null(null_type: &Type) -> Result<ScalarValue> {
fn from_substrait_null_with_names(
null_type: &Type,
dfs_names: &Vec<String>,
name_idx: &mut usize,
) -> Result<ScalarValue> {
if let Some(kind) = &null_type.kind {
match kind {
r#type::Kind::Bool(_) => Ok(ScalarValue::Boolean(None)),
Expand Down Expand Up @@ -1486,6 +1577,10 @@ fn from_substrait_null(null_type: &Type) -> Result<ScalarValue> {
),
}
}
r#type::Kind::Struct(s) => {
let fields = from_substrait_struct(s, dfs_names, name_idx)?;
Ok(ScalarStructBuilder::new_null(fields))
}
_ => not_impl_err!("Unsupported Substrait type for null: {kind:?}"),
}
} else {
Expand Down
105 changes: 101 additions & 4 deletions datafusion/substrait/src/logical_plan/producer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
// specific language governing permissions and limitations
// under the License.

use itertools::Itertools;
use std::collections::HashMap;
use std::ops::Deref;
use std::sync::Arc;
Expand All @@ -31,7 +32,9 @@ use datafusion::{
};

use datafusion::arrow::array::{Array, GenericListArray, OffsetSizeTrait};
use datafusion::common::{exec_err, internal_err, not_impl_err};
use datafusion::common::{
exec_err, internal_err, not_impl_err, plan_err, substrait_datafusion_err,
};
use datafusion::common::{substrait_err, DFSchemaRef};
#[allow(unused_imports)]
use datafusion::logical_expr::aggregate_function;
Expand All @@ -46,6 +49,7 @@ use substrait::proto::exchange_rel::{ExchangeKind, RoundRobin, ScatterFields};
use substrait::proto::expression::literal::{List, Struct};
use substrait::proto::expression::subquery::InPredicate;
use substrait::proto::expression::window_function::BoundsType;
use substrait::proto::read_rel::VirtualTable;
use substrait::proto::{CrossRel, ExchangeRel};
use substrait::{
proto::{
Expand Down Expand Up @@ -167,6 +171,98 @@ pub fn to_substrait_rel(
}))),
}))
}
LogicalPlan::Values(v) => {
fn field_names_dfs(dtype: &DataType) -> Result<Vec<String>> {
// Substrait wants a list of all field names, including nested fields from structs,
// also from within lists and maps. However, it does not want the list and map field names
// themselves - only structs are considered to have useful names.
match dtype {
DataType::Struct(fields) => {
let mut names = Vec::new();
for field in fields {
names.push(field.name().to_string());
names.extend(field_names_dfs(field.data_type())?);
}
Ok(names)
}
DataType::List(l) => field_names_dfs(l.data_type()),
DataType::Map(m, _) => match m.data_type() {
DataType::Struct(key_and_value) if key_and_value.len() == 2 => {
let key_names = field_names_dfs(
key_and_value.first().unwrap().data_type(),
)?;
let value_names = field_names_dfs(
key_and_value.last().unwrap().data_type(),
)?;
Ok([key_names, value_names].concat())
}
_ => plan_err!(
"Map fields must contain a Struct with exactly 2 fields"
),
},
_ => Ok(Vec::new()),
}
}
let names = v
.schema
.fields()
.iter()
.map(|f| {
let mut names = vec![f.name().to_string()];
names.extend(field_names_dfs(f.data_type())?);
Ok(names)
})
.flatten_ok()
.collect::<Result<_>>()?;

let field_types = r#type::Struct {
types: v
.schema
.fields()
.iter()
.map(|f| to_substrait_type(f.data_type(), f.is_nullable()))
.collect::<Result<_>>()?,
type_variation_reference: DEFAULT_TYPE_REF,
nullability: r#type::Nullability::Unspecified as i32,
};
let values = v
.values
.iter()
.map(|row| {
let fields = row
.iter()
.map(|v| match v {
Expr::Literal(sv) => to_substrait_literal(sv),
Expr::Alias(alias) => match alias.expr.as_ref() {
// The schema gives us the names, so we can skip aliases
Expr::Literal(sv) => to_substrait_literal(sv),
_ => Err(substrait_datafusion_err!(
"Only literal types can be aliased in Virtual Tables, got: {}", alias.expr.variant_name()
)),
},
_ => Err(substrait_datafusion_err!(
"Only literal types and aliases are supported in Virtual Tables, got: {}", v.variant_name()
)),
})
.collect::<Result<_>>()?;
Ok(Struct { fields })
})
.collect::<Result<_>>()?;
Ok(Box::new(Rel {
rel_type: Some(RelType::Read(Box::new(ReadRel {
common: None,
base_schema: Some(NamedStruct {
names,
r#struct: Some(field_types),
}),
filter: None,
best_effort_filter: None,
projection: None,
advanced_extension: None,
read_type: Some(ReadType::VirtualTable(VirtualTable { values })),
}))),
}))
}
LogicalPlan::Projection(p) => {
let expressions = p
.expr
Expand Down Expand Up @@ -1996,11 +2092,12 @@ mod test {
let c2 = Field::new("c2", DataType::Utf8, true);
round_trip_literal(
ScalarStructBuilder::new()
.with_scalar(c0, ScalarValue::Boolean(Some(true)))
.with_scalar(c1, ScalarValue::Int32(Some(1)))
.with_scalar(c2, ScalarValue::Utf8(None))
.with_scalar(c0.to_owned(), ScalarValue::Boolean(Some(true)))
.with_scalar(c1.to_owned(), ScalarValue::Int32(Some(1)))
.with_scalar(c2.to_owned(), ScalarValue::Utf8(None))
.build()?,
)?;
round_trip_literal(ScalarStructBuilder::new_null(vec![c0, c1, c2]))?;

Ok(())
}
Expand Down
8 changes: 8 additions & 0 deletions datafusion/substrait/tests/cases/roundtrip_logical_plan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -685,6 +685,14 @@ async fn roundtrip_literal_struct() -> Result<()> {
.await
}

#[tokio::test]
async fn roundtrip_values() -> Result<()> {
assert_expected_plan(
"VALUES (1, 'a', [[-213.1, NULL, 5.5, 2.0, 1.0], []], STRUCT(true, 1, CAST(NULL AS STRING))), (NULL, NULL, NULL, NULL)",
"Values: (Int64(1), Utf8(\"a\"), List([[-213.1, , 5.5, 2.0, 1.0], []]), Struct({c0:true,c1:1,c2:})), (Int64(NULL), Utf8(NULL), List(), Struct({c0:,c1:,c2:}))")
.await
}

/// Construct a plan that cast columns. Only those SQL types are supported for now.
#[tokio::test]
async fn new_test_grammar() -> Result<()> {
Expand Down

0 comments on commit 02e7bc0

Please sign in to comment.