diff --git a/datafusion/physical-expr/src/expressions/case.rs b/datafusion/physical-expr/src/expressions/case.rs index ddb0a98df537..a3c181368d5f 100644 --- a/datafusion/physical-expr/src/expressions/case.rs +++ b/datafusion/physical-expr/src/expressions/case.rs @@ -866,12 +866,13 @@ impl CaseBody { // Since each when expression is tested against the base expression using the equality // operator, null base values can never match any when expression. `x = NULL` is falsy, // for all possible values of `x`. - if base_values.null_count() > 0 { + let base_null_count = base_values.logical_null_count(); + if base_null_count > 0 { // Use `is_not_null` since this is a cheap clone of the null buffer from 'base_value'. // We already checked there are nulls, so we can be sure a new buffer will not be // created. let base_not_nulls = is_not_null(base_values.as_ref())?; - let base_all_null = base_values.null_count() == remainder_batch.num_rows(); + let base_all_null = base_null_count == remainder_batch.num_rows(); // If there is an else expression, use that as the default value for the null rows // Otherwise the default `null` value from the result builder will be used. @@ -1545,6 +1546,84 @@ mod tests { Ok(()) } + #[test] + fn case_with_expr_dictionary() -> Result<()> { + let schema = Schema::new(vec![Field::new( + "a", + DataType::Dictionary(Box::new(DataType::UInt8), Box::new(DataType::Utf8)), + true, + )]); + let keys = UInt8Array::from(vec![0u8, 1u8, 2u8, 3u8]); + let values = StringArray::from(vec![Some("foo"), Some("baz"), None, Some("bar")]); + let dictionary = DictionaryArray::new(keys, Arc::new(values)); + let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(dictionary)])?; + + let schema = batch.schema(); + + // CASE a WHEN 'foo' THEN 123 WHEN 'bar' THEN 456 END + let when1 = lit("foo"); + let then1 = lit(123i32); + let when2 = lit("bar"); + let then2 = lit(456i32); + + let expr = generate_case_when_with_type_coercion( + Some(col("a", &schema)?), + vec![(when1, then1), (when2, then2)], + None, + schema.as_ref(), + )?; + let result = expr + .evaluate(&batch)? + .into_array(batch.num_rows()) + .expect("Failed to convert to array"); + let result = as_int32_array(&result)?; + + let expected = &Int32Array::from(vec![Some(123), None, None, Some(456)]); + + assert_eq!(expected, result); + + Ok(()) + } + + #[test] + fn case_with_expr_all_null_dictionary() -> Result<()> { + let schema = Schema::new(vec![Field::new( + "a", + DataType::Dictionary(Box::new(DataType::UInt8), Box::new(DataType::Utf8)), + true, + )]); + let keys = UInt8Array::from(vec![2u8, 2u8, 2u8, 2u8]); + let values = StringArray::from(vec![Some("foo"), Some("baz"), None, Some("bar")]); + let dictionary = DictionaryArray::new(keys, Arc::new(values)); + let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(dictionary)])?; + + let schema = batch.schema(); + + // CASE a WHEN 'foo' THEN 123 WHEN 'bar' THEN 456 END + let when1 = lit("foo"); + let then1 = lit(123i32); + let when2 = lit("bar"); + let then2 = lit(456i32); + + let expr = generate_case_when_with_type_coercion( + Some(col("a", &schema)?), + vec![(when1, then1), (when2, then2)], + None, + schema.as_ref(), + )?; + let result = expr + .evaluate(&batch)? + .into_array(batch.num_rows()) + .expect("Failed to convert to array"); + let result = as_int32_array(&result)?; + + let expected = &Int32Array::from(vec![None, None, None, None]); + + assert_eq!(expected, result); + + Ok(()) + } + #[test] fn case_with_expr_else() -> Result<()> { let batch = case_test_batch()?;