Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 77 additions & 10 deletions datafusion/substrait/src/physical_plan/consumer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@
use std::collections::HashMap;
use std::sync::Arc;

use datafusion::arrow::datatypes::Schema;
use datafusion::common::not_impl_err;
use datafusion::arrow::datatypes::{DataType, Field, Schema};
use datafusion::common::{not_impl_err, substrait_err};
use datafusion::datasource::listing::PartitionedFile;
use datafusion::datasource::object_store::ObjectStoreUrl;
use datafusion::datasource::physical_plan::{FileScanConfig, ParquetExec};
Expand All @@ -30,7 +30,9 @@ use datafusion::prelude::SessionContext;
use async_recursion::async_recursion;
use chrono::DateTime;
use object_store::ObjectMeta;
use substrait::proto::r#type::{Kind, Nullability};
use substrait::proto::read_rel::local_files::file_or_files::PathType;
use substrait::proto::Type;
use substrait::proto::{
expression::MaskExpression, read_rel::ReadType, rel::RelType, Rel,
};
Expand All @@ -42,17 +44,42 @@ pub async fn from_substrait_rel(
rel: &Rel,
_extensions: &HashMap<u32, &String>,
) -> Result<Arc<dyn ExecutionPlan>> {
let mut base_config;

match &rel.rel_type {
Some(RelType::Read(read)) => {
if read.filter.is_some() || read.best_effort_filter.is_some() {
return not_impl_err!("Read with filter is not supported");
}
if read.base_schema.is_some() {
return not_impl_err!("Read with schema is not supported");
}

if read.advanced_extension.is_some() {
return not_impl_err!("Read with AdvancedExtension is not supported");
}

let Some(schema) = read.base_schema.as_ref() else {
return substrait_err!("Missing base schema in the read");
};

let Some(r#struct) = schema.r#struct.as_ref() else {
return substrait_err!("Missing struct in the schema");
};

match schema
.names
.iter()
.zip(r#struct.types.iter())
.map(|(name, r#type)| to_field(name, r#type))
.collect::<Result<Vec<Field>>>()
{
Ok(fields) => {
base_config = FileScanConfig::new(
ObjectStoreUrl::local_filesystem(),
Arc::new(Schema::new(fields)),
);
}
Err(e) => return Err(e),
};

match &read.as_ref().read_type {
Some(ReadType::LocalFiles(files)) => {
let mut file_groups = vec![];
Expand Down Expand Up @@ -104,11 +131,7 @@ pub async fn from_substrait_rel(
file_groups[part_index].push(partitioned_file)
}

let mut base_config = FileScanConfig::new(
ObjectStoreUrl::local_filesystem(),
Arc::new(Schema::empty()),
)
.with_file_groups(file_groups);
base_config = base_config.with_file_groups(file_groups);

if let Some(MaskExpression { select, .. }) = &read.projection {
if let Some(projection) = &select.as_ref() {
Expand All @@ -132,3 +155,47 @@ pub async fn from_substrait_rel(
_ => not_impl_err!("Unsupported RelType: {:?}", rel.rel_type),
}
}

fn to_field(name: &String, r#type: &Type) -> Result<Field> {
let Some(kind) = r#type.kind.as_ref() else {
return substrait_err!("Missing kind in the type with name {}", name);
};

let mut nullable = false;
let data_type = match kind {
Kind::Bool(boolean) => {
nullable = is_nullable(boolean.nullability);
Ok(DataType::Boolean)
}
Kind::I64(i64) => {
nullable = is_nullable(i64.nullability);
Ok(DataType::Int64)
}
Kind::Fp64(fp64) => {
nullable = is_nullable(fp64.nullability);
Ok(DataType::Float64)
}
Kind::String(string) => {
nullable = is_nullable(string.nullability);
Ok(DataType::Utf8)
}
_ => substrait_err!(
"Unsupported kind: {:?} in the type with name {}",
kind,
name
),
}?;

Ok(Field::new(name, data_type, nullable))
}

fn is_nullable(nullability: i32) -> bool {
let Ok(nullability) = Nullability::try_from(nullability) else {
return true;
};

match nullability {
Nullability::Nullable | Nullability::Unspecified => true,
Nullability::Required => false,
}
}
93 changes: 89 additions & 4 deletions datafusion/substrait/src/physical_plan/producer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,16 @@
// specific language governing permissions and limitations
// under the License.

use datafusion::arrow::datatypes::DataType;
use datafusion::datasource::physical_plan::ParquetExec;
use datafusion::error::{DataFusionError, Result};
use datafusion::physical_plan::{displayable, ExecutionPlan};
use std::collections::HashMap;
use substrait::proto::expression::mask_expression::{StructItem, StructSelect};
use substrait::proto::expression::MaskExpression;
use substrait::proto::extensions;
use substrait::proto::r#type::{
Boolean, Fp64, Kind, Nullability, String as SubstraitString, Struct, I64,
};
use substrait::proto::read_rel::local_files::file_or_files::ParquetReadOptions;
use substrait::proto::read_rel::local_files::file_or_files::{FileFormat, PathType};
use substrait::proto::read_rel::local_files::FileOrFiles;
Expand All @@ -29,6 +33,7 @@ use substrait::proto::read_rel::ReadType;
use substrait::proto::rel::RelType;
use substrait::proto::ReadRel;
use substrait::proto::Rel;
use substrait::proto::{extensions, NamedStruct, Type};

/// Convert DataFusion ExecutionPlan to Substrait Rel
pub fn to_substrait_rel(
Expand All @@ -55,15 +60,56 @@ pub fn to_substrait_rel(
}
}

let mut names = vec![];
let mut types = vec![];

for field in base_config.file_schema.fields.iter() {
match to_substrait_type(field.data_type(), field.is_nullable()) {
Ok(t) => {
names.push(field.name().clone());
types.push(t);
}
Err(e) => return Err(e),
}
}

let type_info = Struct {
types,
// FIXME: duckdb doesn't set this field, keep it as default variant 0.
// https://github.com/duckdb/substrait/blob/b6f56643cb11d52de0e32c24a01dfd5947df62be/src/to_substrait.cpp#L1106-L1127
type_variation_reference: 0,
nullability: Nullability::Required.into(),
};

let mut select_struct = None;
if let Some(projection) = base_config.projection.as_ref() {
let struct_items = projection
.iter()
.map(|index| StructItem {
field: *index as i32,
// FIXME: duckdb sets this to None, but it's not clear why.
// https://github.com/duckdb/substrait/blob/b6f56643cb11d52de0e32c24a01dfd5947df62be/src/to_substrait.cpp#L1191
child: None,
})
.collect();

select_struct = Some(StructSelect { struct_items });
}

Ok(Box::new(Rel {
rel_type: Some(RelType::Read(Box::new(ReadRel {
common: None,
base_schema: None,
base_schema: Some(NamedStruct {
names,
r#struct: Some(type_info),
}),
filter: None,
best_effort_filter: None,
projection: Some(MaskExpression {
select: None,
maintain_singular_struct: false,
select: select_struct,
// FIXME: duckdb set this to true, but it's not clear why.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  // fixme: whatever this means

😆

// https://github.com/duckdb/substrait/blob/b6f56643cb11d52de0e32c24a01dfd5947df62be/src/to_substrait.cpp#L1186.
maintain_singular_struct: true,
}),
advanced_extension: None,
read_type: Some(ReadType::LocalFiles(LocalFiles {
Expand All @@ -79,3 +125,42 @@ pub fn to_substrait_rel(
)))
}
}

// see https://github.com/duckdb/substrait/blob/b6f56643cb11d52de0e32c24a01dfd5947df62be/src/to_substrait.cpp#L954-L1094.
fn to_substrait_type(data_type: &DataType, nullable: bool) -> Result<Type> {
let nullability = if nullable {
Nullability::Nullable.into()
} else {
Nullability::Required.into()
};

match data_type {
DataType::Boolean => Ok(Type {
kind: Some(Kind::Bool(Boolean {
type_variation_reference: 0,
nullability,
})),
}),
DataType::Int64 => Ok(Type {
kind: Some(Kind::I64(I64 {
type_variation_reference: 0,
nullability,
})),
}),
DataType::Float64 => Ok(Type {
kind: Some(Kind::Fp64(Fp64 {
type_variation_reference: 0,
nullability,
})),
}),
DataType::Utf8 => Ok(Type {
kind: Some(Kind::String(SubstraitString {
type_variation_reference: 0,
nullability,
})),
}),
_ => Err(DataFusionError::Substrait(format!(
"Logical type {data_type} not implemented as substrait type"
))),
}
}
92 changes: 91 additions & 1 deletion datafusion/substrait/tests/cases/roundtrip_physical_plan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,13 @@ use std::collections::HashMap;
use std::sync::Arc;

use datafusion::arrow::datatypes::Schema;
use datafusion::dataframe::DataFrame;
use datafusion::datasource::listing::PartitionedFile;
use datafusion::datasource::object_store::ObjectStoreUrl;
use datafusion::datasource::physical_plan::{FileScanConfig, ParquetExec};
use datafusion::error::Result;
use datafusion::physical_plan::{displayable, ExecutionPlan};
use datafusion::prelude::SessionContext;
use datafusion::prelude::{ParquetReadOptions, SessionContext};
use datafusion_substrait::physical_plan::{consumer, producer};

use substrait::proto::extensions;
Expand Down Expand Up @@ -71,3 +72,92 @@ async fn parquet_exec() -> Result<()> {

Ok(())
}

#[tokio::test]
async fn simple_select() -> Result<()> {
roundtrip("SELECT a, b FROM data").await
}

#[tokio::test]
#[ignore = "This test is failing because the translation of the substrait plan to the physical plan is not implemented yet"]
async fn simple_select_alltypes() -> Result<()> {
roundtrip_alltypes("SELECT bool_col, int_col FROM alltypes_plain").await
}

#[tokio::test]
async fn wildcard_select() -> Result<()> {
roundtrip("SELECT * FROM data").await
}

#[tokio::test]
#[ignore = "This test is failing because the translation of the substrait plan to the physical plan is not implemented yet"]
async fn wildcard_select_alltypes() -> Result<()> {
roundtrip_alltypes("SELECT * FROM alltypes_plain").await
}

async fn roundtrip(sql: &str) -> Result<()> {
let ctx = create_parquet_context().await?;
let df = ctx.sql(sql).await?;

roundtrip_parquet(df).await?;

Ok(())
}

async fn roundtrip_alltypes(sql: &str) -> Result<()> {
let ctx = create_all_types_context().await?;
let df = ctx.sql(sql).await?;

roundtrip_parquet(df).await?;

Ok(())
}

async fn roundtrip_parquet(df: DataFrame) -> Result<()> {
let physical_plan = df.create_physical_plan().await?;

// Convert the plan into a substrait (protobuf) Rel
let mut extension_info = (vec![], HashMap::new());
let substrait_plan =
producer::to_substrait_rel(physical_plan.as_ref(), &mut extension_info)?;

// Convert the substrait Rel back into a physical plan
let ctx = create_parquet_context().await?;
let physical_plan_roundtrip =
consumer::from_substrait_rel(&ctx, substrait_plan.as_ref(), &HashMap::new())
.await?;

// Compare the original and roundtrip physical plans
let expected = format!("{}", displayable(physical_plan.as_ref()).indent(true));
let actual = format!(
"{}",
displayable(physical_plan_roundtrip.as_ref()).indent(true)
);
assert_eq!(expected, actual);

Ok(())
}

async fn create_parquet_context() -> Result<SessionContext> {
let ctx = SessionContext::new();
let explicit_options = ParquetReadOptions::default();

ctx.register_parquet("data", "tests/testdata/data.parquet", explicit_options)
.await?;

Ok(ctx)
}

async fn create_all_types_context() -> Result<SessionContext> {
let ctx = SessionContext::new();

let testdata = datafusion::test_util::parquet_test_data();
ctx.register_parquet(
"alltypes_plain",
&format!("{testdata}/alltypes_plain.parquet"),
ParquetReadOptions::default(),
)
.await?;

Ok(ctx)
}
Loading