diff --git a/datafusion/substrait/src/logical_plan/consumer.rs b/datafusion/substrait/src/logical_plan/consumer.rs index fe5d981ece70..9504098db674 100644 --- a/datafusion/substrait/src/logical_plan/consumer.rs +++ b/datafusion/substrait/src/logical_plan/consumer.rs @@ -1499,7 +1499,7 @@ fn from_substrait_struct_type( let field = Field::new( next_struct_field_name(i, dfs_names, name_idx)?, from_substrait_type(f, dfs_names, name_idx)?, - is_substrait_type_nullable(f)?, + true, // We assume everything to be nullable since that's easier than ensuring it matches ); fields.push(field); } @@ -1543,47 +1543,6 @@ fn from_substrait_named_struct(base_schema: &NamedStruct) -> Result Ok(DFSchemaRef::new(DFSchema::try_from(Schema::new(fields?))?)) } -fn is_substrait_type_nullable(dtype: &Type) -> Result { - fn is_nullable(nullability: i32) -> bool { - nullability != substrait::proto::r#type::Nullability::Required as i32 - } - - let nullable = match dtype - .kind - .as_ref() - .ok_or_else(|| substrait_datafusion_err!("Type must contain Kind"))? - { - r#type::Kind::Bool(val) => is_nullable(val.nullability), - r#type::Kind::I8(val) => is_nullable(val.nullability), - r#type::Kind::I16(val) => is_nullable(val.nullability), - r#type::Kind::I32(val) => is_nullable(val.nullability), - r#type::Kind::I64(val) => is_nullable(val.nullability), - r#type::Kind::Fp32(val) => is_nullable(val.nullability), - r#type::Kind::Fp64(val) => is_nullable(val.nullability), - r#type::Kind::String(val) => is_nullable(val.nullability), - r#type::Kind::Binary(val) => is_nullable(val.nullability), - r#type::Kind::Timestamp(val) => is_nullable(val.nullability), - r#type::Kind::Date(val) => is_nullable(val.nullability), - r#type::Kind::Time(val) => is_nullable(val.nullability), - r#type::Kind::IntervalYear(val) => is_nullable(val.nullability), - r#type::Kind::IntervalDay(val) => is_nullable(val.nullability), - r#type::Kind::TimestampTz(val) => is_nullable(val.nullability), - r#type::Kind::Uuid(val) => is_nullable(val.nullability), - r#type::Kind::FixedChar(val) => is_nullable(val.nullability), - r#type::Kind::Varchar(val) => is_nullable(val.nullability), - r#type::Kind::FixedBinary(val) => is_nullable(val.nullability), - r#type::Kind::Decimal(val) => is_nullable(val.nullability), - r#type::Kind::PrecisionTimestamp(val) => is_nullable(val.nullability), - r#type::Kind::PrecisionTimestampTz(val) => is_nullable(val.nullability), - r#type::Kind::Struct(val) => is_nullable(val.nullability), - r#type::Kind::List(val) => is_nullable(val.nullability), - r#type::Kind::Map(val) => is_nullable(val.nullability), - r#type::Kind::UserDefined(val) => is_nullable(val.nullability), - r#type::Kind::UserDefinedTypeReference(_) => true, // not implemented, assume nullable - }; - Ok(nullable) -} - fn from_substrait_bound( bound: &Option, is_lower: bool, @@ -1763,8 +1722,9 @@ fn from_substrait_literal( for (i, field) in s.fields.iter().enumerate() { let name = next_struct_field_name(i, dfs_names, name_idx)?; let sv = from_substrait_literal(field, dfs_names, name_idx)?; - builder = builder - .with_scalar(Field::new(name, sv.data_type(), field.nullable), sv); + // We assume everything to be nullable, since Arrow's strict about things matching + // and it's hard to match otherwise. + builder = builder.with_scalar(Field::new(name, sv.data_type(), true), sv); } builder.build()? } diff --git a/datafusion/substrait/src/logical_plan/producer.rs b/datafusion/substrait/src/logical_plan/producer.rs index c0469d333164..4ee24d868dc1 100644 --- a/datafusion/substrait/src/logical_plan/producer.rs +++ b/datafusion/substrait/src/logical_plan/producer.rs @@ -2258,8 +2258,8 @@ mod test { ), )))?; - let c0 = Field::new("c0", DataType::Boolean, false); - let c1 = Field::new("c1", DataType::Int32, false); + let c0 = Field::new("c0", DataType::Boolean, true); + let c1 = Field::new("c1", DataType::Int32, true); let c2 = Field::new("c2", DataType::Utf8, true); round_trip_literal( ScalarStructBuilder::new() @@ -2319,7 +2319,7 @@ mod test { round_trip_type(DataType::Struct( vec![ Field::new("c0", DataType::Int32, true), - Field::new("c1", DataType::Utf8, false), + Field::new("c1", DataType::Utf8, true), ] .into(), ))?;