diff --git a/datafusion/core/tests/sql/projection.rs b/datafusion/core/tests/sql/projection.rs index a964acae4424..363e96c364c6 100644 --- a/datafusion/core/tests/sql/projection.rs +++ b/datafusion/core/tests/sql/projection.rs @@ -17,6 +17,7 @@ use datafusion::logical_plan::{provider_as_source, LogicalPlanBuilder, UNNAMED_TABLE}; use datafusion::test_util::scan_empty; +use datafusion_expr::when; use tempfile::TempDir; use super::*; @@ -220,6 +221,48 @@ async fn preserve_nullability_on_projection() -> Result<()> { Ok(()) } +#[tokio::test] +async fn project_cast_dictionary() { + let ctx = SessionContext::new(); + + let host: DictionaryArray = vec![Some("host1"), None, Some("host2")] + .into_iter() + .collect(); + + let batch = RecordBatch::try_from_iter(vec![("host", Arc::new(host) as _)]).unwrap(); + + let t = MemTable::try_new(batch.schema(), vec![vec![batch]]).unwrap(); + + // Note that `host` is a dictionary array but `lit("")` is a DataType::Utf8 that needs to be cast + let expr = when(col("host").is_null(), lit("")) + .otherwise(col("host")) + .unwrap(); + + let projection = None; + let builder = LogicalPlanBuilder::scan( + "cpu_load_short", + provider_as_source(Arc::new(t)), + projection, + ) + .unwrap(); + + let logical_plan = builder.project(vec![expr]).unwrap().build().unwrap(); + + let physical_plan = ctx.create_physical_plan(&logical_plan).await.unwrap(); + let actual = collect(physical_plan, ctx.task_ctx()).await.unwrap(); + + let expected = vec![ + "+------------------------------------------------------------------------------------+", + "| CASE WHEN #cpu_load_short.host IS NULL THEN Utf8(\"\") ELSE #cpu_load_short.host END |", + "+------------------------------------------------------------------------------------+", + "| host1 |", + "| |", + "| host2 |", + "+------------------------------------------------------------------------------------+", + ]; + assert_batches_eq!(expected, &actual); +} + #[tokio::test] async fn projection_on_memory_scan() -> Result<()> { let schema = Schema::new(vec![ diff --git a/datafusion/physical-expr/src/expressions/try_cast.rs b/datafusion/physical-expr/src/expressions/try_cast.rs index 2509c1d6b131..813c641362ab 100644 --- a/datafusion/physical-expr/src/expressions/try_cast.rs +++ b/datafusion/physical-expr/src/expressions/try_cast.rs @@ -82,6 +82,19 @@ impl PhysicalExpr for TryCastExpr { &array, &self.cast_type, )?)), + ColumnarValue::Scalar(scalar) + if matches!(self.cast_type, DataType::Dictionary(_, _)) => + { + // ScalarValues do not preserve dictionary encoding + // (so they don't survive the round trip), + // https://github.com/apache/arrow-datafusion/issues/2874 + // Until that is fixed, "unpack" the ColumnarValue here + let array = scalar.to_array_of_size(batch.num_rows()); + Ok(ColumnarValue::Array(kernels::cast::cast( + &array, + &self.cast_type, + )?)) + } ColumnarValue::Scalar(scalar) => { let scalar_array = scalar.to_array(); let cast_array = kernels::cast::cast(&scalar_array, &self.cast_type)?; @@ -119,8 +132,8 @@ mod tests { use super::*; use crate::expressions::col; use arrow::array::{ - BasicDecimalArray, DecimalArray, DecimalBuilder, StringArray, - Time64NanosecondArray, + as_string_array, BasicDecimalArray, DecimalArray, DecimalBuilder, + DictionaryArray, StringArray, Time64NanosecondArray, }; use arrow::util::decimal::{BasicDecimal, Decimal128}; use arrow::{ @@ -185,10 +198,13 @@ mod tests { // 3. evaluate the expression // 4. verify that the resulting expression is of type B // 5. verify that the resulting values are downcastable and correct + // + // $VALUE_FN is an expression (like `result.value`) that extracts the value at index `i` macro_rules! generic_test_cast { - ($A_ARRAY:ident, $A_TYPE:expr, $A_VEC:expr, $TYPEARRAY:ident, $TYPE:expr, $VEC:expr) => {{ + ($A_ARRAY:ident, $A_TYPE:expr, $A_VEC:expr, $TYPEARRAY:ident, $TYPE:expr, $VEC:expr, $VALUE_FN:expr) => {{ let schema = Schema::new(vec![Field::new("a", $A_TYPE, true)]); - let a = $A_ARRAY::from($A_VEC); + let a = $A_ARRAY::from_iter($A_VEC); + let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)])?; @@ -222,7 +238,7 @@ mod tests { // verify that the result itself is correct for (i, x) in $VEC.iter().enumerate() { match x { - Some(x) => assert_eq!(result.value(i), *x), + Some(x) => assert_eq!($VALUE_FN(result, i), *x), None => assert!(!result.is_valid(i)), } } @@ -396,7 +412,8 @@ mod tests { Some(convert(3)), Some(convert(4)), Some(convert(5)), - ] + ], + |result: &DecimalArray, i| result.value(i) ); // int16 @@ -413,7 +430,8 @@ mod tests { Some(convert(3)), Some(convert(4)), Some(convert(5)), - ] + ], + |result: &DecimalArray, i| result.value(i) ); // int32 @@ -430,7 +448,8 @@ mod tests { Some(convert(3)), Some(convert(4)), Some(convert(5)), - ] + ], + |result: &DecimalArray, i| result.value(i) ); // int64 @@ -447,7 +466,8 @@ mod tests { Some(convert(3)), Some(convert(4)), Some(convert(5)), - ] + ], + |result: &DecimalArray, i| result.value(i) ); // int64 to different scale @@ -464,7 +484,8 @@ mod tests { Some(convert(300)), Some(convert(400)), Some(convert(500)), - ] + ], + |result: &DecimalArray, i| result.value(i) ); // float32 @@ -481,7 +502,8 @@ mod tests { Some(convert(300)), Some(convert(112)), Some(convert(550)), - ] + ], + |result: &DecimalArray, i| result.value(i) ); // float64 @@ -498,7 +520,8 @@ mod tests { Some(convert(30000)), Some(convert(11234)), Some(convert(55000)), - ] + ], + |result: &DecimalArray, i| result.value(i) ); Ok(()) } @@ -517,7 +540,8 @@ mod tests { Some(3_u32), Some(4_u32), Some(5_u32) - ] + ], + |result: &UInt32Array, i| result.value(i) ); Ok(()) } @@ -530,7 +554,8 @@ mod tests { vec![1, 2, 3, 4, 5], StringArray, DataType::Utf8, - vec![Some("1"), Some("2"), Some("3"), Some("4"), Some("5")] + vec![Some("1"), Some("2"), Some("3"), Some("4"), Some("5")], + |result: &StringArray, i| result.value(i).to_string() ); Ok(()) } @@ -540,10 +565,58 @@ mod tests { generic_test_cast!( StringArray, DataType::Utf8, - vec!["a", "2", "3", "b", "5"], + vec![Some("a"), Some("2"), Some("3"), Some("b"), Some("5")], Int32Array, DataType::Int32, - vec![None, Some(2), Some(3), None, Some(5)] + vec![None, Some(2), Some(3), None, Some(5)], + |result: &Int32Array, i| result.value(i) + ); + Ok(()) + } + + #[test] + fn test_try_cast_string_dict_to_utf8() -> Result<()> { + let dict_type = DataType::Dictionary( + Box::new(DataType::Int32), // key_type + Box::new(DataType::Utf8), // value_type + ); + + // define a type alias so we can use the macro + type DictArrayType = DictionaryArray; + + generic_test_cast!( + DictArrayType, + dict_type, + vec![Some("a"), Some("b")], + StringArray, + DataType::Utf8, + vec![Some("a"), Some("b")], + |result: &StringArray, i| result.value(i).to_string() + ); + Ok(()) + } + + #[allow(clippy::redundant_clone)] + #[test] + fn test_try_cast_utf8_to_string_dict() -> Result<()> { + let dict_type = DataType::Dictionary( + Box::new(DataType::Int32), // key_type + Box::new(DataType::Utf8), // value_type + ); + + // define a type alias so we can use the macro + type DictArrayType = DictionaryArray; + + generic_test_cast!( + StringArray, + DataType::Utf8, + vec![Some("a"), Some("b")], + DictArrayType, + dict_type.clone(), + vec![Some("a"), Some("b")], + |result: &DictArrayType, i| { + as_string_array(result.values()).value(i).to_string() + } ); Ok(()) } @@ -562,7 +635,8 @@ mod tests { original.clone(), TimestampNanosecondArray, DataType::Timestamp(TimeUnit::Nanosecond, None), - expected + expected, + |result: &TimestampNanosecondArray, i| result.value(i) ); Ok(()) }