Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion datafusion/substrait/src/logical_plan/consumer/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@

use super::utils::{from_substrait_precision, next_struct_field_name, DEFAULT_TIMEZONE};
use super::SubstraitConsumer;
use crate::variation_const::FLOAT_16_TYPE_NAME;
#[allow(deprecated)]
use crate::variation_const::{
DATE_32_TYPE_VARIATION_REF, DATE_64_TYPE_VARIATION_REF,
Expand All @@ -33,6 +32,7 @@ use crate::variation_const::{
TIME_64_TYPE_VARIATION_REF, UNSIGNED_INTEGER_TYPE_VARIATION_REF,
VIEW_CONTAINER_TYPE_VARIATION_REF,
};
use crate::variation_const::{FLOAT_16_TYPE_NAME, NULL_TYPE_NAME};
use datafusion::arrow::datatypes::{
DataType, Field, Fields, IntervalUnit, Schema, TimeUnit,
};
Expand Down Expand Up @@ -253,6 +253,7 @@ pub fn from_substrait_type(
// Kept for backwards compatibility, producers should use IntervalCompound instead
INTERVAL_MONTH_DAY_NANO_TYPE_NAME => Ok(DataType::Interval(IntervalUnit::MonthDayNano)),
FLOAT_16_TYPE_NAME => Ok(DataType::Float16),
NULL_TYPE_NAME => Ok(DataType::Null),
_ => not_impl_err!(
"Unsupported Substrait user defined type with ref {} and variation {}",
u.type_reference,
Expand Down
29 changes: 23 additions & 6 deletions datafusion/substrait/src/logical_plan/producer/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,12 @@ use crate::variation_const::{
DEFAULT_CONTAINER_TYPE_VARIATION_REF, DEFAULT_INTERVAL_DAY_TYPE_VARIATION_REF,
DEFAULT_MAP_TYPE_VARIATION_REF, DEFAULT_TYPE_VARIATION_REF,
DICTIONARY_MAP_TYPE_VARIATION_REF, DURATION_INTERVAL_DAY_TYPE_VARIATION_REF,
FLOAT_16_TYPE_NAME, LARGE_CONTAINER_TYPE_VARIATION_REF, TIME_32_TYPE_VARIATION_REF,
TIME_64_TYPE_VARIATION_REF, UNSIGNED_INTEGER_TYPE_VARIATION_REF,
VIEW_CONTAINER_TYPE_VARIATION_REF,
FLOAT_16_TYPE_NAME, LARGE_CONTAINER_TYPE_VARIATION_REF, NULL_TYPE_NAME,
TIME_32_TYPE_VARIATION_REF, TIME_64_TYPE_VARIATION_REF,
UNSIGNED_INTEGER_TYPE_VARIATION_REF, VIEW_CONTAINER_TYPE_VARIATION_REF,
};
use datafusion::arrow::datatypes::{DataType, IntervalUnit};
use datafusion::common::{internal_err, not_impl_err, plan_err, DFSchemaRef};
use datafusion::common::{not_impl_err, plan_err, DFSchemaRef};
use substrait::proto::{r#type, NamedStruct};

pub(crate) fn to_substrait_type(
Expand All @@ -42,7 +42,17 @@ pub(crate) fn to_substrait_type(
r#type::Nullability::Required as i32
};
match dt {
DataType::Null => internal_err!("Null cast is not valid"),
DataType::Null => {
let type_anchor = producer.register_type(NULL_TYPE_NAME.to_string());
Ok(substrait::proto::Type {
kind: Some(r#type::Kind::UserDefined(r#type::UserDefined {
type_reference: type_anchor,
type_variation_reference: DEFAULT_TYPE_VARIATION_REF,
nullability,
type_parameters: vec![],
})),
})
}
DataType::Boolean => Ok(substrait::proto::Type {
kind: Some(r#type::Kind::Bool(r#type::Boolean {
type_variation_reference: DEFAULT_TYPE_VARIATION_REF,
Expand Down Expand Up @@ -377,6 +387,7 @@ mod tests {
use crate::logical_plan::consumer::tests::test_consumer;
use crate::logical_plan::consumer::{
from_substrait_named_struct, from_substrait_type_without_names,
DefaultSubstraitConsumer,
};
use crate::logical_plan::producer::DefaultSubstraitProducer;
use datafusion::arrow::datatypes::{Field, Fields, Schema, TimeUnit};
Expand All @@ -386,6 +397,7 @@ mod tests {

#[test]
fn round_trip_types() -> Result<()> {
round_trip_type(DataType::Null)?;
round_trip_type(DataType::Boolean)?;
round_trip_type(DataType::Int8)?;
round_trip_type(DataType::UInt8)?;
Expand Down Expand Up @@ -474,7 +486,12 @@ mod tests {
// As DataFusion doesn't consider nullability as a property of the type, but field,
// it doesn't matter if we set nullability to true or false here.
let substrait = to_substrait_type(&mut producer, &dt, true)?;
let consumer = test_consumer();

// Get the extensions from the producer so the consumer can look up
// any registered user-defined types (like "null" or "f16")
let extensions = producer.get_extensions();
let consumer = DefaultSubstraitConsumer::new(&extensions, &state);

let roundtrip_dt = from_substrait_type_without_names(&consumer, &substrait)?;
assert_eq!(dt, roundtrip_dt);
Ok(())
Expand Down
5 changes: 5 additions & 0 deletions datafusion/substrait/src/variation_const.rs
Original file line number Diff line number Diff line change
Expand Up @@ -125,3 +125,8 @@ pub const INTERVAL_MONTH_DAY_NANO_TYPE_NAME: &str = "interval-month-day-nano";

/// Defined in <https://github.com/apache/arrow/blame/main/format/substrait/extension_types.yaml>
pub const FLOAT_16_TYPE_NAME: &str = "fp16";

/// For [`DataType::Null`]
///
/// [`DataType::Null`]: datafusion::arrow::datatypes::DataType::Null
pub const NULL_TYPE_NAME: &str = "null";