diff --git a/datafusion/physical-expr/src/expressions/try_cast.rs b/datafusion/physical-expr/src/expressions/try_cast.rs index ba59d113acaab..65b953fd181b7 100644 --- a/datafusion/physical-expr/src/expressions/try_cast.rs +++ b/datafusion/physical-expr/src/expressions/try_cast.rs @@ -119,6 +119,56 @@ impl PhysicalExpr for TryCastExpr { self.expr.fmt_sql(f)?; write!(f, " AS {:?})", self.cast_type) } + + #[cfg(feature = "proto")] + fn try_to_proto( + &self, + ctx: &datafusion_physical_expr_common::physical_expr::proto_encode::PhysicalExprEncodeCtx<'_>, + ) -> Result> { + use datafusion_proto_models::protobuf; + + Ok(Some(protobuf::PhysicalExprNode { + expr_id: None, + expr_type: Some(protobuf::physical_expr_node::ExprType::TryCast(Box::new( + protobuf::PhysicalTryCastNode { + expr: Some(Box::new(ctx.encode_child(&self.expr)?)), + arrow_type: Some(self.cast_type().try_into()?), + }, + ))), + })) + } +} + +#[cfg(feature = "proto")] +impl TryCastExpr { + /// Reconstruct a [`TryCastExpr`] from its protobuf representation. + pub fn try_from_proto( + node: &datafusion_proto_models::protobuf::PhysicalExprNode, + ctx: &datafusion_physical_expr_common::physical_expr::proto_decode::PhysicalExprDecodeCtx<'_>, + ) -> Result> { + use datafusion_physical_expr_common::expect_expr_variant; + use datafusion_physical_expr_common::physical_expr::proto_decode::require_proto_field; + use datafusion_proto_models::protobuf; + + let try_cast = expect_expr_variant!( + node, + protobuf::physical_expr_node::ExprType::TryCast, + "TryCastExpr", + ); + let expr = ctx.decode_required_expression( + try_cast.expr.as_deref(), + "TryCastExpr", + "expr", + )?; + let arrow_type = require_proto_field( + try_cast.arrow_type.as_ref(), + "TryCastExpr", + "arrow_type", + )?; + let cast_type: DataType = arrow_type.try_into()?; + + Ok(Arc::new(TryCastExpr::new(expr, cast_type))) + } } /// Return a PhysicalExpression representing `expr` casted to @@ -593,3 +643,143 @@ mod tests { Ok(()) } } + +#[cfg(all(test, feature = "proto"))] +mod proto_tests { + use super::*; + use crate::expressions::{Column, col}; + use crate::proto_test_util::{ + StubDecoder, StubEncoder, UnreachableDecoder, column_node, + }; + use arrow::datatypes::Field; + use datafusion_common::DataFusionError; + use datafusion_physical_expr_common::physical_expr::proto_decode::PhysicalExprDecodeCtx; + use datafusion_physical_expr_common::physical_expr::proto_encode::PhysicalExprEncodeCtx; + use datafusion_proto_models::datafusion_common::ArrowType; + use datafusion_proto_models::protobuf::{ + PhysicalExprNode, PhysicalTryCastNode, physical_expr_node, + }; + + fn try_cast_fixture() -> TryCastExpr { + let schema = Schema::new(vec![Field::new("a", DataType::Utf8, true)]); + TryCastExpr::new(col("a", &schema).unwrap(), DataType::Int32) + } + + fn int32_arrow_type() -> ArrowType { + (&DataType::Int32).try_into().unwrap() + } + + fn try_cast_node( + expr: Option>, + arrow_type: Option, + ) -> PhysicalExprNode { + PhysicalExprNode { + expr_id: None, + expr_type: Some(physical_expr_node::ExprType::TryCast(Box::new( + PhysicalTryCastNode { expr, arrow_type }, + ))), + } + } + + #[test] + fn try_to_proto_encodes_try_cast_expr() { + let try_cast = try_cast_fixture(); + let encoder = StubEncoder::ok(); + let ctx = PhysicalExprEncodeCtx::new(&encoder); + + let node = try_cast + .try_to_proto(&ctx) + .unwrap() + .expect("TryCastExpr should encode to Some(node)"); + + assert!(node.expr_id.is_none()); + let try_cast_node = match node.expr_type { + Some(physical_expr_node::ExprType::TryCast(boxed)) => *boxed, + other => panic!("expected a TryCastExpr node, got {other:?}"), + }; + assert!(try_cast_node.expr.is_some()); + + let arrow_type = try_cast_node + .arrow_type + .as_ref() + .expect("try cast type should be encoded"); + let data_type: DataType = arrow_type.try_into().unwrap(); + assert_eq!(data_type, DataType::Int32); + } + + #[test] + fn try_to_proto_propagates_child_encode_error() { + let try_cast = try_cast_fixture(); + let encoder = StubEncoder::failing_on(1); + let ctx = PhysicalExprEncodeCtx::new(&encoder); + let err = try_cast.try_to_proto(&ctx).unwrap_err(); + assert!(matches!(err, DataFusionError::Internal(msg) if msg.contains("call 1"))); + } + + #[test] + fn try_from_proto_decodes_try_cast_expr() { + let node = + try_cast_node(Some(Box::new(column_node("a"))), Some(int32_arrow_type())); + let schema = Schema::empty(); + let decoder = StubDecoder::ok(); + let ctx = PhysicalExprDecodeCtx::new(&schema, &decoder); + + let decoded = TryCastExpr::try_from_proto(&node, &ctx).unwrap(); + let try_cast = decoded + .downcast_ref::() + .expect("decoded expr should be a TryCastExpr"); + + assert_eq!(try_cast.cast_type(), &DataType::Int32); + assert!(try_cast.expr().downcast_ref::().is_some()); + } + + #[test] + fn try_from_proto_rejects_non_try_cast_node() { + let node = column_node("a"); + let schema = Schema::empty(); + let decoder = UnreachableDecoder; + let ctx = PhysicalExprDecodeCtx::new(&schema, &decoder); + + let err = TryCastExpr::try_from_proto(&node, &ctx).unwrap_err(); + assert!( + matches!(err, DataFusionError::Internal(msg) if msg.contains("PhysicalExprNode is not a TryCastExpr")) + ); + } + + #[test] + fn try_from_proto_rejects_missing_expr() { + let node = try_cast_node(None, Some(int32_arrow_type())); + let schema = Schema::empty(); + let decoder = UnreachableDecoder; + let ctx = PhysicalExprDecodeCtx::new(&schema, &decoder); + + let err = TryCastExpr::try_from_proto(&node, &ctx).unwrap_err(); + assert!( + matches!(err, DataFusionError::Internal(msg) if msg.contains("TryCastExpr is missing required field 'expr'")) + ); + } + + #[test] + fn try_from_proto_rejects_missing_arrow_type() { + let node = try_cast_node(Some(Box::new(column_node("a"))), None); + let schema = Schema::empty(); + let decoder = StubDecoder::ok(); + let ctx = PhysicalExprDecodeCtx::new(&schema, &decoder); + + let err = TryCastExpr::try_from_proto(&node, &ctx).unwrap_err(); + assert!( + matches!(err, DataFusionError::Internal(msg) if msg.contains("TryCastExpr is missing required field 'arrow_type'")) + ); + } + + #[test] + fn try_from_proto_propagates_child_decode_error() { + let node = + try_cast_node(Some(Box::new(column_node("a"))), Some(int32_arrow_type())); + let schema = Schema::empty(); + let decoder = StubDecoder::failing_on(1); + let ctx = PhysicalExprDecodeCtx::new(&schema, &decoder); + let err = TryCastExpr::try_from_proto(&node, &ctx).unwrap_err(); + assert!(matches!(err, DataFusionError::Internal(msg) if msg.contains("call 1"))); + } +} diff --git a/datafusion/proto/src/physical_plan/from_proto.rs b/datafusion/proto/src/physical_plan/from_proto.rs index 8d7c11fb6ab26..5e1e285ecf665 100644 --- a/datafusion/proto/src/physical_plan/from_proto.rs +++ b/datafusion/proto/src/physical_plan/from_proto.rs @@ -347,16 +347,7 @@ pub fn parse_physical_expr_with_converter( .transpose()?, )?), ExprType::Cast(_) => CastExpr::try_from_proto(proto, &decode_ctx)?, - ExprType::TryCast(e) => Arc::new(TryCastExpr::new( - parse_required_physical_expr( - e.expr.as_deref(), - ctx, - "expr", - input_schema, - proto_converter, - )?, - convert_required!(e.arrow_type)?, - )), + ExprType::TryCast(_) => TryCastExpr::try_from_proto(proto, &decode_ctx)?, ExprType::ScalarUdf(e) => { let udf = match &e.fun_definition { Some(buf) => ctx.codec().try_decode_udf(&e.name, buf)?, diff --git a/datafusion/proto/src/physical_plan/to_proto.rs b/datafusion/proto/src/physical_plan/to_proto.rs index 17d363fa0689f..039c4bc4fea61 100644 --- a/datafusion/proto/src/physical_plan/to_proto.rs +++ b/datafusion/proto/src/physical_plan/to_proto.rs @@ -36,7 +36,7 @@ use datafusion_physical_expr::scalar_subquery::ScalarSubqueryExpr; use datafusion_physical_expr::window::{SlidingAggregateWindowExpr, StandardWindowExpr}; use datafusion_physical_expr_common::sort_expr::PhysicalSortExpr; use datafusion_physical_plan::expressions::{ - CaseExpr, DynamicFilterPhysicalExpr, IsNullExpr, Literal, TryCastExpr, + CaseExpr, DynamicFilterPhysicalExpr, IsNullExpr, Literal, }; use datafusion_physical_plan::udaf::AggregateFunctionExpr; use datafusion_physical_plan::windows::{PlainAggregateWindowExpr, WindowUDFExpr}; @@ -363,18 +363,6 @@ pub fn serialize_physical_expr_with_converter( lit.value().try_into()?, )), }) - } else if let Some(cast) = expr.downcast_ref::() { - Ok(protobuf::PhysicalExprNode { - expr_id, - expr_type: Some(protobuf::physical_expr_node::ExprType::TryCast(Box::new( - protobuf::PhysicalTryCastNode { - expr: Some(Box::new( - proto_converter.physical_expr_to_proto(cast.expr(), codec)?, - )), - arrow_type: Some(cast.cast_type().try_into()?), - }, - ))), - }) } else if let Some(expr) = expr.downcast_ref::() { let mut buf = Vec::new(); codec.try_encode_udf(expr.fun(), &mut buf)?;