diff --git a/Cargo.lock b/Cargo.lock index 012573deb452d..d569a4471d5fe 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2335,6 +2335,7 @@ name = "datafusion-physical-expr" version = "53.1.0" dependencies = [ "arrow", + "arrow-schema", "criterion", "datafusion-common", "datafusion-expr", diff --git a/datafusion/expr/src/expr_schema.rs b/datafusion/expr/src/expr_schema.rs index c989bab3048ad..7a4e472b6dd56 100644 --- a/datafusion/expr/src/expr_schema.rs +++ b/datafusion/expr/src/expr_schema.rs @@ -31,6 +31,7 @@ use crate::{LogicalPlan, Projection, Subquery, WindowFunctionDefinition, utils}; use arrow::compute::can_cast_types; use arrow::datatypes::FieldRef; use arrow::datatypes::{DataType, Field}; +use arrow_schema::extension::{EXTENSION_TYPE_METADATA_KEY, EXTENSION_TYPE_NAME_KEY}; use datafusion_common::datatype::FieldExt; use datafusion_common::{ Column, DataFusionError, ExprSchema, Result, ScalarValue, Spans, TableReference, @@ -73,17 +74,35 @@ pub trait ExprSchemable { /// For `TryCast`, `force_nullable` is `true` since a failed cast returns NULL. fn cast_output_field( source_field: &FieldRef, - target_type: &DataType, + target_field: &FieldRef, force_nullable: bool, ) -> Arc { let mut f = source_field .as_ref() .clone() - .with_data_type(target_type.clone()) + .with_data_type(target_field.data_type().clone()) .with_metadata(source_field.metadata().clone()); + + // Extension type information is never propagated from the source field + // through a cast because there is no guarantee the output data type + // is a valid storage type for the extension. + f.metadata_mut().remove(EXTENSION_TYPE_NAME_KEY); + f.metadata_mut().remove(EXTENSION_TYPE_METADATA_KEY); + + // Metadata from target field overrides metadata from the source field. + // In most cases the target field will not have any metadata (created from + // a DataType), in which case this does nothing. Where the target field + // represents an extension type or includes other type-like metadata, + // this allows an optimizer rule or planner to insert the appropriate + // behaviour. + for (k, v) in target_field.metadata() { + f.metadata_mut().insert(k.clone(), v.clone()); + } + if force_nullable { f = f.with_nullable(true); } + Arc::new(f) } @@ -594,21 +613,16 @@ impl ExprSchemable for Expr { func.return_field_from_args(args) } - // _ => Ok((self.get_type(schema)?, self.nullable(schema)?)), - Expr::Cast(Cast { expr, field }) => { - expr.to_field(schema).map(|(_table_ref, src)| { - cast_output_field(&src, field.data_type(), false) - }) - } + Expr::Cast(Cast { expr, field }) => expr + .to_field(schema) + .map(|(_table_ref, src)| cast_output_field(&src, field, false)), Expr::Placeholder(Placeholder { id: _, field: Some(field), }) => Ok(Arc::clone(field).renamed(&schema_name)), - Expr::TryCast(TryCast { expr, field }) => { - expr.to_field(schema).map(|(_table_ref, src)| { - cast_output_field(&src, field.data_type(), true) - }) - } + Expr::TryCast(TryCast { expr, field }) => expr + .to_field(schema) + .map(|(_table_ref, src)| cast_output_field(&src, field, true)), Expr::LambdaVariable(LambdaVariable { field: Some(field), .. }) => Ok(Arc::clone(field).renamed(&schema_name)), @@ -1044,6 +1058,7 @@ mod tests { .with_metadata(meta.clone()); // col, alias, and cast should be metadata-preserving + // mark assert_eq!(meta, expr.metadata(&schema).unwrap()); assert_eq!(meta, expr.clone().alias("bar").metadata(&schema).unwrap()); assert_eq!( diff --git a/datafusion/physical-expr/Cargo.toml b/datafusion/physical-expr/Cargo.toml index b755353d75658..2f071e59e9615 100644 --- a/datafusion/physical-expr/Cargo.toml +++ b/datafusion/physical-expr/Cargo.toml @@ -45,6 +45,7 @@ recursive_protection = ["dep:recursive"] [dependencies] arrow = { workspace = true } +arrow-schema = { workspace = true } datafusion-common = { workspace = true } datafusion-expr = { workspace = true } datafusion-expr-common = { workspace = true } diff --git a/datafusion/physical-expr/src/expressions/cast.rs b/datafusion/physical-expr/src/expressions/cast.rs index ad214a89ceb71..90d534f6ccfc7 100644 --- a/datafusion/physical-expr/src/expressions/cast.rs +++ b/datafusion/physical-expr/src/expressions/cast.rs @@ -24,8 +24,10 @@ use crate::physical_expr::PhysicalExpr; use arrow::compute::{CastOptions, can_cast_types}; use arrow::datatypes::{DataType, DataType::*, FieldRef, Schema}; use arrow::record_batch::RecordBatch; +use arrow_schema::extension::{EXTENSION_TYPE_METADATA_KEY, EXTENSION_TYPE_NAME_KEY}; use datafusion_common::datatype::DataTypeExt; use datafusion_common::format::DEFAULT_FORMAT_OPTIONS; +use datafusion_common::metadata::format_type_and_metadata; use datafusion_common::nested_struct::{ requires_nested_struct_cast, validate_data_type_compatibility, }; @@ -151,17 +153,26 @@ impl CastExpr { } fn resolved_target_field(&self, input_schema: &Schema) -> Result { + let output_nullability = self.nullable(input_schema)?; if is_default_target_field(&self.target_field) { self.expr.return_field(input_schema).map(|field| { - Arc::new( - field - .as_ref() - .clone() - .with_data_type(self.cast_type().clone()), - ) + let cast_type = self.cast_type(); + let mut out_field = + field.as_ref().clone().with_data_type(cast_type.clone()); + + // Extension type information is never propagated from the source field + out_field.metadata_mut().remove(EXTENSION_TYPE_NAME_KEY); + out_field.metadata_mut().remove(EXTENSION_TYPE_METADATA_KEY); + + Arc::new(out_field.with_nullable(output_nullability)) }) } else { - Ok(Arc::clone(&self.target_field)) + Ok(Arc::new( + self.target_field + .as_ref() + .clone() + .with_nullable(output_nullability), + )) } } @@ -233,14 +244,8 @@ impl PhysicalExpr for CastExpr { } fn nullable(&self, input_schema: &Schema) -> Result { - // A cast is nullable if **either** the child is nullable or the - // target field allows nulls. This conservative rule prevents - // optimizers from assuming a non-null result when a null input could - // still propagate. `return_field()` continues to expose the exact - // target metadata separately. - let child_nullable = self.expr.nullable(input_schema)?; - let target_nullable = self.resolved_target_field(input_schema)?.is_nullable(); - Ok(child_nullable || target_nullable) + // Casts do not change the nullability of the input + self.expr.nullable(input_schema) } fn evaluate(&self, batch: &RecordBatch) -> Result { @@ -332,25 +337,38 @@ pub fn cast_with_target_field( target_field: FieldRef, cast_options: Option>, ) -> Result> { - let expr_type = expr.data_type(input_schema)?; + let expr_field = expr.return_field(input_schema)?; + let expr_type = expr_field.data_type(); let cast_type = target_field.data_type(); - if expr_type == *cast_type && is_default_target_field(&target_field) { + if expr_type == cast_type && is_default_target_field(&target_field) { return Ok(Arc::clone(&expr)); } - let can_build_cast = if requires_nested_struct_cast(&expr_type, cast_type) { + let can_build_cast = if target_field.extension_type_name().is_some() { + // Disallow casts to an extension type because we do not yet have a + // mechanism to ensure the target type will be valid. We allow a cast + // from an extension type (which casts the storage and discards the + // extension information) for backward compatibility. + false + } else if requires_nested_struct_cast(expr_type, cast_type) { // Allow casts involving structs (including nested inside Lists, Dictionaries, // etc.) that pass name-based compatibility validation. This validation is // applied at planning time (now) to fail fast, rather than deferring errors // to execution time. The name-based casting logic will be executed at runtime // via ColumnarValue::cast_to. - can_cast_named_struct_types(&expr_type, cast_type) + can_cast_named_struct_types(expr_type, cast_type) } else { - can_cast_types(&expr_type, cast_type) + can_cast_types(expr_type, cast_type) }; if !can_build_cast { - return not_impl_err!("Unsupported CAST from {expr_type} to {cast_type}"); + let expr_type_disp = + format_type_and_metadata(expr_type, Some(expr_field.metadata())); + let cast_type_disp = + format_type_and_metadata(cast_type, Some(target_field.metadata())); + return not_impl_err!( + "Unsupported CAST from {expr_type_disp} to {cast_type_disp}" + ); } Ok(Arc::new(CastExpr::new_with_target_field( @@ -944,12 +962,12 @@ mod tests { let field = expr.return_field(&schema)?; assert_eq!(field.name(), "cast_target"); assert_eq!(field.data_type(), &Int64); - assert_eq!(field.is_nullable(), target_nullable); + assert_eq!(field.is_nullable(), child_nullable); assert_eq!( field.metadata().get("target_meta").map(String::as_str), Some("1") ); - assert_eq!(expr.nullable(&schema)?, child_nullable || target_nullable); + assert_eq!(expr.nullable(&schema)?, child_nullable); } Ok(()) diff --git a/datafusion/physical-expr/src/planner.rs b/datafusion/physical-expr/src/planner.rs index d0d0508a106a5..d66427189905c 100644 --- a/datafusion/physical-expr/src/planner.rs +++ b/datafusion/physical-expr/src/planner.rs @@ -294,17 +294,25 @@ pub fn create_physical_expr( }; Ok(expressions::case(expr, when_then_expr, else_expr)?) } - Expr::Cast(Cast { expr, field }) => expressions::cast_with_target_field( - create_physical_expr(expr, input_dfschema, execution_props)?, - input_schema, - Arc::clone(field), - None, - ), + Expr::Cast(Cast { expr, .. }) => { + // The output field is calculated using a combination of the input + // and target fields. Use the calculated target field when forming the cast + // so that the logical and physical schemas align. + let (_, resolved_field) = e.to_field(input_dfschema)?; + expressions::cast_with_target_field( + create_physical_expr(expr, input_dfschema, execution_props)?, + input_schema, + Arc::clone(&resolved_field), + None, + ) + } Expr::TryCast(TryCast { expr, field }) => { + // The physical try_cast does not support a target field, + // so error if the target field carries metadata that would be dropped. if !field.metadata().is_empty() { let (_, src_field) = expr.to_field(input_dfschema)?; return plan_err!( - "TryCast from {} to {} is not supported", + "Unsupported TRY_CAST from {} to {}", format_type_and_metadata( src_field.data_type(), Some(src_field.metadata()), @@ -657,9 +665,13 @@ mod tests { let physical = lower_cast_expr(&cast_expr, &schema)?; let cast = as_planner_cast(&physical); - assert_eq!(cast.target_field(), &target_field); - assert_eq!(physical.return_field(&schema)?, target_field); - assert!(physical.nullable(&schema)?); + let (_, logical_field) = + cast_expr.to_field(&DFSchema::try_from(schema.clone())?)?; + assert_eq!(cast.target_field(), &logical_field); + assert_eq!(physical.return_field(&schema)?, logical_field); + + // Like other casts, nullability from the input is preserved + assert!(!physical.nullable(&schema)?); Ok(()) } @@ -676,6 +688,8 @@ mod tests { assert_eq!(cast.cast_type(), &DataType::Int64); assert_eq!(returned_field.name(), "a"); assert_eq!(returned_field.data_type(), &DataType::Int64); + + // Ensure a cast to a DataType preserves the nullability of the input assert!(!physical.nullable(&schema)?); Ok(()) @@ -697,9 +711,17 @@ mod tests { let physical = lower_cast_expr(&cast_expr, &schema)?; let cast = as_planner_cast(&physical); - assert_eq!(cast.target_field(), &target_field); - assert_eq!(physical.return_field(&schema)?, target_field); - assert!(physical.nullable(&schema)?); + let (_, logical_field) = + cast_expr.to_field(&DFSchema::try_from(schema.clone())?)?; + + // Like other casts, nullability of the input is preserved + assert!(!logical_field.is_nullable()); + + assert_eq!(cast.target_field(), &logical_field); + assert_eq!(physical.return_field(&schema)?, logical_field); + + // Like other casts, nullability of the input is preserved + assert!(!physical.nullable(&schema)?); Ok(()) } diff --git a/datafusion/sqllogictest/test_files/cast_extension_type_metadata.slt b/datafusion/sqllogictest/test_files/cast_extension_type_metadata.slt index 425d8ac16eaee..51f040cc0af2b 100644 --- a/datafusion/sqllogictest/test_files/cast_extension_type_metadata.slt +++ b/datafusion/sqllogictest/test_files/cast_extension_type_metadata.slt @@ -17,33 +17,8 @@ # Regression tests for logical CAST targets that carry explicit field metadata. -query ?T -SELECT - CAST( - arrow_cast(X'00010203040506070809000102030506', 'FixedSizeBinary(16)') - AS UUID - ), - arrow_metadata( - CAST( - arrow_cast(X'00010203040506070809000102030506', 'FixedSizeBinary(16)') - AS UUID - ), - 'ARROW:extension:name' - ); ----- -00010203040506070809000102030506 arrow.uuid +statement error DataFusion error: Optimizer rule 'simplify_expressions' failed[\s\S]*Unsupported CAST from FixedSizeBinary\(16\) to FixedSizeBinary\(16\)<\{"ARROW:extension:name": "arrow\.uuid"\}> +SELECT CAST(arrow_cast(X'00010203040506070809000102030506', 'FixedSizeBinary(16)') AS UUID); -query ?T -SELECT - CAST(raw AS UUID), - arrow_metadata(CAST(raw AS UUID), 'ARROW:extension:name') -FROM ( - VALUES ( - arrow_cast(X'00010203040506070809000102030506', 'FixedSizeBinary(16)') - ) -) AS uuids(raw); ----- -00010203040506070809000102030506 arrow.uuid - -statement error DataFusion error: Optimizer rule 'simplify_expressions' failed[\s\S]*TryCast from FixedSizeBinary\(16\) to FixedSizeBinary\(16\)<\{"ARROW:extension:name": "arrow\.uuid"\}> is not supported +statement error DataFusion error: Optimizer rule 'simplify_expressions' failed[\s\S]*Unsupported TRY_CAST from FixedSizeBinary\(16\) to FixedSizeBinary\(16\)<\{"ARROW:extension:name": "arrow\.uuid"\}> SELECT TRY_CAST(arrow_cast(X'00010203040506070809000102030506', 'FixedSizeBinary(16)') AS UUID);