diff --git a/datafusion/substrait/src/lib.rs b/datafusion/substrait/src/lib.rs index 34c9c46edf0c..432553ec7903 100644 --- a/datafusion/substrait/src/lib.rs +++ b/datafusion/substrait/src/lib.rs @@ -18,6 +18,7 @@ pub mod logical_plan; pub mod physical_plan; pub mod serializer; +pub mod variation_const; // Re-export substrait crate pub use substrait; diff --git a/datafusion/substrait/src/logical_plan/consumer.rs b/datafusion/substrait/src/logical_plan/consumer.rs index 767c4a39375a..607012bfd629 100644 --- a/datafusion/substrait/src/logical_plan/consumer.rs +++ b/datafusion/substrait/src/logical_plan/consumer.rs @@ -16,7 +16,7 @@ // under the License. use async_recursion::async_recursion; -use datafusion::arrow::datatypes::DataType; +use datafusion::arrow::datatypes::{DataType, Field, TimeUnit}; use datafusion::common::{DFField, DFSchema, DFSchemaRef}; use datafusion::logical_expr::{ aggregate_function, window_function::find_df_window_func, BinaryExpr, Case, Expr, @@ -32,6 +32,7 @@ use datafusion::{ prelude::{Column, SessionContext}, scalar::ScalarValue, }; +use substrait::proto::expression::Literal; use substrait::proto::{ aggregate_function::AggregationInvocation, expression::{ @@ -56,6 +57,13 @@ use std::collections::HashMap; use std::str::FromStr; use std::sync::Arc; +use crate::variation_const::{ + DATE_32_TYPE_REF, DATE_64_TYPE_REF, DECIMAL_128_TYPE_REF, DECIMAL_256_TYPE_REF, + DEFAULT_CONTAINER_TYPE_REF, DEFAULT_TYPE_REF, LARGE_CONTAINER_TYPE_REF, + TIMESTAMP_MICRO_TYPE_REF, TIMESTAMP_MILLI_TYPE_REF, TIMESTAMP_NANO_TYPE_REF, + TIMESTAMP_SECOND_TYPE_REF, UNSIGNED_INTEGER_TYPE_REF, +}; + pub fn name_to_op(name: &str) -> Result { match name { "equal" => Ok(Operator::Eq), @@ -682,109 +690,8 @@ pub async fn from_substrait_rex( } } Some(RexType::Literal(lit)) => { - match &lit.literal_type { - Some(LiteralType::I8(n)) => { - if lit.type_variation_reference == 0 { - Ok(Arc::new(Expr::Literal(ScalarValue::Int8(Some(*n as i8))))) - } else if lit.type_variation_reference == 1 { - Ok(Arc::new(Expr::Literal(ScalarValue::UInt8(Some(*n as u8))))) - } else { - Err(DataFusionError::Substrait(format!( - "Unknown type variation reference {}", - lit.type_variation_reference - ))) - } - } - Some(LiteralType::I16(n)) => { - if lit.type_variation_reference == 0 { - Ok(Arc::new(Expr::Literal(ScalarValue::Int16(Some(*n as i16))))) - } else if lit.type_variation_reference == 1 { - Ok(Arc::new(Expr::Literal(ScalarValue::UInt16(Some( - *n as u16, - ))))) - } else { - Err(DataFusionError::Substrait(format!( - "Unknown type variation reference {}", - lit.type_variation_reference - ))) - } - } - Some(LiteralType::I32(n)) => { - if lit.type_variation_reference == 0 { - Ok(Arc::new(Expr::Literal(ScalarValue::Int32(Some(*n))))) - } else if lit.type_variation_reference == 1 { - Ok(Arc::new(Expr::Literal(ScalarValue::UInt32(Some(unsafe { - std::mem::transmute_copy::(n) - }))))) - } else { - Err(DataFusionError::Substrait(format!( - "Unknown type variation reference {}", - lit.type_variation_reference - ))) - } - } - Some(LiteralType::I64(n)) => { - if lit.type_variation_reference == 0 { - Ok(Arc::new(Expr::Literal(ScalarValue::Int64(Some(*n))))) - } else if lit.type_variation_reference == 1 { - Ok(Arc::new(Expr::Literal(ScalarValue::UInt64(Some(unsafe { - std::mem::transmute_copy::(n) - }))))) - } else { - Err(DataFusionError::Substrait(format!( - "Unknown type variation reference {}", - lit.type_variation_reference - ))) - } - } - Some(LiteralType::Boolean(b)) => { - Ok(Arc::new(Expr::Literal(ScalarValue::Boolean(Some(*b))))) - } - Some(LiteralType::Date(d)) => { - Ok(Arc::new(Expr::Literal(ScalarValue::Date32(Some(*d))))) - } - Some(LiteralType::Fp32(f)) => { - Ok(Arc::new(Expr::Literal(ScalarValue::Float32(Some(*f))))) - } - Some(LiteralType::Fp64(f)) => { - Ok(Arc::new(Expr::Literal(ScalarValue::Float64(Some(*f))))) - } - Some(LiteralType::Decimal(d)) => { - let value: [u8; 16] = d.value.clone().try_into().or(Err( - DataFusionError::Substrait( - "Failed to parse decimal value".to_string(), - ), - ))?; - let p = d.precision.try_into().map_err(|e| { - DataFusionError::Substrait(format!( - "Failed to parse decimal precision: {e}" - )) - })?; - let s = d.scale.try_into().map_err(|e| { - DataFusionError::Substrait(format!( - "Failed to parse decimal scale: {e}" - )) - })?; - Ok(Arc::new(Expr::Literal(ScalarValue::Decimal128( - Some(std::primitive::i128::from_le_bytes(value)), - p, - s, - )))) - } - Some(LiteralType::String(s)) => { - Ok(Arc::new(Expr::Literal(ScalarValue::Utf8(Some(s.clone()))))) - } - Some(LiteralType::Binary(b)) => Ok(Arc::new(Expr::Literal( - ScalarValue::Binary(Some(b.clone())), - ))), - Some(LiteralType::Null(ntype)) => { - Ok(Arc::new(Expr::Literal(from_substrait_null(ntype)?))) - } - _ => Err(DataFusionError::NotImplemented(format!( - "Unsupported literal_type: {:?}", - lit.literal_type - ))), - } + let scalar_value = from_substrait_literal(lit)?; + Ok(Arc::new(Expr::Literal(scalar_value))) } Some(RexType::Cast(cast)) => match cast.as_ref().r#type.as_ref() { Some(output_type) => Ok(Arc::new(Expr::Cast(Cast::new( @@ -855,13 +762,104 @@ fn from_substrait_type(dt: &substrait::proto::Type) -> Result { match &dt.kind { Some(s_kind) => match s_kind { r#type::Kind::Bool(_) => Ok(DataType::Boolean), - r#type::Kind::I8(_) => Ok(DataType::Int8), - r#type::Kind::I16(_) => Ok(DataType::Int16), - r#type::Kind::I32(_) => Ok(DataType::Int32), - r#type::Kind::I64(_) => Ok(DataType::Int64), - r#type::Kind::Decimal(d) => { - Ok(DataType::Decimal128(d.precision as u8, d.scale as i8)) + r#type::Kind::I8(integer) => match integer.type_variation_reference { + DEFAULT_TYPE_REF => Ok(DataType::Int8), + UNSIGNED_INTEGER_TYPE_REF => Ok(DataType::UInt8), + v => Err(DataFusionError::NotImplemented(format!( + "Unsupported Substrait type variation {v} of type {s_kind:?}" + ))), + }, + r#type::Kind::I16(integer) => match integer.type_variation_reference { + DEFAULT_TYPE_REF => Ok(DataType::Int16), + UNSIGNED_INTEGER_TYPE_REF => Ok(DataType::UInt16), + v => Err(DataFusionError::NotImplemented(format!( + "Unsupported Substrait type variation {v} of type {s_kind:?}" + ))), + }, + r#type::Kind::I32(integer) => match integer.type_variation_reference { + DEFAULT_TYPE_REF => Ok(DataType::Int32), + UNSIGNED_INTEGER_TYPE_REF => Ok(DataType::UInt32), + v => Err(DataFusionError::NotImplemented(format!( + "Unsupported Substrait type variation {v} of type {s_kind:?}" + ))), + }, + r#type::Kind::I64(integer) => match integer.type_variation_reference { + DEFAULT_TYPE_REF => Ok(DataType::Int64), + UNSIGNED_INTEGER_TYPE_REF => Ok(DataType::UInt64), + v => Err(DataFusionError::NotImplemented(format!( + "Unsupported Substrait type variation {v} of type {s_kind:?}" + ))), + }, + r#type::Kind::Fp32(_) => Ok(DataType::Float32), + r#type::Kind::Fp64(_) => Ok(DataType::Float64), + r#type::Kind::Timestamp(ts) => match ts.type_variation_reference { + TIMESTAMP_SECOND_TYPE_REF => { + Ok(DataType::Timestamp(TimeUnit::Second, None)) + } + TIMESTAMP_MILLI_TYPE_REF => { + Ok(DataType::Timestamp(TimeUnit::Millisecond, None)) + } + TIMESTAMP_MICRO_TYPE_REF => { + Ok(DataType::Timestamp(TimeUnit::Microsecond, None)) + } + TIMESTAMP_NANO_TYPE_REF => { + Ok(DataType::Timestamp(TimeUnit::Nanosecond, None)) + } + v => Err(DataFusionError::NotImplemented(format!( + "Unsupported Substrait type variation {v} of type {s_kind:?}" + ))), + }, + r#type::Kind::Date(date) => match date.type_variation_reference { + DATE_32_TYPE_REF => Ok(DataType::Date32), + DATE_64_TYPE_REF => Ok(DataType::Date64), + v => Err(DataFusionError::NotImplemented(format!( + "Unsupported Substrait type variation {v} of type {s_kind:?}" + ))), + }, + r#type::Kind::Binary(binary) => match binary.type_variation_reference { + DEFAULT_CONTAINER_TYPE_REF => Ok(DataType::Binary), + LARGE_CONTAINER_TYPE_REF => Ok(DataType::LargeBinary), + v => Err(DataFusionError::NotImplemented(format!( + "Unsupported Substrait type variation {v} of type {s_kind:?}" + ))), + }, + r#type::Kind::FixedBinary(fixed) => { + Ok(DataType::FixedSizeBinary(fixed.length)) } + r#type::Kind::String(string) => match string.type_variation_reference { + DEFAULT_CONTAINER_TYPE_REF => Ok(DataType::Utf8), + LARGE_CONTAINER_TYPE_REF => Ok(DataType::LargeUtf8), + v => Err(DataFusionError::NotImplemented(format!( + "Unsupported Substrait type variation {v} of type {s_kind:?}" + ))), + }, + r#type::Kind::List(list) => { + let inner_type = + from_substrait_type(list.r#type.as_ref().ok_or_else(|| { + DataFusionError::Substrait( + "List type must have inner type".to_string(), + ) + })?)?; + let field = Box::new(Field::new("list_item", inner_type, true)); + match list.type_variation_reference { + DEFAULT_CONTAINER_TYPE_REF => Ok(DataType::List(field)), + LARGE_CONTAINER_TYPE_REF => Ok(DataType::LargeList(field)), + v => Err(DataFusionError::NotImplemented(format!( + "Unsupported Substrait type variation {v} of type {s_kind:?}" + )))?, + } + } + r#type::Kind::Decimal(d) => match d.type_variation_reference { + DECIMAL_128_TYPE_REF => { + Ok(DataType::Decimal128(d.precision as u8, d.scale as i8)) + } + DECIMAL_256_TYPE_REF => { + Ok(DataType::Decimal256(d.precision as u8, d.scale as i8)) + } + v => Err(DataFusionError::NotImplemented(format!( + "Unsupported Substrait type variation {v} of type {s_kind:?}" + ))), + }, _ => Err(DataFusionError::NotImplemented(format!( "Unsupported Substrait type: {s_kind:?}" ))), @@ -910,20 +908,196 @@ fn from_substrait_bound( } } +fn from_substrait_literal(lit: &Literal) -> Result { + let scalar_value = match &lit.literal_type { + Some(LiteralType::Boolean(b)) => ScalarValue::Boolean(Some(*b)), + Some(LiteralType::I8(n)) => match lit.type_variation_reference { + DEFAULT_TYPE_REF => ScalarValue::Int8(Some(*n as i8)), + UNSIGNED_INTEGER_TYPE_REF => ScalarValue::UInt8(Some(*n as u8)), + others => { + return Err(DataFusionError::Substrait(format!( + "Unknown type variation reference {others}", + ))); + } + }, + Some(LiteralType::I16(n)) => match lit.type_variation_reference { + DEFAULT_TYPE_REF => ScalarValue::Int16(Some(*n as i16)), + UNSIGNED_INTEGER_TYPE_REF => ScalarValue::UInt16(Some(*n as u16)), + others => { + return Err(DataFusionError::Substrait(format!( + "Unknown type variation reference {others}", + ))); + } + }, + Some(LiteralType::I32(n)) => match lit.type_variation_reference { + DEFAULT_TYPE_REF => ScalarValue::Int32(Some(*n)), + UNSIGNED_INTEGER_TYPE_REF => ScalarValue::UInt32(Some(unsafe { + std::mem::transmute_copy::(n) + })), + others => { + return Err(DataFusionError::Substrait(format!( + "Unknown type variation reference {others}", + ))); + } + }, + Some(LiteralType::I64(n)) => match lit.type_variation_reference { + DEFAULT_TYPE_REF => ScalarValue::Int64(Some(*n)), + UNSIGNED_INTEGER_TYPE_REF => ScalarValue::UInt64(Some(unsafe { + std::mem::transmute_copy::(n) + })), + others => { + return Err(DataFusionError::Substrait(format!( + "Unknown type variation reference {others}", + ))); + } + }, + Some(LiteralType::Fp32(f)) => ScalarValue::Float32(Some(*f)), + Some(LiteralType::Fp64(f)) => ScalarValue::Float64(Some(*f)), + Some(LiteralType::Timestamp(t)) => match lit.type_variation_reference { + TIMESTAMP_SECOND_TYPE_REF => ScalarValue::TimestampSecond(Some(*t), None), + TIMESTAMP_MILLI_TYPE_REF => ScalarValue::TimestampMillisecond(Some(*t), None), + TIMESTAMP_MICRO_TYPE_REF => ScalarValue::TimestampMicrosecond(Some(*t), None), + TIMESTAMP_NANO_TYPE_REF => ScalarValue::TimestampNanosecond(Some(*t), None), + others => { + return Err(DataFusionError::Substrait(format!( + "Unknown type variation reference {others}", + ))); + } + }, + Some(LiteralType::Date(d)) => ScalarValue::Date32(Some(*d)), + Some(LiteralType::String(s)) => match lit.type_variation_reference { + DEFAULT_CONTAINER_TYPE_REF => ScalarValue::Utf8(Some(s.clone())), + LARGE_CONTAINER_TYPE_REF => ScalarValue::LargeUtf8(Some(s.clone())), + others => { + return Err(DataFusionError::Substrait(format!( + "Unknown type variation reference {others}", + ))); + } + }, + Some(LiteralType::Binary(b)) => match lit.type_variation_reference { + DEFAULT_CONTAINER_TYPE_REF => ScalarValue::Binary(Some(b.clone())), + LARGE_CONTAINER_TYPE_REF => ScalarValue::LargeBinary(Some(b.clone())), + others => { + return Err(DataFusionError::Substrait(format!( + "Unknown type variation reference {others}", + ))); + } + }, + Some(LiteralType::FixedBinary(b)) => { + ScalarValue::FixedSizeBinary(b.len() as _, Some(b.clone())) + } + Some(LiteralType::Decimal(d)) => { + let value: [u8; 16] = + d.value + .clone() + .try_into() + .or(Err(DataFusionError::Substrait( + "Failed to parse decimal value".to_string(), + )))?; + let p = d.precision.try_into().map_err(|e| { + DataFusionError::Substrait(format!( + "Failed to parse decimal precision: {e}" + )) + })?; + let s = d.scale.try_into().map_err(|e| { + DataFusionError::Substrait(format!("Failed to parse decimal scale: {e}")) + })?; + ScalarValue::Decimal128( + Some(std::primitive::i128::from_le_bytes(value)), + p, + s, + ) + } + Some(LiteralType::Null(ntype)) => from_substrait_null(ntype)?, + _ => { + return Err(DataFusionError::NotImplemented(format!( + "Unsupported literal_type: {:?}", + lit.literal_type + ))) + } + }; + + Ok(scalar_value) +} + fn from_substrait_null(null_type: &Type) -> Result { if let Some(kind) = &null_type.kind { match kind { - r#type::Kind::I8(_) => Ok(ScalarValue::Int8(None)), - r#type::Kind::I16(_) => Ok(ScalarValue::Int16(None)), - r#type::Kind::I32(_) => Ok(ScalarValue::Int32(None)), - r#type::Kind::I64(_) => Ok(ScalarValue::Int64(None)), + r#type::Kind::Bool(_) => Ok(ScalarValue::Boolean(None)), + r#type::Kind::I8(integer) => match integer.type_variation_reference { + DEFAULT_TYPE_REF => Ok(ScalarValue::Int8(None)), + UNSIGNED_INTEGER_TYPE_REF => Ok(ScalarValue::UInt8(None)), + v => Err(DataFusionError::NotImplemented(format!( + "Unsupported Substrait type variation {v} of type {kind:?}" + ))), + }, + r#type::Kind::I16(integer) => match integer.type_variation_reference { + DEFAULT_TYPE_REF => Ok(ScalarValue::Int16(None)), + UNSIGNED_INTEGER_TYPE_REF => Ok(ScalarValue::UInt16(None)), + v => Err(DataFusionError::NotImplemented(format!( + "Unsupported Substrait type variation {v} of type {kind:?}" + ))), + }, + r#type::Kind::I32(integer) => match integer.type_variation_reference { + DEFAULT_TYPE_REF => Ok(ScalarValue::Int32(None)), + UNSIGNED_INTEGER_TYPE_REF => Ok(ScalarValue::UInt32(None)), + v => Err(DataFusionError::NotImplemented(format!( + "Unsupported Substrait type variation {v} of type {kind:?}" + ))), + }, + r#type::Kind::I64(integer) => match integer.type_variation_reference { + DEFAULT_TYPE_REF => Ok(ScalarValue::Int64(None)), + UNSIGNED_INTEGER_TYPE_REF => Ok(ScalarValue::UInt64(None)), + v => Err(DataFusionError::NotImplemented(format!( + "Unsupported Substrait type variation {v} of type {kind:?}" + ))), + }, + r#type::Kind::Fp32(_) => Ok(ScalarValue::Float32(None)), + r#type::Kind::Fp64(_) => Ok(ScalarValue::Float64(None)), + r#type::Kind::Timestamp(ts) => match ts.type_variation_reference { + TIMESTAMP_SECOND_TYPE_REF => Ok(ScalarValue::TimestampSecond(None, None)), + TIMESTAMP_MILLI_TYPE_REF => { + Ok(ScalarValue::TimestampMillisecond(None, None)) + } + TIMESTAMP_MICRO_TYPE_REF => { + Ok(ScalarValue::TimestampMicrosecond(None, None)) + } + TIMESTAMP_NANO_TYPE_REF => { + Ok(ScalarValue::TimestampNanosecond(None, None)) + } + v => Err(DataFusionError::NotImplemented(format!( + "Unsupported Substrait type variation {v} of type {kind:?}" + ))), + }, + r#type::Kind::Date(date) => match date.type_variation_reference { + DATE_32_TYPE_REF => Ok(ScalarValue::Date32(None)), + DATE_64_TYPE_REF => Ok(ScalarValue::Date64(None)), + v => Err(DataFusionError::NotImplemented(format!( + "Unsupported Substrait type variation {v} of type {kind:?}" + ))), + }, + r#type::Kind::Binary(binary) => match binary.type_variation_reference { + DEFAULT_CONTAINER_TYPE_REF => Ok(ScalarValue::Binary(None)), + LARGE_CONTAINER_TYPE_REF => Ok(ScalarValue::LargeBinary(None)), + v => Err(DataFusionError::NotImplemented(format!( + "Unsupported Substrait type variation {v} of type {kind:?}" + ))), + }, + // FixedBinary is not supported because `None` doesn't have length + r#type::Kind::String(string) => match string.type_variation_reference { + DEFAULT_CONTAINER_TYPE_REF => Ok(ScalarValue::Utf8(None)), + LARGE_CONTAINER_TYPE_REF => Ok(ScalarValue::LargeUtf8(None)), + v => Err(DataFusionError::NotImplemented(format!( + "Unsupported Substrait type variation {v} of type {kind:?}" + ))), + }, r#type::Kind::Decimal(d) => Ok(ScalarValue::Decimal128( None, d.precision as u8, d.scale as i8, )), _ => Err(DataFusionError::NotImplemented(format!( - "Unsupported null kind: {kind:?}" + "Unsupported Substrait type: {kind:?}" ))), } } else { diff --git a/datafusion/substrait/src/logical_plan/producer.rs b/datafusion/substrait/src/logical_plan/producer.rs index c0a4dd04f33a..9ad9645ffcc4 100644 --- a/datafusion/substrait/src/logical_plan/producer.rs +++ b/datafusion/substrait/src/logical_plan/producer.rs @@ -18,7 +18,7 @@ use std::{collections::HashMap, mem, sync::Arc}; use datafusion::{ - arrow::datatypes::DataType, + arrow::datatypes::{DataType, TimeUnit}, error::{DataFusionError, Result}, logical_expr::{WindowFrame, WindowFrameBound}, prelude::JoinType, @@ -63,6 +63,13 @@ use substrait::{ version, }; +use crate::variation_const::{ + DATE_32_TYPE_REF, DATE_64_TYPE_REF, DECIMAL_128_TYPE_REF, DECIMAL_256_TYPE_REF, + DEFAULT_CONTAINER_TYPE_REF, DEFAULT_TYPE_REF, LARGE_CONTAINER_TYPE_REF, + TIMESTAMP_MICRO_TYPE_REF, TIMESTAMP_MILLI_TYPE_REF, TIMESTAMP_NANO_TYPE_REF, + TIMESTAMP_SECOND_TYPE_REF, UNSIGNED_INTEGER_TYPE_REF, +}; + /// Convert DataFusion LogicalPlan to Substrait Plan pub fn to_substrait_plan(plan: &LogicalPlan) -> Result> { // Parse relation nodes @@ -637,48 +644,7 @@ pub fn to_substrait_rex( ))), }) } - Expr::Literal(value) => { - let literal_type = match value { - ScalarValue::Int8(Some(n)) => Some(LiteralType::I8(*n as i32)), - ScalarValue::UInt8(Some(n)) => Some(LiteralType::I8(*n as i32)), - ScalarValue::Int16(Some(n)) => Some(LiteralType::I16(*n as i32)), - ScalarValue::UInt16(Some(n)) => Some(LiteralType::I16(*n as i32)), - ScalarValue::Int32(Some(n)) => Some(LiteralType::I32(*n)), - ScalarValue::UInt32(Some(n)) => Some(LiteralType::I32(unsafe { - mem::transmute_copy::(n) - })), - ScalarValue::Int64(Some(n)) => Some(LiteralType::I64(*n)), - ScalarValue::UInt64(Some(n)) => Some(LiteralType::I64(unsafe { - mem::transmute_copy::(n) - })), - ScalarValue::Boolean(Some(b)) => Some(LiteralType::Boolean(*b)), - ScalarValue::Float32(Some(f)) => Some(LiteralType::Fp32(*f)), - ScalarValue::Float64(Some(f)) => Some(LiteralType::Fp64(*f)), - ScalarValue::Decimal128(v, p, s) if v.is_some() => { - Some(LiteralType::Decimal(Decimal { - value: v.unwrap().to_le_bytes().to_vec(), - precision: *p as i32, - scale: *s as i32, - })) - } - ScalarValue::Utf8(Some(s)) => Some(LiteralType::String(s.clone())), - ScalarValue::LargeUtf8(Some(s)) => Some(LiteralType::String(s.clone())), - ScalarValue::Binary(Some(b)) => Some(LiteralType::Binary(b.clone())), - ScalarValue::LargeBinary(Some(b)) => Some(LiteralType::Binary(b.clone())), - ScalarValue::Date32(Some(d)) => Some(LiteralType::Date(*d)), - _ => Some(try_to_substrait_null(value)?), - }; - - let type_variation_reference = if value.is_unsigned() { 1 } else { 0 }; - - Ok(Expression { - rex_type: Some(RexType::Literal(Literal { - nullable: true, - type_variation_reference, - literal_type, - })), - }) - } + Expr::Literal(value) => to_substrait_literal(value), Expr::Alias(expr, _alias) => to_substrait_rex(expr, schema, extension_info), Expr::WindowFunction(WindowFunction { fun, @@ -728,7 +694,6 @@ pub fn to_substrait_rex( } fn to_substrait_type(dt: &DataType) -> Result { - let default_type_ref = 0; let default_nullability = r#type::Nullability::Required as i32; match dt { DataType::Null => Err(DataFusionError::Internal( @@ -736,37 +701,173 @@ fn to_substrait_type(dt: &DataType) -> Result { )), DataType::Boolean => Ok(substrait::proto::Type { kind: Some(r#type::Kind::Bool(r#type::Boolean { - type_variation_reference: default_type_ref, + type_variation_reference: DEFAULT_TYPE_REF, nullability: default_nullability, })), }), DataType::Int8 => Ok(substrait::proto::Type { kind: Some(r#type::Kind::I8(r#type::I8 { - type_variation_reference: default_type_ref, + type_variation_reference: DEFAULT_TYPE_REF, + nullability: default_nullability, + })), + }), + DataType::UInt8 => Ok(substrait::proto::Type { + kind: Some(r#type::Kind::I8(r#type::I8 { + type_variation_reference: UNSIGNED_INTEGER_TYPE_REF, nullability: default_nullability, })), }), DataType::Int16 => Ok(substrait::proto::Type { kind: Some(r#type::Kind::I16(r#type::I16 { - type_variation_reference: default_type_ref, + type_variation_reference: DEFAULT_TYPE_REF, + nullability: default_nullability, + })), + }), + DataType::UInt16 => Ok(substrait::proto::Type { + kind: Some(r#type::Kind::I16(r#type::I16 { + type_variation_reference: UNSIGNED_INTEGER_TYPE_REF, nullability: default_nullability, })), }), DataType::Int32 => Ok(substrait::proto::Type { kind: Some(r#type::Kind::I32(r#type::I32 { - type_variation_reference: default_type_ref, + type_variation_reference: DEFAULT_TYPE_REF, + nullability: default_nullability, + })), + }), + DataType::UInt32 => Ok(substrait::proto::Type { + kind: Some(r#type::Kind::I32(r#type::I32 { + type_variation_reference: UNSIGNED_INTEGER_TYPE_REF, nullability: default_nullability, })), }), DataType::Int64 => Ok(substrait::proto::Type { kind: Some(r#type::Kind::I64(r#type::I64 { - type_variation_reference: default_type_ref, + type_variation_reference: DEFAULT_TYPE_REF, + nullability: default_nullability, + })), + }), + DataType::UInt64 => Ok(substrait::proto::Type { + kind: Some(r#type::Kind::I64(r#type::I64 { + type_variation_reference: UNSIGNED_INTEGER_TYPE_REF, + nullability: default_nullability, + })), + }), + // Float16 is not supported in Substrait + DataType::Float32 => Ok(substrait::proto::Type { + kind: Some(r#type::Kind::Fp32(r#type::Fp32 { + type_variation_reference: DEFAULT_TYPE_REF, + nullability: default_nullability, + })), + }), + DataType::Float64 => Ok(substrait::proto::Type { + kind: Some(r#type::Kind::Fp64(r#type::Fp64 { + type_variation_reference: DEFAULT_TYPE_REF, + nullability: default_nullability, + })), + }), + // Timezone is ignored. + DataType::Timestamp(unit, _) => { + let type_variation_reference = match unit { + TimeUnit::Second => TIMESTAMP_SECOND_TYPE_REF, + TimeUnit::Millisecond => TIMESTAMP_MILLI_TYPE_REF, + TimeUnit::Microsecond => TIMESTAMP_MICRO_TYPE_REF, + TimeUnit::Nanosecond => TIMESTAMP_NANO_TYPE_REF, + }; + Ok(substrait::proto::Type { + kind: Some(r#type::Kind::Timestamp(r#type::Timestamp { + type_variation_reference, + nullability: default_nullability, + })), + }) + } + DataType::Date32 => Ok(substrait::proto::Type { + kind: Some(r#type::Kind::Date(r#type::Date { + type_variation_reference: DATE_32_TYPE_REF, + nullability: default_nullability, + })), + }), + DataType::Date64 => Ok(substrait::proto::Type { + kind: Some(r#type::Kind::Date(r#type::Date { + type_variation_reference: DATE_64_TYPE_REF, + nullability: default_nullability, + })), + }), + DataType::Binary => Ok(substrait::proto::Type { + kind: Some(r#type::Kind::Binary(r#type::Binary { + type_variation_reference: DEFAULT_CONTAINER_TYPE_REF, + nullability: default_nullability, + })), + }), + DataType::FixedSizeBinary(length) => Ok(substrait::proto::Type { + kind: Some(r#type::Kind::FixedBinary(r#type::FixedBinary { + length: *length, + type_variation_reference: DEFAULT_TYPE_REF, nullability: default_nullability, })), }), + DataType::LargeBinary => Ok(substrait::proto::Type { + kind: Some(r#type::Kind::Binary(r#type::Binary { + type_variation_reference: LARGE_CONTAINER_TYPE_REF, + nullability: default_nullability, + })), + }), + DataType::Utf8 => Ok(substrait::proto::Type { + kind: Some(r#type::Kind::String(r#type::String { + type_variation_reference: DEFAULT_CONTAINER_TYPE_REF, + nullability: default_nullability, + })), + }), + DataType::LargeUtf8 => Ok(substrait::proto::Type { + kind: Some(r#type::Kind::String(r#type::String { + type_variation_reference: LARGE_CONTAINER_TYPE_REF, + nullability: default_nullability, + })), + }), + DataType::List(inner) => { + let inner_type = to_substrait_type(inner.data_type())?; + Ok(substrait::proto::Type { + kind: Some(r#type::Kind::List(Box::new(r#type::List { + r#type: Some(Box::new(inner_type)), + type_variation_reference: DEFAULT_CONTAINER_TYPE_REF, + nullability: default_nullability, + }))), + }) + } + DataType::LargeList(inner) => { + let inner_type = to_substrait_type(inner.data_type())?; + Ok(substrait::proto::Type { + kind: Some(r#type::Kind::List(Box::new(r#type::List { + r#type: Some(Box::new(inner_type)), + type_variation_reference: LARGE_CONTAINER_TYPE_REF, + nullability: default_nullability, + }))), + }) + } + DataType::Struct(fields) => { + let field_types = fields + .iter() + .map(|field| to_substrait_type(field.data_type())) + .collect::>>()?; + Ok(substrait::proto::Type { + kind: Some(r#type::Kind::Struct(r#type::Struct { + types: field_types, + type_variation_reference: DEFAULT_TYPE_REF, + nullability: default_nullability, + })), + }) + } DataType::Decimal128(p, s) => Ok(substrait::proto::Type { kind: Some(r#type::Kind::Decimal(r#type::Decimal { - type_variation_reference: default_type_ref, + type_variation_reference: DECIMAL_128_TYPE_REF, + nullability: default_nullability, + scale: *s as i32, + precision: *p as i32, + })), + }), + DataType::Decimal256(p, s) => Ok(substrait::proto::Type { + kind: Some(r#type::Kind::Decimal(r#type::Decimal { + type_variation_reference: DECIMAL_256_TYPE_REF, nullability: default_nullability, scale: *s as i32, precision: *p as i32, @@ -908,31 +1009,215 @@ fn to_substrait_bounds(window_frame: &WindowFrame) -> Result<(Bound, Bound)> { )) } +fn to_substrait_literal(value: &ScalarValue) -> Result { + let (literal_type, type_variation_reference) = match value { + ScalarValue::Boolean(Some(b)) => (LiteralType::Boolean(*b), DEFAULT_TYPE_REF), + ScalarValue::Int8(Some(n)) => (LiteralType::I8(*n as i32), DEFAULT_TYPE_REF), + ScalarValue::UInt8(Some(n)) => { + (LiteralType::I8(*n as i32), UNSIGNED_INTEGER_TYPE_REF) + } + ScalarValue::Int16(Some(n)) => (LiteralType::I16(*n as i32), DEFAULT_TYPE_REF), + ScalarValue::UInt16(Some(n)) => { + (LiteralType::I16(*n as i32), UNSIGNED_INTEGER_TYPE_REF) + } + ScalarValue::Int32(Some(n)) => (LiteralType::I32(*n), DEFAULT_TYPE_REF), + ScalarValue::UInt32(Some(n)) => ( + LiteralType::I32(unsafe { mem::transmute_copy::(n) }), + UNSIGNED_INTEGER_TYPE_REF, + ), + ScalarValue::Int64(Some(n)) => (LiteralType::I64(*n), DEFAULT_TYPE_REF), + ScalarValue::UInt64(Some(n)) => ( + LiteralType::I64(unsafe { mem::transmute_copy::(n) }), + UNSIGNED_INTEGER_TYPE_REF, + ), + ScalarValue::Float32(Some(f)) => (LiteralType::Fp32(*f), DEFAULT_TYPE_REF), + ScalarValue::Float64(Some(f)) => (LiteralType::Fp64(*f), DEFAULT_TYPE_REF), + ScalarValue::TimestampSecond(Some(t), _) => { + (LiteralType::Timestamp(*t), TIMESTAMP_SECOND_TYPE_REF) + } + ScalarValue::TimestampMillisecond(Some(t), _) => { + (LiteralType::Timestamp(*t), TIMESTAMP_MILLI_TYPE_REF) + } + ScalarValue::TimestampMicrosecond(Some(t), _) => { + (LiteralType::Timestamp(*t), TIMESTAMP_MICRO_TYPE_REF) + } + ScalarValue::TimestampNanosecond(Some(t), _) => { + (LiteralType::Timestamp(*t), TIMESTAMP_NANO_TYPE_REF) + } + ScalarValue::Date32(Some(d)) => (LiteralType::Date(*d), DATE_32_TYPE_REF), + // Date64 literal is not supported in Substrait + ScalarValue::Binary(Some(b)) => { + (LiteralType::Binary(b.clone()), DEFAULT_CONTAINER_TYPE_REF) + } + ScalarValue::LargeBinary(Some(b)) => { + (LiteralType::Binary(b.clone()), LARGE_CONTAINER_TYPE_REF) + } + ScalarValue::FixedSizeBinary(_, Some(b)) => { + (LiteralType::FixedBinary(b.clone()), DEFAULT_TYPE_REF) + } + ScalarValue::Utf8(Some(s)) => { + (LiteralType::String(s.clone()), DEFAULT_CONTAINER_TYPE_REF) + } + ScalarValue::LargeUtf8(Some(s)) => { + (LiteralType::String(s.clone()), LARGE_CONTAINER_TYPE_REF) + } + ScalarValue::Decimal128(v, p, s) if v.is_some() => ( + LiteralType::Decimal(Decimal { + value: v.unwrap().to_le_bytes().to_vec(), + precision: *p as i32, + scale: *s as i32, + }), + DECIMAL_128_TYPE_REF, + ), + _ => (try_to_substrait_null(value)?, DEFAULT_TYPE_REF), + }; + + Ok(Expression { + rex_type: Some(RexType::Literal(Literal { + nullable: true, + type_variation_reference, + literal_type: Some(literal_type), + })), + }) +} + fn try_to_substrait_null(v: &ScalarValue) -> Result { - let default_type_ref = 0; + // let default_type_ref = 0; let default_nullability = r#type::Nullability::Nullable as i32; match v { ScalarValue::Int8(None) => Ok(LiteralType::Null(substrait::proto::Type { kind: Some(r#type::Kind::I8(r#type::I8 { - type_variation_reference: default_type_ref, + type_variation_reference: DEFAULT_TYPE_REF, + nullability: default_nullability, + })), + })), + ScalarValue::UInt8(None) => Ok(LiteralType::Null(substrait::proto::Type { + kind: Some(r#type::Kind::I8(r#type::I8 { + type_variation_reference: UNSIGNED_INTEGER_TYPE_REF, nullability: default_nullability, })), })), ScalarValue::Int16(None) => Ok(LiteralType::Null(substrait::proto::Type { kind: Some(r#type::Kind::I16(r#type::I16 { - type_variation_reference: default_type_ref, + type_variation_reference: DEFAULT_TYPE_REF, + nullability: default_nullability, + })), + })), + ScalarValue::UInt16(None) => Ok(LiteralType::Null(substrait::proto::Type { + kind: Some(r#type::Kind::I16(r#type::I16 { + type_variation_reference: UNSIGNED_INTEGER_TYPE_REF, nullability: default_nullability, })), })), ScalarValue::Int32(None) => Ok(LiteralType::Null(substrait::proto::Type { kind: Some(r#type::Kind::I32(r#type::I32 { - type_variation_reference: default_type_ref, + type_variation_reference: DEFAULT_TYPE_REF, + nullability: default_nullability, + })), + })), + ScalarValue::UInt32(None) => Ok(LiteralType::Null(substrait::proto::Type { + kind: Some(r#type::Kind::I32(r#type::I32 { + type_variation_reference: UNSIGNED_INTEGER_TYPE_REF, nullability: default_nullability, })), })), ScalarValue::Int64(None) => Ok(LiteralType::Null(substrait::proto::Type { kind: Some(r#type::Kind::I64(r#type::I64 { - type_variation_reference: default_type_ref, + type_variation_reference: DEFAULT_TYPE_REF, + nullability: default_nullability, + })), + })), + ScalarValue::UInt64(None) => Ok(LiteralType::Null(substrait::proto::Type { + kind: Some(r#type::Kind::I64(r#type::I64 { + type_variation_reference: UNSIGNED_INTEGER_TYPE_REF, + nullability: default_nullability, + })), + })), + ScalarValue::Float32(None) => Ok(LiteralType::Null(substrait::proto::Type { + kind: Some(r#type::Kind::Fp32(r#type::Fp32 { + type_variation_reference: DEFAULT_TYPE_REF, + nullability: default_nullability, + })), + })), + ScalarValue::Float64(None) => Ok(LiteralType::Null(substrait::proto::Type { + kind: Some(r#type::Kind::Fp64(r#type::Fp64 { + type_variation_reference: DEFAULT_TYPE_REF, + nullability: default_nullability, + })), + })), + ScalarValue::TimestampSecond(None, _) => { + Ok(LiteralType::Null(substrait::proto::Type { + kind: Some(r#type::Kind::Timestamp(r#type::Timestamp { + type_variation_reference: TIMESTAMP_SECOND_TYPE_REF, + nullability: default_nullability, + })), + })) + } + ScalarValue::TimestampMillisecond(None, _) => { + Ok(LiteralType::Null(substrait::proto::Type { + kind: Some(r#type::Kind::Timestamp(r#type::Timestamp { + type_variation_reference: TIMESTAMP_MILLI_TYPE_REF, + nullability: default_nullability, + })), + })) + } + ScalarValue::TimestampMicrosecond(None, _) => { + Ok(LiteralType::Null(substrait::proto::Type { + kind: Some(r#type::Kind::Timestamp(r#type::Timestamp { + type_variation_reference: TIMESTAMP_MICRO_TYPE_REF, + nullability: default_nullability, + })), + })) + } + ScalarValue::TimestampNanosecond(None, _) => { + Ok(LiteralType::Null(substrait::proto::Type { + kind: Some(r#type::Kind::Timestamp(r#type::Timestamp { + type_variation_reference: TIMESTAMP_NANO_TYPE_REF, + nullability: default_nullability, + })), + })) + } + ScalarValue::Date32(None) => Ok(LiteralType::Null(substrait::proto::Type { + kind: Some(r#type::Kind::Date(r#type::Date { + type_variation_reference: DATE_32_TYPE_REF, + nullability: default_nullability, + })), + })), + ScalarValue::Date64(None) => Ok(LiteralType::Null(substrait::proto::Type { + kind: Some(r#type::Kind::Date(r#type::Date { + type_variation_reference: DATE_64_TYPE_REF, + nullability: default_nullability, + })), + })), + ScalarValue::Binary(None) => Ok(LiteralType::Null(substrait::proto::Type { + kind: Some(r#type::Kind::Binary(r#type::Binary { + type_variation_reference: DEFAULT_CONTAINER_TYPE_REF, + nullability: default_nullability, + })), + })), + ScalarValue::LargeBinary(None) => Ok(LiteralType::Null(substrait::proto::Type { + kind: Some(r#type::Kind::Binary(r#type::Binary { + type_variation_reference: LARGE_CONTAINER_TYPE_REF, + nullability: default_nullability, + })), + })), + ScalarValue::FixedSizeBinary(_, None) => { + Ok(LiteralType::Null(substrait::proto::Type { + kind: Some(r#type::Kind::Binary(r#type::Binary { + type_variation_reference: DEFAULT_TYPE_REF, + nullability: default_nullability, + })), + })) + } + ScalarValue::Utf8(None) => Ok(LiteralType::Null(substrait::proto::Type { + kind: Some(r#type::Kind::String(r#type::String { + type_variation_reference: DEFAULT_CONTAINER_TYPE_REF, + nullability: default_nullability, + })), + })), + ScalarValue::LargeUtf8(None) => Ok(LiteralType::Null(substrait::proto::Type { + kind: Some(r#type::Kind::String(r#type::String { + type_variation_reference: LARGE_CONTAINER_TYPE_REF, nullability: default_nullability, })), })), @@ -941,7 +1226,7 @@ fn try_to_substrait_null(v: &ScalarValue) -> Result { kind: Some(r#type::Kind::Decimal(r#type::Decimal { scale: *s as i32, precision: *p as i32, - type_variation_reference: default_type_ref, + type_variation_reference: DEFAULT_TYPE_REF, nullability: default_nullability, })), })) diff --git a/datafusion/substrait/src/variation_const.rs b/datafusion/substrait/src/variation_const.rs new file mode 100644 index 000000000000..27ef15153bd8 --- /dev/null +++ b/datafusion/substrait/src/variation_const.rs @@ -0,0 +1,39 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Type variation constants +//! +//! To add support for types not in the [core specification](https://substrait.io/types/type_classes/), +//! we make use of the [simple extensions](https://substrait.io/extensions/#simple-extensions) of substrait +//! type. This module contains the constants used to identify the type variation. +//! +//! The rules of type variations here are: +//! - Default type reference is 0. It is used when the actual type is the same with the original type. +//! - Extended variant type references start from 1, and ususlly increase by 1. + +pub const DEFAULT_TYPE_REF: u32 = 0; +pub const UNSIGNED_INTEGER_TYPE_REF: u32 = 1; +pub const TIMESTAMP_SECOND_TYPE_REF: u32 = 0; +pub const TIMESTAMP_MILLI_TYPE_REF: u32 = 1; +pub const TIMESTAMP_MICRO_TYPE_REF: u32 = 2; +pub const TIMESTAMP_NANO_TYPE_REF: u32 = 3; +pub const DATE_32_TYPE_REF: u32 = 0; +pub const DATE_64_TYPE_REF: u32 = 1; +pub const DEFAULT_CONTAINER_TYPE_REF: u32 = 0; +pub const LARGE_CONTAINER_TYPE_REF: u32 = 1; +pub const DECIMAL_128_TYPE_REF: u32 = 0; +pub const DECIMAL_256_TYPE_REF: u32 = 1; diff --git a/datafusion/substrait/tests/roundtrip_logical_plan.rs b/datafusion/substrait/tests/roundtrip_logical_plan.rs index 3389658d2aca..965c007e98d3 100644 --- a/datafusion/substrait/tests/roundtrip_logical_plan.rs +++ b/datafusion/substrait/tests/roundtrip_logical_plan.rs @@ -21,7 +21,7 @@ use datafusion_substrait::logical_plan::{consumer, producer}; mod tests { use crate::{consumer::from_substrait_plan, producer::to_substrait_plan}; - use datafusion::arrow::datatypes::{DataType, Field, Schema}; + use datafusion::arrow::datatypes::{DataType, Field, Schema, TimeUnit}; use datafusion::error::Result; use datafusion::prelude::*; use substrait::proto::extensions::simple_extension_declaration::MappingType; @@ -262,7 +262,65 @@ mod tests { #[tokio::test] async fn qualified_catalog_schema_table_reference() -> Result<()> { - roundtrip("SELECT * FROM datafusion.public.data;").await + roundtrip("SELECT a,b,c,d,e FROM datafusion.public.data;").await + } + + /// Construct a plan that contains several literals of types that are currently supported. + /// This case ignores: + /// - Date64, for this literal is not supported + /// - FixedSizeBinary, for converting UTF-8 literal to FixedSizeBinary is not supported + /// - List, this nested type is not supported in arrow_cast + /// - Decimal128 and Decimal256, them will fallback to UTF8 cast expr rather than plain literal. + #[tokio::test] + async fn all_type_literal() -> Result<()> { + roundtrip_all_types( + "select * from data where + bool_col = TRUE AND + int8_col = arrow_cast('0', 'Int8') AND + uint8_col = arrow_cast('0', 'UInt8') AND + int16_col = arrow_cast('0', 'Int16') AND + uint16_col = arrow_cast('0', 'UInt16') AND + int32_col = arrow_cast('0', 'Int32') AND + uint32_col = arrow_cast('0', 'UInt32') AND + int64_col = arrow_cast('0', 'Int64') AND + uint64_col = arrow_cast('0', 'UInt64') AND + float32_col = arrow_cast('0', 'Float32') AND + float64_col = arrow_cast('0', 'Float64') AND + sec_timestamp_col = arrow_cast('2020-01-01 00:00:00', 'Timestamp (Second, None)') AND + ms_timestamp_col = arrow_cast('2020-01-01 00:00:00', 'Timestamp (Millisecond, None)') AND + us_timestamp_col = arrow_cast('2020-01-01 00:00:00', 'Timestamp (Microsecond, None)') AND + ns_timestamp_col = arrow_cast('2020-01-01 00:00:00', 'Timestamp (Nanosecond, None)') AND + date32_col = arrow_cast('2020-01-01', 'Date32') AND + binary_col = arrow_cast('binary', 'Binary') AND + large_binary_col = arrow_cast('large_binary', 'LargeBinary') AND + utf8_col = arrow_cast('utf8', 'Utf8') AND + large_utf8_col = arrow_cast('large_utf8', 'LargeUtf8');", + ) + .await + } + + /// Construct a plan that cast columns. Only those SQL types are supported for now. + #[tokio::test] + async fn new_test_grammar() -> Result<()> { + roundtrip_all_types( + "select + bool_col::boolean, + int8_col::tinyint, + uint8_col::tinyint unsigned, + int16_col::smallint, + uint16_col::smallint unsigned, + int32_col::integer, + uint32_col::integer unsigned, + int64_col::bigint, + uint64_col::bigint unsigned, + float32_col::float, + float64_col::double, + decimal_128_col::decimal(10, 2), + date32_col::date, + binary_col::bytea + from data", + ) + .await } async fn assert_expected_plan(sql: &str, expected_plan_str: &str) -> Result<()> { @@ -333,6 +391,23 @@ mod tests { Ok(()) } + async fn roundtrip_all_types(sql: &str) -> Result<()> { + let mut ctx = create_all_type_context().await?; + let df = ctx.sql(sql).await?; + let plan = df.into_optimized_plan()?; + let proto = to_substrait_plan(&plan)?; + let plan2 = from_substrait_plan(&mut ctx, &proto).await?; + let plan2 = ctx.state().optimize(&plan2)?; + + println!("{plan:#?}"); + println!("{plan2:#?}"); + + let plan1str = format!("{plan:?}"); + let plan2str = format!("{plan2:?}"); + assert_eq!(plan1str, plan2str); + Ok(()) + } + async fn function_extension_info(sql: &str) -> Result<(Vec, Vec)> { let ctx = create_context().await?; let df = ctx.sql(sql).await?; @@ -373,4 +448,68 @@ mod tests { .await?; Ok(ctx) } + + /// Cover all supported types + async fn create_all_type_context() -> Result { + let ctx = SessionContext::new(); + let mut explicit_options = CsvReadOptions::new(); + let schema = Schema::new(vec![ + Field::new("bool_col", DataType::Boolean, true), + Field::new("int8_col", DataType::Int8, true), + Field::new("uint8_col", DataType::UInt8, true), + Field::new("int16_col", DataType::Int16, true), + Field::new("uint16_col", DataType::UInt16, true), + Field::new("int32_col", DataType::Int32, true), + Field::new("uint32_col", DataType::UInt32, true), + Field::new("int64_col", DataType::Int64, true), + Field::new("uint64_col", DataType::UInt64, true), + Field::new("float32_col", DataType::Float32, true), + Field::new("float64_col", DataType::Float64, true), + Field::new( + "sec_timestamp_col", + DataType::Timestamp(TimeUnit::Second, None), + true, + ), + Field::new( + "ms_timestamp_col", + DataType::Timestamp(TimeUnit::Millisecond, None), + true, + ), + Field::new( + "us_timestamp_col", + DataType::Timestamp(TimeUnit::Microsecond, None), + true, + ), + Field::new( + "ns_timestamp_col", + DataType::Timestamp(TimeUnit::Nanosecond, None), + true, + ), + Field::new("date32_col", DataType::Date32, true), + Field::new("date64_col", DataType::Date64, true), + Field::new("binary_col", DataType::Binary, true), + Field::new("large_binary_col", DataType::LargeBinary, true), + Field::new("fixed_size_binary_col", DataType::FixedSizeBinary(42), true), + Field::new("utf8_col", DataType::Utf8, true), + Field::new("large_utf8_col", DataType::LargeUtf8, true), + Field::new( + "list_col", + DataType::List(Box::new(Field::new("item", DataType::Int64, true))), + true, + ), + Field::new( + "large_list_col", + DataType::LargeList(Box::new(Field::new("item", DataType::Int64, true))), + true, + ), + Field::new("decimal_128_col", DataType::Decimal128(10, 2), true), + Field::new("decimal_256_col", DataType::Decimal256(10, 2), true), + ]); + explicit_options.schema = Some(&schema); + explicit_options.has_header = false; + ctx.register_csv("data", "tests/testdata/empty.csv", explicit_options) + .await?; + + Ok(ctx) + } } diff --git a/datafusion/substrait/tests/testdata/empty.csv b/datafusion/substrait/tests/testdata/empty.csv new file mode 100644 index 000000000000..e69de29bb2d1