From 50a5bc5a74b27e4c58dbfc03ba382cf5618ddcfc Mon Sep 17 00:00:00 2001 From: byteink Date: Thu, 11 May 2023 16:06:39 +0800 Subject: [PATCH] Fix case evaluation with NULL --- .../physical-expr/src/expressions/case.rs | 35 +++++++++++++++++-- 1 file changed, 33 insertions(+), 2 deletions(-) diff --git a/datafusion/physical-expr/src/expressions/case.rs b/datafusion/physical-expr/src/expressions/case.rs index 2d97d57324fe..72783145a04d 100644 --- a/datafusion/physical-expr/src/expressions/case.rs +++ b/datafusion/physical-expr/src/expressions/case.rs @@ -23,7 +23,7 @@ use crate::physical_expr::down_cast_any_ref; use crate::PhysicalExpr; use arrow::array::*; use arrow::compute::kernels::zip::zip; -use arrow::compute::{and, eq_dyn, is_null, not, or, or_kleene}; +use arrow::compute::{and, and_kleene, eq_dyn, is_not_null, is_null, not, or, or_kleene}; use arrow::datatypes::{DataType, Schema}; use arrow::record_batch::RecordBatch; use datafusion_common::{cast::as_boolean_array, DataFusionError, Result}; @@ -138,6 +138,8 @@ impl CaseExpr { let when_value = when_value.into_array(batch.num_rows()); // build boolean array representing which rows match the "when" value let when_match = eq_dyn(&when_value, base_value.as_ref())?; + // Treat nulls as false + let when_match = and_kleene(&when_match, &is_not_null(&when_match)?)?; let then_value = self.when_then_expr[i] .1 @@ -152,7 +154,7 @@ impl CaseExpr { current_value = zip(&when_match, then_value.as_ref(), current_value.as_ref())?; - remainder = and(&remainder, &or_kleene(¬(&when_match)?, &base_nulls)?)?; + remainder = and(&remainder, ¬(&when_match)?)?; } if let Some(e) = &self.else_expr { @@ -526,6 +528,35 @@ mod tests { Ok(()) } + #[test] + fn case_with_expr_when_null() -> Result<()> { + let batch = case_test_batch()?; + let schema = batch.schema(); + + // CASE a WHEN NULL THEN 0 WHEN a THEN 123 ELSE 999 END + let when1 = lit(ScalarValue::Utf8(None)); + let then1 = lit(0i32); + let when2 = col("a", &schema)?; + let then2 = lit(123i32); + let else_value = lit(999i32); + + let expr = generate_case_when_with_type_coercion( + Some(col("a", &schema)?), + vec![(when1, then1), (when2, then2)], + Some(else_value), + schema.as_ref(), + )?; + let result = expr.evaluate(&batch)?.into_array(batch.num_rows()); + let result = as_int32_array(&result)?; + + let expected = + &Int32Array::from(vec![Some(123), Some(123), Some(999), Some(123)]); + + assert_eq!(expected, result); + + Ok(()) + } + #[test] fn case_without_expr_divide_by_zero() -> Result<()> { let batch = case_test_batch1()?;