diff --git a/datafusion/common/src/cast.rs b/datafusion/common/src/cast.rs index 9401689474bb..3d5036d7e438 100644 --- a/datafusion/common/src/cast.rs +++ b/datafusion/common/src/cast.rs @@ -20,132 +20,71 @@ //! but provide an error message rather than a panic, as the corresponding //! kernels in arrow-rs such as `as_boolean_array` do. -use crate::DataFusionError; +use crate::{downcast_value, DataFusionError}; use arrow::array::{ Array, BooleanArray, Date32Array, Decimal128Array, Float32Array, Float64Array, - Int32Array, Int64Array, StringArray, StructArray, UInt32Array, UInt64Array, + Int32Array, Int64Array, ListArray, StringArray, StructArray, UInt32Array, + UInt64Array, }; // Downcast ArrayRef to Date32Array pub fn as_date32_array(array: &dyn Array) -> Result<&Date32Array, DataFusionError> { - array.as_any().downcast_ref::().ok_or_else(|| { - DataFusionError::Internal(format!( - "Expected a Date32Array, got: {}", - array.data_type() - )) - }) + Ok(downcast_value!(array, Date32Array)) } // Downcast ArrayRef to StructArray pub fn as_struct_array(array: &dyn Array) -> Result<&StructArray, DataFusionError> { - array.as_any().downcast_ref::().ok_or_else(|| { - DataFusionError::Internal(format!( - "Expected a StructArray, got: {}", - array.data_type() - )) - }) + Ok(downcast_value!(array, StructArray)) } // Downcast ArrayRef to Int32Array pub fn as_int32_array(array: &dyn Array) -> Result<&Int32Array, DataFusionError> { - array.as_any().downcast_ref::().ok_or_else(|| { - DataFusionError::Internal(format!( - "Expected a Int32Array, got: {}", - array.data_type() - )) - }) + Ok(downcast_value!(array, Int32Array)) } // Downcast ArrayRef to Int64Array pub fn as_int64_array(array: &dyn Array) -> Result<&Int64Array, DataFusionError> { - array.as_any().downcast_ref::().ok_or_else(|| { - DataFusionError::Internal(format!( - "Expected a Int64Array, got: {}", - array.data_type() - )) - }) + Ok(downcast_value!(array, Int64Array)) } // Downcast ArrayRef to Decimal128Array pub fn as_decimal128_array( array: &dyn Array, ) -> Result<&Decimal128Array, DataFusionError> { - array - .as_any() - .downcast_ref::() - .ok_or_else(|| { - DataFusionError::Internal(format!( - "Expected a Decimal128Array, got: {}", - array.data_type() - )) - }) + Ok(downcast_value!(array, Decimal128Array)) } // Downcast ArrayRef to Float32Array pub fn as_float32_array(array: &dyn Array) -> Result<&Float32Array, DataFusionError> { - array - .as_any() - .downcast_ref::() - .ok_or_else(|| { - DataFusionError::Internal(format!( - "Expected a Float32Array, got: {}", - array.data_type() - )) - }) + Ok(downcast_value!(array, Float32Array)) } // Downcast ArrayRef to Float64Array pub fn as_float64_array(array: &dyn Array) -> Result<&Float64Array, DataFusionError> { - array - .as_any() - .downcast_ref::() - .ok_or_else(|| { - DataFusionError::Internal(format!( - "Expected a Float64Array, got: {}", - array.data_type() - )) - }) + Ok(downcast_value!(array, Float64Array)) } // Downcast ArrayRef to StringArray pub fn as_string_array(array: &dyn Array) -> Result<&StringArray, DataFusionError> { - array.as_any().downcast_ref::().ok_or_else(|| { - DataFusionError::Internal(format!( - "Expected a StringArray, got: {}", - array.data_type() - )) - }) + Ok(downcast_value!(array, StringArray)) } // Downcast ArrayRef to UInt32Array pub fn as_uint32_array(array: &dyn Array) -> Result<&UInt32Array, DataFusionError> { - array.as_any().downcast_ref::().ok_or_else(|| { - DataFusionError::Internal(format!( - "Expected a UInt32Array, got: {}", - array.data_type() - )) - }) + Ok(downcast_value!(array, UInt32Array)) } // Downcast ArrayRef to UInt64Array pub fn as_uint64_array(array: &dyn Array) -> Result<&UInt64Array, DataFusionError> { - array.as_any().downcast_ref::().ok_or_else(|| { - DataFusionError::Internal(format!( - "Expected a UInt64Array, got: {}", - array.data_type() - )) - }) + Ok(downcast_value!(array, UInt64Array)) } // Downcast ArrayRef to BooleanArray pub fn as_boolean_array(array: &dyn Array) -> Result<&BooleanArray, DataFusionError> { - array - .as_any() - .downcast_ref::() - .ok_or_else(|| { - DataFusionError::Internal(format!( - "Expected a BooleanArray, got: {}", - array.data_type() - )) - }) + Ok(downcast_value!(array, BooleanArray)) +} + +// Downcast ArrayRef to ListArray +pub fn as_list_array(array: &dyn Array) -> Result<&ListArray, DataFusionError> { + Ok(downcast_value!(array, ListArray)) } diff --git a/datafusion/common/src/scalar.rs b/datafusion/common/src/scalar.rs index 44bf278b9674..96d1ab6722f7 100644 --- a/datafusion/common/src/scalar.rs +++ b/datafusion/common/src/scalar.rs @@ -24,7 +24,7 @@ use std::ops::{Add, Sub}; use std::str::FromStr; use std::{convert::TryFrom, fmt, iter::repeat, sync::Arc}; -use crate::cast::{as_decimal128_array, as_struct_array}; +use crate::cast::{as_decimal128_array, as_list_array, as_struct_array}; use crate::delta::shift_months; use crate::error::{DataFusionError, Result}; use arrow::{ @@ -2001,12 +2001,7 @@ impl ScalarValue { DataType::Utf8 => typed_cast!(array, index, StringArray, Utf8), DataType::LargeUtf8 => typed_cast!(array, index, LargeStringArray, LargeUtf8), DataType::List(nested_type) => { - let list_array = - array.as_any().downcast_ref::().ok_or_else(|| { - DataFusionError::Internal( - "Failed to downcast ListArray".to_string(), - ) - })?; + let list_array = as_list_array(array)?; let value = match list_array.is_null(index) { true => None, false => { @@ -2940,7 +2935,7 @@ mod tests { Box::new(Field::new("item", DataType::UInt64, false)), ) .to_array(); - let list_array = list_array_ref.as_any().downcast_ref::().unwrap(); + let list_array = as_list_array(&list_array_ref).unwrap(); assert!(list_array.is_null(0)); assert_eq!(list_array.len(), 1); @@ -2959,7 +2954,7 @@ mod tests { ) .to_array(); - let list_array = list_array_ref.as_any().downcast_ref::().unwrap(); + let list_array = as_list_array(&list_array_ref)?; assert_eq!(list_array.len(), 1); assert_eq!(list_array.values().len(), 3); @@ -3758,7 +3753,7 @@ mod tests { let nl2 = ScalarValue::new_list(Some(vec![s1]), s0.get_datatype()); // iter_to_array for list-of-struct let array = ScalarValue::iter_to_array(vec![nl0, nl1, nl2]).unwrap(); - let array = array.as_any().downcast_ref::().unwrap(); + let array = as_list_array(&array).unwrap(); // Construct expected array with array builders let field_a_builder = StringBuilder::with_capacity(4, 1024); @@ -3922,7 +3917,7 @@ mod tests { ); let array = ScalarValue::iter_to_array(vec![l1, l2, l3]).unwrap(); - let array = array.as_any().downcast_ref::().unwrap(); + let array = as_list_array(&array).unwrap(); // Construct expected array with array builders let inner_builder = Int32Array::builder(8); diff --git a/datafusion/core/src/avro_to_arrow/arrow_array_reader.rs b/datafusion/core/src/avro_to_arrow/arrow_array_reader.rs index f40c5045af5b..30c32fc572a8 100644 --- a/datafusion/core/src/avro_to_arrow/arrow_array_reader.rs +++ b/datafusion/core/src/avro_to_arrow/arrow_array_reader.rs @@ -975,9 +975,9 @@ mod test { use crate::arrow::array::Array; use crate::arrow::datatypes::{Field, TimeUnit}; use crate::avro_to_arrow::{Reader, ReaderBuilder}; - use arrow::array::{ListArray, TimestampMicrosecondArray}; + use arrow::array::TimestampMicrosecondArray; use arrow::datatypes::DataType; - use datafusion_common::cast::{as_int32_array, as_int64_array}; + use datafusion_common::cast::{as_int32_array, as_int64_array, as_list_array}; use std::fs::File; fn build_reader(name: &str, batch_size: usize) -> Reader { @@ -1034,11 +1034,7 @@ mod test { let batch = reader.next().unwrap().unwrap(); assert_eq!(batch.num_columns(), 2); assert_eq!(batch.num_rows(), 3); - let a_array = batch - .column(col_id_index) - .as_any() - .downcast_ref::() - .unwrap(); + let a_array = as_list_array(batch.column(col_id_index)).unwrap(); assert_eq!( *a_array.data_type(), DataType::List(Box::new(Field::new("bigint", DataType::Int64, true))) diff --git a/datafusion/core/tests/sql/parquet.rs b/datafusion/core/tests/sql/parquet.rs index e7b1fe5b15f1..7cf4f343fb40 100644 --- a/datafusion/core/tests/sql/parquet.rs +++ b/datafusion/core/tests/sql/parquet.rs @@ -19,7 +19,7 @@ use std::{fs, path::Path}; use ::parquet::arrow::ArrowWriter; use datafusion::datasource::listing::ListingOptions; -use datafusion_common::cast::as_string_array; +use datafusion_common::cast::{as_list_array, as_string_array}; use tempfile::TempDir; use super::*; @@ -235,16 +235,8 @@ async fn parquet_list_columns() { assert_eq!(2, batch.num_columns()); assert_eq!(schema, batch.schema()); - let int_list_array = batch - .column(0) - .as_any() - .downcast_ref::() - .unwrap(); - let utf8_list_array = batch - .column(1) - .as_any() - .downcast_ref::() - .unwrap(); + let int_list_array = as_list_array(batch.column(0)).unwrap(); + let utf8_list_array = as_list_array(batch.column(1)).unwrap(); assert_eq!( int_list_array diff --git a/datafusion/physical-expr/src/aggregate/count_distinct.rs b/datafusion/physical-expr/src/aggregate/count_distinct.rs index 943f7b632e07..d4c0b4406adc 100644 --- a/datafusion/physical-expr/src/aggregate/count_distinct.rs +++ b/datafusion/physical-expr/src/aggregate/count_distinct.rs @@ -226,11 +226,11 @@ mod tests { use crate::aggregate::utils::get_accum_scalar_values; use arrow::array::{ ArrayRef, BooleanArray, Float32Array, Float64Array, Int16Array, Int32Array, - Int64Array, Int8Array, ListArray, UInt16Array, UInt32Array, UInt64Array, - UInt8Array, + Int64Array, Int8Array, UInt16Array, UInt32Array, UInt64Array, UInt8Array, }; use arrow::array::{Int32Builder, ListBuilder, UInt64Builder}; use arrow::datatypes::DataType; + use datafusion_common::cast::as_list_array; macro_rules! state_to_vec { ($LIST:expr, $DATA_TYPE:ident, $PRIM_TY:ty) => {{ @@ -380,7 +380,7 @@ mod tests { let agg = DistinctCount::new( arrays .iter() - .map(|a| a.as_any().downcast_ref::().unwrap()) + .map(|a| as_list_array(a).unwrap()) .map(|a| a.values().data_type().clone()) .collect::>(), vec![], diff --git a/datafusion/physical-expr/src/expressions/get_indexed_field.rs b/datafusion/physical-expr/src/expressions/get_indexed_field.rs index 2b77a0b95f80..8fbb68d61a9a 100644 --- a/datafusion/physical-expr/src/expressions/get_indexed_field.rs +++ b/datafusion/physical-expr/src/expressions/get_indexed_field.rs @@ -19,7 +19,6 @@ use crate::PhysicalExpr; use arrow::array::Array; -use arrow::array::ListArray; use arrow::compute::concat; use crate::physical_expr::down_cast_any_ref; @@ -27,7 +26,7 @@ use arrow::{ datatypes::{DataType, Schema}, record_batch::RecordBatch, }; -use datafusion_common::cast::as_struct_array; +use datafusion_common::cast::{as_list_array, as_struct_array}; use datafusion_common::DataFusionError; use datafusion_common::Result; use datafusion_common::ScalarValue; @@ -91,8 +90,7 @@ impl PhysicalExpr for GetIndexedFieldExpr { Ok(ColumnarValue::Scalar(scalar_null)) } (DataType::List(_), ScalarValue::Int64(Some(i))) => { - let as_list_array = - array.as_any().downcast_ref::().unwrap(); + let as_list_array = as_list_array(&array)?; if *i < 1 || as_list_array.is_empty() { let scalar_null: ScalarValue = array.data_type().try_into()?; @@ -349,10 +347,7 @@ mod tests { let get_list_expr = Arc::new(GetIndexedFieldExpr::new(struct_col_expr, list_field_key)); let result = get_list_expr.evaluate(&batch)?.into_array(batch.num_rows()); - let result = result - .as_any() - .downcast_ref::() - .unwrap_or_else(|| panic!("failed to downcast to ListArray : {:?}", result)); + let result = as_list_array(&result)?; let expected = &build_utf8_lists(list_of_tuples.into_iter().map(|t| t.1).collect()); assert_eq!(expected, result); diff --git a/datafusion/physical-expr/src/functions.rs b/datafusion/physical-expr/src/functions.rs index c84ee24c5c10..1ed83b89ace7 100644 --- a/datafusion/physical-expr/src/functions.rs +++ b/datafusion/physical-expr/src/functions.rs @@ -2847,8 +2847,7 @@ mod tests { #[test] #[cfg(feature = "regex_expressions")] fn test_regexp_match() -> Result<()> { - use arrow::array::ListArray; - use datafusion_common::cast::as_string_array; + use datafusion_common::cast::{as_list_array, as_string_array}; let schema = Schema::new(vec![Field::new("a", DataType::Utf8, false)]); let execution_props = ExecutionProps::new(); @@ -2873,7 +2872,7 @@ mod tests { let result = expr.evaluate(&batch)?.into_array(batch.num_rows()); // downcast works - let result = result.as_any().downcast_ref::().unwrap(); + let result = as_list_array(&result)?; let first_row = result.value(0); let first_row = as_string_array(&first_row)?; @@ -2887,8 +2886,7 @@ mod tests { #[test] #[cfg(feature = "regex_expressions")] fn test_regexp_match_all_literals() -> Result<()> { - use arrow::array::ListArray; - use datafusion_common::cast::as_string_array; + use datafusion_common::cast::{as_list_array, as_string_array}; let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]); let execution_props = ExecutionProps::new(); @@ -2913,7 +2911,7 @@ mod tests { let result = expr.evaluate(&batch)?.into_array(batch.num_rows()); // downcast works - let result = result.as_any().downcast_ref::().unwrap(); + let result = as_list_array(&result)?; let first_row = result.value(0); let first_row = as_string_array(&first_row)?;