diff --git a/benchmarks/src/tpch.rs b/benchmarks/src/tpch.rs index 51f67548a828..59d2654239df 100644 --- a/benchmarks/src/tpch.rs +++ b/benchmarks/src/tpch.rs @@ -15,7 +15,7 @@ // specific language governing permissions and limitations // under the License. -use arrow::array::{Array, ArrayRef, Decimal128Array, Float64Array, StringArray}; +use arrow::array::{Array, ArrayRef, Float64Array, StringArray}; use arrow::datatypes::SchemaRef; use arrow::record_batch::RecordBatch; use std::fs; @@ -24,7 +24,9 @@ use std::path::Path; use std::sync::Arc; use std::time::Instant; -use datafusion::common::cast::{as_date32_array, as_int32_array, as_int64_array}; +use datafusion::common::cast::{ + as_date32_array, as_decimal128_array, as_int32_array, as_int64_array, +}; use datafusion::common::ScalarValue; use datafusion::logical_expr::Cast; use datafusion::prelude::*; @@ -434,7 +436,7 @@ fn col_to_scalar(column: &ArrayRef, row_index: usize) -> ScalarValue { ScalarValue::Float64(Some(array.value(row_index))) } DataType::Decimal128(p, s) => { - let array = column.as_any().downcast_ref::().unwrap(); + let array = as_decimal128_array(column).unwrap(); ScalarValue::Decimal128(Some(array.value(row_index)), *p, *s) } DataType::Date32 => { diff --git a/datafusion/common/src/cast.rs b/datafusion/common/src/cast.rs index 32cbc0af7748..4618050c59ac 100644 --- a/datafusion/common/src/cast.rs +++ b/datafusion/common/src/cast.rs @@ -21,7 +21,9 @@ //! kernels in arrow-rs such as `as_boolean_array` do. use crate::DataFusionError; -use arrow::array::{Array, Date32Array, Int32Array, Int64Array, StructArray}; +use arrow::array::{ + Array, Date32Array, Decimal128Array, Int32Array, Int64Array, StructArray, +}; // Downcast ArrayRef to Date32Array pub fn as_date32_array(array: &dyn Array) -> Result<&Date32Array, DataFusionError> { @@ -62,3 +64,18 @@ pub fn as_int64_array(array: &dyn Array) -> Result<&Int64Array, DataFusionError> )) }) } + +// Downcast ArrayRef to Decimal128Array +pub fn as_decimal128_array( + array: &dyn Array, +) -> Result<&Decimal128Array, DataFusionError> { + array + .as_any() + .downcast_ref::() + .ok_or_else(|| { + DataFusionError::Internal(format!( + "Expected a Decimal128Array, got: {}", + array.data_type() + )) + }) +} diff --git a/datafusion/common/src/scalar.rs b/datafusion/common/src/scalar.rs index 76c7cacccc7e..5caac82bf167 100644 --- a/datafusion/common/src/scalar.rs +++ b/datafusion/common/src/scalar.rs @@ -24,7 +24,7 @@ use std::ops::{Add, Sub}; use std::str::FromStr; use std::{convert::TryFrom, fmt, iter::repeat, sync::Arc}; -use crate::cast::as_struct_array; +use crate::cast::{as_decimal128_array, as_struct_array}; use crate::delta::shift_months; use crate::error::{DataFusionError, Result}; use arrow::{ @@ -1882,12 +1882,13 @@ impl ScalarValue { index: usize, precision: u8, scale: u8, - ) -> ScalarValue { - let array = array.as_any().downcast_ref::().unwrap(); + ) -> Result { + let array = as_decimal128_array(array)?; if array.is_null(index) { - ScalarValue::Decimal128(None, precision, scale) + Ok(ScalarValue::Decimal128(None, precision, scale)) } else { - ScalarValue::Decimal128(Some(array.value(index)), precision, scale) + let value = array.value(index); + Ok(ScalarValue::Decimal128(Some(value), precision, scale)) } } @@ -1903,7 +1904,7 @@ impl ScalarValue { DataType::Decimal128(precision, scale) => { ScalarValue::get_decimal_value_from_array( array, index, *precision, *scale, - ) + )? } DataType::Boolean => typed_cast!(array, index, BooleanArray, Boolean), DataType::Float64 => typed_cast!(array, index, Float64Array, Float64), @@ -2074,14 +2075,16 @@ impl ScalarValue { value: &Option, precision: u8, scale: u8, - ) -> bool { - let array = array.as_any().downcast_ref::().unwrap(); + ) -> Result { + let array = as_decimal128_array(array)?; if array.precision() != precision || array.scale() != scale { - return false; + return Ok(false); } - match value { - None => array.is_null(index), - Some(v) => !array.is_null(index) && array.value(index) == *v, + let is_null = array.is_null(index); + if let Some(v) = value { + Ok(!array.is_null(index) && array.value(index) == *v) + } else { + Ok(is_null) } } @@ -2106,6 +2109,7 @@ impl ScalarValue { match self { ScalarValue::Decimal128(v, precision, scale) => { ScalarValue::eq_array_decimal(array, index, v, *precision, *scale) + .unwrap() } ScalarValue::Boolean(val) => { eq_array_primitive!(array, index, BooleanArray, val) @@ -2697,14 +2701,14 @@ mod tests { // decimal scalar to array let array = decimal_value.to_array(); - let array = array.as_any().downcast_ref::().unwrap(); + let array = as_decimal128_array(&array)?; assert_eq!(1, array.len()); assert_eq!(DataType::Decimal128(10, 1), array.data_type().clone()); assert_eq!(123i128, array.value(0)); // decimal scalar to array with size let array = decimal_value.to_array_of_size(10); - let array_decimal = array.as_any().downcast_ref::().unwrap(); + let array_decimal = as_decimal128_array(&array)?; assert_eq!(10, array.len()); assert_eq!(DataType::Decimal128(10, 1), array.data_type().clone()); assert_eq!(123i128, array_decimal.value(0)); diff --git a/datafusion/physical-expr/Cargo.toml b/datafusion/physical-expr/Cargo.toml index 91b5baaa5c0e..0d732c538887 100644 --- a/datafusion/physical-expr/Cargo.toml +++ b/datafusion/physical-expr/Cargo.toml @@ -40,7 +40,7 @@ unicode_expressions = ["unicode-segmentation"] [dependencies] ahash = { version = "0.8", default-features = false, features = ["runtime-rng"] } -arrow = { version = "26.0.0", features = ["prettyprint"] } +arrow = { version = "26.0.0", features = ["prettyprint", "dyn_cmp_dict"] } arrow-buffer = "26.0.0" arrow-schema = "26.0.0" blake2 = { version = "^0.10.2", optional = true } diff --git a/datafusion/physical-expr/src/expressions/binary.rs b/datafusion/physical-expr/src/expressions/binary.rs index a2d769a48cba..88dcaaccca04 100644 --- a/datafusion/physical-expr/src/expressions/binary.rs +++ b/datafusion/physical-expr/src/expressions/binary.rs @@ -75,6 +75,7 @@ use arrow::record_batch::RecordBatch; use crate::physical_expr::down_cast_any_ref; use crate::{AnalysisContext, ExprBoundaries, PhysicalExpr}; +use datafusion_common::cast::as_decimal128_array; use datafusion_common::ScalarValue; use datafusion_common::{DataFusionError, Result}; use datafusion_expr::type_coercion::binary::binary_operator_data_type; @@ -122,7 +123,7 @@ impl std::fmt::Display for BinaryExpr { macro_rules! compute_decimal_op_dyn_scalar { ($LEFT:expr, $RIGHT:expr, $OP:ident, $OP_TYPE:expr) => {{ - let ll = $LEFT.as_any().downcast_ref::().unwrap(); + let ll = as_decimal128_array($LEFT).unwrap(); if let ScalarValue::Decimal128(Some(_), _, _) = $RIGHT { Ok(Arc::new(paste::expr! {[<$OP _decimal_scalar>]}( ll, @@ -137,7 +138,7 @@ macro_rules! compute_decimal_op_dyn_scalar { macro_rules! compute_decimal_op_scalar { ($LEFT:expr, $RIGHT:expr, $OP:ident, $DT:ident) => {{ - let ll = $LEFT.as_any().downcast_ref::().unwrap(); + let ll = as_decimal128_array($LEFT).unwrap(); if let ScalarValue::Decimal128(Some(_), _, _) = $RIGHT { Ok(Arc::new(paste::expr! {[<$OP _decimal_scalar>]}( ll, diff --git a/datafusion/physical-expr/src/expressions/binary/adapter.rs b/datafusion/physical-expr/src/expressions/binary/adapter.rs index 12b8fab89d76..a4726ba2cae4 100644 --- a/datafusion/physical-expr/src/expressions/binary/adapter.rs +++ b/datafusion/physical-expr/src/expressions/binary/adapter.rs @@ -23,6 +23,7 @@ use std::sync::Arc; use super::kernels_arrow::*; use arrow::array::*; use arrow::datatypes::DataType; +use datafusion_common::cast::as_decimal128_array; use datafusion_common::Result; /// create a `dyn_op` wrapper function for the specified operation @@ -39,7 +40,7 @@ macro_rules! make_dyn_comp_op { // arrow has native support // https://github.com/apache/arrow-rs/issues/1200 (DataType::Decimal128(_, _), DataType::Decimal128(_, _)) => { - [<$OP _decimal>](as_decimal_array(left), as_decimal_array(right)) + [<$OP _decimal>](as_decimal128_array(left).unwrap(), as_decimal128_array(right).unwrap()) }, // By default call the arrow kernel _ => {