diff --git a/datafusion/proto/src/from_proto.rs b/datafusion/proto/src/from_proto.rs index 91ef29ca0573..39544dca9485 100644 --- a/datafusion/proto/src/from_proto.rs +++ b/datafusion/proto/src/from_proto.rs @@ -627,9 +627,7 @@ impl TryFrom<&protobuf::PrimitiveScalarType> for ScalarValue { use protobuf::PrimitiveScalarType; Ok(match scalar { - PrimitiveScalarType::Null => { - return Err(proto_error("Untyped null is an invalid scalar value")); - } + PrimitiveScalarType::Null => Self::Null, PrimitiveScalarType::Bool => Self::Boolean(None), PrimitiveScalarType::Uint8 => Self::UInt8(None), PrimitiveScalarType::Int8 => Self::Int8(None), diff --git a/datafusion/proto/src/lib.rs b/datafusion/proto/src/lib.rs index c843de630289..1a8d6b0f4a26 100644 --- a/datafusion/proto/src/lib.rs +++ b/datafusion/proto/src/lib.rs @@ -851,6 +851,26 @@ mod roundtrip_tests { roundtrip_expr_test(test_expr, ctx); } + #[test] + fn roundtrip_case_with_null() { + let test_expr = Expr::Case { + expr: Some(Box::new(lit(1.0_f32))), + when_then_expr: vec![(Box::new(lit(2.0_f32)), Box::new(lit(3.0_f32)))], + else_expr: Some(Box::new(Expr::Literal(ScalarValue::Null))), + }; + + let ctx = SessionContext::new(); + roundtrip_expr_test(test_expr, ctx); + } + + #[test] + fn roundtrip_null_literal() { + let test_expr = Expr::Literal(ScalarValue::Null); + + let ctx = SessionContext::new(); + roundtrip_expr_test(test_expr, ctx); + } + #[test] fn roundtrip_cast() { let test_expr = Expr::Cast { diff --git a/datafusion/proto/src/to_proto.rs b/datafusion/proto/src/to_proto.rs index 387582af939a..b98e73a74009 100644 --- a/datafusion/proto/src/to_proto.rs +++ b/datafusion/proto/src/to_proto.rs @@ -1096,6 +1096,9 @@ impl TryFrom<&ScalarValue> for protobuf::ScalarValue { Value::IntervalDaytimeValue(*s) }) } + ScalarValue::Null => protobuf::ScalarValue { + value: Some(Value::NullValue(PrimitiveScalarType::Null as i32)), + }, _ => { return Err(Error::invalid_scalar_value(val)); }