diff --git a/datafusion/common/src/scalar.rs b/datafusion/common/src/scalar.rs index 777ab0d9013e..ba02aeb6703a 100644 --- a/datafusion/common/src/scalar.rs +++ b/datafusion/common/src/scalar.rs @@ -101,6 +101,8 @@ pub enum ScalarValue { IntervalMonthDayNano(Option), /// struct of nested ScalarValue Struct(Option>, Box>), + /// Dictionary type: index type and value + Dictionary(Box, Box), } // manual implementation of `PartialEq` that uses OrderedFloat to @@ -176,6 +178,8 @@ impl PartialEq for ScalarValue { (IntervalMonthDayNano(_), _) => false, (Struct(v1, t1), Struct(v2, t2)) => v1.eq(v2) && t1.eq(t2), (Struct(_, _), _) => false, + (Dictionary(k1, v1), Dictionary(k2, v2)) => k1.eq(k2) && v1.eq(v2), + (Dictionary(_, _), _) => false, (Null, Null) => true, (Null, _) => false, } @@ -278,6 +282,15 @@ impl PartialOrd for ScalarValue { } } (Struct(_, _), _) => None, + (Dictionary(k1, v1), Dictionary(k2, v2)) => { + // Don't compare if the key types don't match (it is effectively a different datatype) + if k1 == k2 { + v1.partial_cmp(v2) + } else { + None + } + } + (Dictionary(_, _), _) => None, (Null, Null) => Some(Ordering::Equal), (Null, _) => None, } @@ -335,35 +348,85 @@ impl std::hash::Hash for ScalarValue { v.hash(state); t.hash(state); } + Dictionary(k, v) => { + k.hash(state); + v.hash(state); + } // stable hash for Null value Null => 1.hash(state), } } } -// return the index into the dictionary values for array@index as well -// as a reference to the dictionary values array. Returns None for the -// index if the array is NULL at index +/// return a reference to the values array and the index into it for a +/// dictionary array #[inline] fn get_dict_value( array: &ArrayRef, index: usize, -) -> Result<(&ArrayRef, Option)> { - let dict_array = array.as_any().downcast_ref::>().unwrap(); +) -> (&ArrayRef, Option) { + let dict_array = as_dictionary_array::(array); + (dict_array.values(), dict_array.key(index)) +} - // look up the index in the values dictionary - let keys_col = dict_array.keys(); - if !keys_col.is_valid(index) { - return Ok((dict_array.values(), None)); - } - let values_index = keys_col.value(index).to_usize().ok_or_else(|| { - DataFusionError::Internal(format!( - "Can not convert index to usize in dictionary of type creating group by value {:?}", - keys_col.data_type() - )) - })?; - - Ok((dict_array.values(), Some(values_index))) +/// Create a dictionary array representing `value` repeated `size` +/// times +fn dict_from_scalar( + value: &ScalarValue, + size: usize, +) -> ArrayRef { + // values array is one element long (the value) + let values_array = value.to_array_of_size(1); + + // Create a key array with `size` elements, each of 0 + let key_array: PrimitiveArray = std::iter::repeat(Some(K::default_value())) + .take(size) + .collect(); + + // create a new DictionaryArray + // + // Note: this path could be made faster by using the ArrayData + // APIs and skipping validation, if it every comes up in + // performance traces. + Arc::new( + DictionaryArray::::try_new(&key_array, &values_array) + // should always be valid by construction above + .expect("Can not construct dictionary array"), + ) +} + +/// Create a dictionary array representing all the values in values +fn dict_from_values( + values_array: &dyn Array, +) -> Result { + // Create a key array with `size` elements of 0..array_len for all + // non-null value elements + let key_array: PrimitiveArray = (0..values_array.len()) + .map(|index| { + if values_array.is_valid(index) { + let native_index = K::Native::from_usize(index).ok_or_else(|| { + DataFusionError::Internal(format!( + "Can not create index of type {} from value {}", + K::DATA_TYPE, + index + )) + })?; + Ok(Some(native_index)) + } else { + Ok(None) + } + }) + .collect::>>()? + .into_iter() + .collect(); + + // create a new DictionaryArray + // + // Note: this path could be made faster by using the ArrayData + // APIs and skipping validation, if it every comes up in + // performance traces. + let dict_array = DictionaryArray::::try_new(&key_array, values_array)?; + Ok(Arc::new(dict_array)) } macro_rules! typed_cast_tz { @@ -603,6 +666,9 @@ impl ScalarValue { DataType::Interval(IntervalUnit::MonthDayNano) } ScalarValue::Struct(_, fields) => DataType::Struct(fields.as_ref().clone()), + ScalarValue::Dictionary(k, v) => { + DataType::Dictionary(k.clone(), Box::new(v.get_datatype())) + } ScalarValue::Null => DataType::Null, } } @@ -660,6 +726,7 @@ impl ScalarValue { ScalarValue::IntervalDayTime(v) => v.is_none(), ScalarValue::IntervalMonthDayNano(v) => v.is_none(), ScalarValue::Struct(v, _) => v.is_none(), + ScalarValue::Dictionary(_, v) => v.is_null(), } } @@ -971,7 +1038,54 @@ impl ScalarValue { Arc::new(StructArray::from(field_values)) } - _ => { + DataType::Dictionary(key_type, value_type) => { + // create the values array + let value_scalars = scalars + .into_iter() + .map(|scalar| match scalar { + ScalarValue::Dictionary(inner_key_type, scalar) => { + if &inner_key_type == key_type { + Ok(*scalar) + } else{ + panic!("Expected inner key type of {} but found: {}, value was ({:?})", key_type, inner_key_type, scalar); + } + }, + _ => { + Err(DataFusionError::Internal(format!( + "Expected scalar of type {} but found: {} {:?}", + value_type, scalar, scalar + ))) + }, + }) + .collect::>>()?; + + let values = Self::iter_to_array(value_scalars)?; + assert_eq!(values.data_type(), value_type.as_ref()); + + match key_type.as_ref() { + DataType::Int8 => dict_from_values::(&values)?, + DataType::Int16 => dict_from_values::(&values)?, + DataType::Int32 => dict_from_values::(&values)?, + DataType::Int64 => dict_from_values::(&values)?, + DataType::UInt8 => dict_from_values::(&values)?, + DataType::UInt16 => dict_from_values::(&values)?, + DataType::UInt32 => dict_from_values::(&values)?, + DataType::UInt64 => dict_from_values::(&values)?, + _ => unreachable!("Invalid dictionary keys type: {:?}", key_type), + } + } + // explicitly enumerate unsupported types so newly added + // types must be aknowledged + DataType::Float16 + | DataType::Time32(_) + | DataType::Time64(_) + | DataType::Duration(_) + | DataType::FixedSizeBinary(_) + | DataType::FixedSizeList(_, _) + | DataType::Interval(_) + | DataType::LargeList(_) + | DataType::Union(_, _, _) + | DataType::Map(_, _) => { return Err(DataFusionError::Internal(format!( "Unsupported creation of {:?} array from ScalarValue {:?}", data_type, @@ -1267,6 +1381,20 @@ impl ScalarValue { Arc::new(StructArray::from(field_values)) } }, + ScalarValue::Dictionary(key_type, v) => { + // values array is one element long (the value) + match key_type.as_ref() { + DataType::Int8 => dict_from_scalar::(v, size), + DataType::Int16 => dict_from_scalar::(v, size), + DataType::Int32 => dict_from_scalar::(v, size), + DataType::Int64 => dict_from_scalar::(v, size), + DataType::UInt8 => dict_from_scalar::(v, size), + DataType::UInt16 => dict_from_scalar::(v, size), + DataType::UInt32 => dict_from_scalar::(v, size), + DataType::UInt64 => dict_from_scalar::(v, size), + _ => unreachable!("Invalid dictionary keys type: {:?}", key_type), + } + } ScalarValue::Null => new_null_array(&DataType::Null, size), } } @@ -1380,29 +1508,30 @@ impl ScalarValue { tz_opt ) } - DataType::Dictionary(index_type, _) => { - let (values, values_index) = match **index_type { - DataType::Int8 => get_dict_value::(array, index)?, - DataType::Int16 => get_dict_value::(array, index)?, - DataType::Int32 => get_dict_value::(array, index)?, - DataType::Int64 => get_dict_value::(array, index)?, - DataType::UInt8 => get_dict_value::(array, index)?, - DataType::UInt16 => get_dict_value::(array, index)?, - DataType::UInt32 => get_dict_value::(array, index)?, - DataType::UInt64 => get_dict_value::(array, index)?, - _ => { - return Err(DataFusionError::Internal(format!( - "Index type not supported while creating scalar from dictionary: {}", - array.data_type(), - ))); - } + DataType::Dictionary(key_type, _) => { + let (values_array, values_index) = match key_type.as_ref() { + DataType::Int8 => get_dict_value::(array, index), + DataType::Int16 => get_dict_value::(array, index), + DataType::Int32 => get_dict_value::(array, index), + DataType::Int64 => get_dict_value::(array, index), + DataType::UInt8 => get_dict_value::(array, index), + DataType::UInt16 => get_dict_value::(array, index), + DataType::UInt32 => get_dict_value::(array, index), + DataType::UInt64 => get_dict_value::(array, index), + _ => unreachable!("Invalid dictionary keys type: {:?}", key_type), }; + // look up the index in the values dictionary + let value = match values_index { + Some(values_index) => { + ScalarValue::try_from_array(values_array, values_index) + } + // else entry was null, so return null + None => values_array.data_type().try_into(), + }?; - match values_index { - Some(values_index) => Self::try_from_array(values, values_index)?, - // was null - None => values.data_type().try_into()?, - } + println!("AAL creating dictionary scalar with value {:?}", value); + + Self::Dictionary(key_type.clone(), Box::new(value)) } DataType::Struct(fields) => { let array = @@ -1494,10 +1623,6 @@ impl ScalarValue { /// comparisons where comparing a single row at a time is necessary. #[inline] pub fn eq_array(&self, array: &ArrayRef, index: usize) -> bool { - if let DataType::Dictionary(key_type, _) = array.data_type() { - return self.eq_array_dictionary(array, index, key_type); - } - match self { ScalarValue::Decimal128(v, precision, scale) => { ScalarValue::eq_array_decimal(array, index, v, *precision, *scale) @@ -1564,35 +1689,27 @@ impl ScalarValue { eq_array_primitive!(array, index, IntervalMonthDayNanoArray, val) } ScalarValue::Struct(_, _) => unimplemented!(), + ScalarValue::Dictionary(key_type, v) => { + let (values_array, values_index) = match key_type.as_ref() { + DataType::Int8 => get_dict_value::(array, index), + DataType::Int16 => get_dict_value::(array, index), + DataType::Int32 => get_dict_value::(array, index), + DataType::Int64 => get_dict_value::(array, index), + DataType::UInt8 => get_dict_value::(array, index), + DataType::UInt16 => get_dict_value::(array, index), + DataType::UInt32 => get_dict_value::(array, index), + DataType::UInt64 => get_dict_value::(array, index), + _ => unreachable!("Invalid dictionary keys type: {:?}", key_type), + }; + // was the value in the array non null? + match values_index { + Some(values_index) => v.eq_array(values_array, values_index), + None => v.is_null(), + } + } ScalarValue::Null => array.data().is_null(index), } } - - /// Compares a dictionary array with indexes of type `key_type` - /// with the array @ index for equality with self - fn eq_array_dictionary( - &self, - array: &ArrayRef, - index: usize, - key_type: &DataType, - ) -> bool { - let (values, values_index) = match key_type { - DataType::Int8 => get_dict_value::(array, index).unwrap(), - DataType::Int16 => get_dict_value::(array, index).unwrap(), - DataType::Int32 => get_dict_value::(array, index).unwrap(), - DataType::Int64 => get_dict_value::(array, index).unwrap(), - DataType::UInt8 => get_dict_value::(array, index).unwrap(), - DataType::UInt16 => get_dict_value::(array, index).unwrap(), - DataType::UInt32 => get_dict_value::(array, index).unwrap(), - DataType::UInt64 => get_dict_value::(array, index).unwrap(), - _ => unreachable!("Invalid dictionary keys type: {:?}", key_type), - }; - - match values_index { - Some(values_index) => self.eq_array(values, values_index), - None => self.is_null(), - } - } } macro_rules! impl_scalar { @@ -1777,9 +1894,10 @@ impl TryFrom<&DataType> for ScalarValue { DataType::Timestamp(TimeUnit::Nanosecond, tz_opt) => { ScalarValue::TimestampNanosecond(None, tz_opt.clone()) } - DataType::Dictionary(_index_type, value_type) => { - value_type.as_ref().try_into()? - } + DataType::Dictionary(index_type, value_type) => ScalarValue::Dictionary( + index_type.clone(), + Box::new(value_type.as_ref().try_into()?), + ), DataType::List(ref nested_type) => { ScalarValue::List(None, Box::new(nested_type.data_type().clone())) } @@ -1879,6 +1997,7 @@ impl fmt::Display for ScalarValue { )?, None => write!(f, "NULL")?, }, + ScalarValue::Dictionary(_k, v) => write!(f, "{}", v)?, ScalarValue::Null => write!(f, "NULL")?, }; Ok(()) @@ -1947,6 +2066,7 @@ impl fmt::Debug for ScalarValue { None => write!(f, "Struct(NULL)"), } } + ScalarValue::Dictionary(k, v) => write!(f, "Dictionary({:?}, {:?})", k, v), ScalarValue::Null => write!(f, "NULL"), } } @@ -1992,6 +2112,8 @@ impl ScalarType for TimestampNanosecondType { mod tests { use super::*; use crate::from_slice::FromSlice; + use arrow::compute::kernels; + use arrow::datatypes::ArrowPrimitiveType; use std::cmp::Ordering; use std::sync::Arc; @@ -2303,6 +2425,38 @@ mod tests { ); } + #[test] + fn scalar_iter_to_dictionary() { + fn make_val(v: Option) -> ScalarValue { + let key_type = DataType::Int32; + let value = ScalarValue::Utf8(v); + ScalarValue::Dictionary(Box::new(key_type), Box::new(value)) + } + + let scalars = vec![ + make_val(Some("Foo".into())), + make_val(None), + make_val(Some("Bar".into())), + ]; + + let array = ScalarValue::iter_to_array(scalars.into_iter()).unwrap(); + let array = as_dictionary_array::(&array); + let values_array = as_string_array(array.values()); + + let values = array + .keys_iter() + .map(|k| { + k.map(|k| { + assert!(values_array.is_valid(k)); + values_array.value(k) + }) + }) + .collect::>(); + + let expected = vec![Some("Foo"), None, Some("Bar")]; + assert_eq!(values, expected); + } + #[test] fn scalar_iter_to_array_mismatched_types() { use ScalarValue::*; @@ -2334,7 +2488,11 @@ mod tests { let data_type = DataType::Dictionary(Box::new(DataType::Int8), Box::new(DataType::Utf8)); let data_type = &data_type; - assert_eq!(ScalarValue::Utf8(None), data_type.try_into().unwrap()) + let expected = ScalarValue::Dictionary( + Box::new(DataType::Int8), + Box::new(ScalarValue::Utf8(None)), + ); + assert_eq!(expected, data_type.try_into().unwrap()) } #[test] @@ -2385,6 +2543,7 @@ mod tests { /// Test each value in `scalar` with the corresponding element /// at `array`. Assumes each element is unique (aka not equal /// with all other indexes) + #[derive(Debug)] struct TestCase { array: ArrayRef, scalars: Vec, @@ -2439,7 +2598,7 @@ mod tests { /// create a test case for DictionaryArray<$INDEX_TY> macro_rules! make_str_dict_test_case { - ($INPUT:expr, $INDEX_TY:ident, $SCALAR_TY:ident) => {{ + ($INPUT:expr, $INDEX_TY:ident) => {{ TestCase { array: Arc::new( $INPUT @@ -2449,7 +2608,12 @@ mod tests { ), scalars: $INPUT .iter() - .map(|v| ScalarValue::$SCALAR_TY(v.map(|v| v.to_string()))) + .map(|v| { + ScalarValue::Dictionary( + Box::new($INDEX_TY::DATA_TYPE), + Box::new(ScalarValue::Utf8(v.map(|v| v.to_string()))), + ) + }) .collect(), } }}; @@ -2518,18 +2682,21 @@ mod tests { ), make_test_case!(i32_vals, IntervalYearMonthArray, IntervalYearMonth), make_test_case!(i64_vals, IntervalDayTimeArray, IntervalDayTime), - make_str_dict_test_case!(str_vals, Int8Type, Utf8), - make_str_dict_test_case!(str_vals, Int16Type, Utf8), - make_str_dict_test_case!(str_vals, Int32Type, Utf8), - make_str_dict_test_case!(str_vals, Int64Type, Utf8), - make_str_dict_test_case!(str_vals, UInt8Type, Utf8), - make_str_dict_test_case!(str_vals, UInt16Type, Utf8), - make_str_dict_test_case!(str_vals, UInt32Type, Utf8), - make_str_dict_test_case!(str_vals, UInt64Type, Utf8), + make_str_dict_test_case!(str_vals, Int8Type), + make_str_dict_test_case!(str_vals, Int16Type), + make_str_dict_test_case!(str_vals, Int32Type), + make_str_dict_test_case!(str_vals, Int64Type), + make_str_dict_test_case!(str_vals, UInt8Type), + make_str_dict_test_case!(str_vals, UInt16Type), + make_str_dict_test_case!(str_vals, UInt32Type), + make_str_dict_test_case!(str_vals, UInt64Type), ]; for case in cases { + println!("**** Test Case *****"); let TestCase { array, scalars } = case; + println!("Input array type: {}", array.data_type()); + println!("Input scalars: {:#?}", scalars); assert_eq!(array.len(), scalars.len()); for (index, scalar) in scalars.into_iter().enumerate() { @@ -3159,4 +3326,42 @@ mod tests { DataType::Timestamp(TimeUnit::Nanosecond, Some("UTC".to_owned())) ); } + + #[test] + fn cast_round_trip() { + check_scalar_cast(ScalarValue::Int8(Some(5)), DataType::Int16); + check_scalar_cast(ScalarValue::Int8(None), DataType::Int16); + + check_scalar_cast(ScalarValue::Float64(Some(5.5)), DataType::Int16); + + check_scalar_cast(ScalarValue::Float64(None), DataType::Int16); + + check_scalar_cast( + ScalarValue::Utf8(Some("foo".to_string())), + DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)), + ); + + check_scalar_cast( + ScalarValue::Utf8(None), + DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)), + ); + } + + // mimics how casting work on scalar values by `casting` `scalar` to `desired_type` + fn check_scalar_cast(scalar: ScalarValue, desired_type: DataType) { + // convert from scalar --> Array to call cast + let scalar_array = scalar.to_array(); + // cast the actual value + let cast_array = kernels::cast::cast(&scalar_array, &desired_type).unwrap(); + + // turn it back to a scalar + let cast_scalar = ScalarValue::try_from_array(&cast_array, 0).unwrap(); + assert_eq!(cast_scalar.get_datatype(), desired_type); + + // Some time later the "cast" scalar is turned back into an array: + let array = cast_scalar.to_array_of_size(10); + + // The datatype should be "Dictionary" but is actually Utf8!!! + assert_eq!(array.data_type(), &desired_type) + } } diff --git a/datafusion/core/tests/path_partition.rs b/datafusion/core/tests/path_partition.rs index a88445d781a9..821d174f2d99 100644 --- a/datafusion/core/tests/path_partition.rs +++ b/datafusion/core/tests/path_partition.rs @@ -149,9 +149,10 @@ async fn parquet_distinct_partition_col() -> Result<()> { assert_eq!(min_limit, resulting_limit); - let month = match ScalarValue::try_from_array(results[0].column(1), 0)? { - ScalarValue::Utf8(Some(month)) => month, - s => panic!("Expected count as Int64 found {}", s.get_datatype()), + let s = ScalarValue::try_from_array(results[0].column(1), 0)?; + let month = match extract_as_utf(&s) { + Some(month) => month, + s => panic!("Expected month as Dict(_, Utf8) found {:?}", s), }; let sql_on_partition_boundary = format!( @@ -172,6 +173,15 @@ async fn parquet_distinct_partition_col() -> Result<()> { Ok(()) } +fn extract_as_utf(v: &ScalarValue) -> Option { + if let ScalarValue::Dictionary(_, v) = v { + if let ScalarValue::Utf8(v) = v.as_ref() { + return v.clone(); + } + } + None +} + #[tokio::test] async fn csv_filter_with_file_col() -> Result<()> { let ctx = SessionContext::new(); diff --git a/datafusion/core/tests/sql/projection.rs b/datafusion/core/tests/sql/projection.rs index a964acae4424..363e96c364c6 100644 --- a/datafusion/core/tests/sql/projection.rs +++ b/datafusion/core/tests/sql/projection.rs @@ -17,6 +17,7 @@ use datafusion::logical_plan::{provider_as_source, LogicalPlanBuilder, UNNAMED_TABLE}; use datafusion::test_util::scan_empty; +use datafusion_expr::when; use tempfile::TempDir; use super::*; @@ -220,6 +221,48 @@ async fn preserve_nullability_on_projection() -> Result<()> { Ok(()) } +#[tokio::test] +async fn project_cast_dictionary() { + let ctx = SessionContext::new(); + + let host: DictionaryArray = vec![Some("host1"), None, Some("host2")] + .into_iter() + .collect(); + + let batch = RecordBatch::try_from_iter(vec![("host", Arc::new(host) as _)]).unwrap(); + + let t = MemTable::try_new(batch.schema(), vec![vec![batch]]).unwrap(); + + // Note that `host` is a dictionary array but `lit("")` is a DataType::Utf8 that needs to be cast + let expr = when(col("host").is_null(), lit("")) + .otherwise(col("host")) + .unwrap(); + + let projection = None; + let builder = LogicalPlanBuilder::scan( + "cpu_load_short", + provider_as_source(Arc::new(t)), + projection, + ) + .unwrap(); + + let logical_plan = builder.project(vec![expr]).unwrap().build().unwrap(); + + let physical_plan = ctx.create_physical_plan(&logical_plan).await.unwrap(); + let actual = collect(physical_plan, ctx.task_ctx()).await.unwrap(); + + let expected = vec![ + "+------------------------------------------------------------------------------------+", + "| CASE WHEN #cpu_load_short.host IS NULL THEN Utf8(\"\") ELSE #cpu_load_short.host END |", + "+------------------------------------------------------------------------------------+", + "| host1 |", + "| |", + "| host2 |", + "+------------------------------------------------------------------------------------+", + ]; + assert_batches_eq!(expected, &actual); +} + #[tokio::test] async fn projection_on_memory_scan() -> Result<()> { let schema = Schema::new(vec![ diff --git a/datafusion/physical-expr/src/aggregate/count_distinct.rs b/datafusion/physical-expr/src/aggregate/count_distinct.rs index e7a8b8c5663f..5b3ef852e156 100644 --- a/datafusion/physical-expr/src/aggregate/count_distinct.rs +++ b/datafusion/physical-expr/src/aggregate/count_distinct.rs @@ -54,7 +54,7 @@ impl DistinctCount { name: String, data_type: DataType, ) -> Self { - let state_data_types = input_data_types.into_iter().map(state_type).collect(); + let state_data_types = input_data_types; Self { name, @@ -65,15 +65,6 @@ impl DistinctCount { } } -/// return the type to use to accumulate state for the specified input type -fn state_type(data_type: DataType) -> DataType { - match data_type { - // when aggregating dictionary values, use the underlying value type - DataType::Dictionary(_key_type, value_type) => *value_type, - t => t, - } -} - impl AggregateExpr for DistinctCount { /// Return a reference to Any that can be used for downcasting fn as_any(&self) -> &dyn Any { diff --git a/datafusion/physical-expr/src/expressions/binary.rs b/datafusion/physical-expr/src/expressions/binary.rs index 417306221b7b..93bcfae2ec7e 100644 --- a/datafusion/physical-expr/src/expressions/binary.rs +++ b/datafusion/physical-expr/src/expressions/binary.rs @@ -1060,14 +1060,26 @@ impl PhysicalExpr for BinaryExpr { } } -/// The binary_array_op_dyn_scalar macro includes types that extend beyond the primitive, -/// such as Utf8 strings. +/// unwrap underlying (non dictionary) value, if any, to pass to a scalar kernel +fn unwrap_dict_value(v: ScalarValue) -> ScalarValue { + if let ScalarValue::Dictionary(_key_type, v) = v { + unwrap_dict_value(*v) + } else { + v + } +} + +/// The binary_array_op_dyn_scalar macro includes types that extend +/// beyond the primitive, such as Utf8 strings. #[macro_export] macro_rules! binary_array_op_dyn_scalar { ($LEFT:expr, $RIGHT:expr, $OP:ident, $OP_TYPE:expr) => {{ - let result: Result> = match $RIGHT { + // unwrap underlying (non dictionary) value + let right = unwrap_dict_value($RIGHT); + + let result: Result> = match right { ScalarValue::Boolean(b) => compute_bool_op_dyn_scalar!($LEFT, b, $OP, $OP_TYPE), - ScalarValue::Decimal128(..) => compute_decimal_op_scalar!($LEFT, $RIGHT, $OP, DecimalArray), + ScalarValue::Decimal128(..) => compute_decimal_op_scalar!($LEFT, right, $OP, DecimalArray), ScalarValue::Utf8(v) => compute_utf8_op_dyn_scalar!($LEFT, v, $OP, $OP_TYPE), ScalarValue::LargeUtf8(v) => compute_utf8_op_dyn_scalar!($LEFT, v, $OP, $OP_TYPE), ScalarValue::Int8(v) => compute_op_dyn_scalar!($LEFT, v, $OP, $OP_TYPE), @@ -1080,13 +1092,16 @@ macro_rules! binary_array_op_dyn_scalar { ScalarValue::UInt64(v) => compute_op_dyn_scalar!($LEFT, v, $OP, $OP_TYPE), ScalarValue::Float32(v) => compute_op_dyn_scalar!($LEFT, v, $OP, $OP_TYPE), ScalarValue::Float64(v) => compute_op_dyn_scalar!($LEFT, v, $OP, $OP_TYPE), - ScalarValue::Date32(_) => compute_op_scalar!($LEFT, $RIGHT, $OP, Date32Array), - ScalarValue::Date64(_) => compute_op_scalar!($LEFT, $RIGHT, $OP, Date64Array), - ScalarValue::TimestampSecond(..) => compute_op_scalar!($LEFT, $RIGHT, $OP, TimestampSecondArray), - ScalarValue::TimestampMillisecond(..) => compute_op_scalar!($LEFT, $RIGHT, $OP, TimestampMillisecondArray), - ScalarValue::TimestampMicrosecond(..) => compute_op_scalar!($LEFT, $RIGHT, $OP, TimestampMicrosecondArray), - ScalarValue::TimestampNanosecond(..) => compute_op_scalar!($LEFT, $RIGHT, $OP, TimestampNanosecondArray), - other => Err(DataFusionError::Internal(format!("Data type {:?} not supported for scalar operation '{}' on dyn array", other, stringify!($OP)))) + ScalarValue::Date32(_) => compute_op_scalar!($LEFT, right, $OP, Date32Array), + ScalarValue::Date64(_) => compute_op_scalar!($LEFT, right, $OP, Date64Array), + ScalarValue::TimestampSecond(..) => compute_op_scalar!($LEFT, right, $OP, TimestampSecondArray), + ScalarValue::TimestampMillisecond(..) => compute_op_scalar!($LEFT, right, $OP, TimestampMillisecondArray), + ScalarValue::TimestampMicrosecond(..) => compute_op_scalar!($LEFT, right, $OP, TimestampMicrosecondArray), + ScalarValue::TimestampNanosecond(..) => compute_op_scalar!($LEFT, right, $OP, TimestampNanosecondArray), + other => Err(DataFusionError::Internal(format!( + "Data type {:?} not supported for scalar operation '{}' on dyn array", + other, stringify!($OP))) + ) }; Some(result) }}