Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 4 additions & 44 deletions datafusion/substrait/src/logical_plan/consumer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1499,7 +1499,7 @@ fn from_substrait_struct_type(
let field = Field::new(
next_struct_field_name(i, dfs_names, name_idx)?,
from_substrait_type(f, dfs_names, name_idx)?,
is_substrait_type_nullable(f)?,
true, // We assume everything to be nullable since that's easier than ensuring it matches
);
fields.push(field);
}
Expand Down Expand Up @@ -1543,47 +1543,6 @@ fn from_substrait_named_struct(base_schema: &NamedStruct) -> Result<DFSchemaRef>
Ok(DFSchemaRef::new(DFSchema::try_from(Schema::new(fields?))?))
}

fn is_substrait_type_nullable(dtype: &Type) -> Result<bool> {
fn is_nullable(nullability: i32) -> bool {
nullability != substrait::proto::r#type::Nullability::Required as i32
}

let nullable = match dtype
.kind
.as_ref()
.ok_or_else(|| substrait_datafusion_err!("Type must contain Kind"))?
{
r#type::Kind::Bool(val) => is_nullable(val.nullability),
r#type::Kind::I8(val) => is_nullable(val.nullability),
r#type::Kind::I16(val) => is_nullable(val.nullability),
r#type::Kind::I32(val) => is_nullable(val.nullability),
r#type::Kind::I64(val) => is_nullable(val.nullability),
r#type::Kind::Fp32(val) => is_nullable(val.nullability),
r#type::Kind::Fp64(val) => is_nullable(val.nullability),
r#type::Kind::String(val) => is_nullable(val.nullability),
r#type::Kind::Binary(val) => is_nullable(val.nullability),
r#type::Kind::Timestamp(val) => is_nullable(val.nullability),
r#type::Kind::Date(val) => is_nullable(val.nullability),
r#type::Kind::Time(val) => is_nullable(val.nullability),
r#type::Kind::IntervalYear(val) => is_nullable(val.nullability),
r#type::Kind::IntervalDay(val) => is_nullable(val.nullability),
r#type::Kind::TimestampTz(val) => is_nullable(val.nullability),
r#type::Kind::Uuid(val) => is_nullable(val.nullability),
r#type::Kind::FixedChar(val) => is_nullable(val.nullability),
r#type::Kind::Varchar(val) => is_nullable(val.nullability),
r#type::Kind::FixedBinary(val) => is_nullable(val.nullability),
r#type::Kind::Decimal(val) => is_nullable(val.nullability),
r#type::Kind::PrecisionTimestamp(val) => is_nullable(val.nullability),
r#type::Kind::PrecisionTimestampTz(val) => is_nullable(val.nullability),
r#type::Kind::Struct(val) => is_nullable(val.nullability),
r#type::Kind::List(val) => is_nullable(val.nullability),
r#type::Kind::Map(val) => is_nullable(val.nullability),
r#type::Kind::UserDefined(val) => is_nullable(val.nullability),
r#type::Kind::UserDefinedTypeReference(_) => true, // not implemented, assume nullable
};
Ok(nullable)
}

fn from_substrait_bound(
bound: &Option<Bound>,
is_lower: bool,
Expand Down Expand Up @@ -1763,8 +1722,9 @@ fn from_substrait_literal(
for (i, field) in s.fields.iter().enumerate() {
let name = next_struct_field_name(i, dfs_names, name_idx)?;
let sv = from_substrait_literal(field, dfs_names, name_idx)?;
builder = builder
.with_scalar(Field::new(name, sv.data_type(), field.nullable), sv);
// We assume everything to be nullable, since Arrow's strict about things matching
// and it's hard to match otherwise.
builder = builder.with_scalar(Field::new(name, sv.data_type(), true), sv);
}
builder.build()?
}
Expand Down
6 changes: 3 additions & 3 deletions datafusion/substrait/src/logical_plan/producer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2258,8 +2258,8 @@ mod test {
),
)))?;

let c0 = Field::new("c0", DataType::Boolean, false);
let c1 = Field::new("c1", DataType::Int32, false);
let c0 = Field::new("c0", DataType::Boolean, true);
let c1 = Field::new("c1", DataType::Int32, true);
let c2 = Field::new("c2", DataType::Utf8, true);
round_trip_literal(
ScalarStructBuilder::new()
Expand Down Expand Up @@ -2319,7 +2319,7 @@ mod test {
round_trip_type(DataType::Struct(
vec![
Field::new("c0", DataType::Int32, true),
Field::new("c1", DataType::Utf8, false),
Field::new("c1", DataType::Utf8, true),
]
.into(),
))?;
Expand Down