diff --git a/datafusion/common/src/cast.rs b/datafusion/common/src/cast.rs new file mode 100644 index 000000000000..16a2a7422aa2 --- /dev/null +++ b/datafusion/common/src/cast.rs @@ -0,0 +1,34 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! This module provides DataFusion specific casting functions +//! that provide error handling. They are intended to "never fail" +//! but provide an error message rather than a panic, as the corresponding +//! kernels in arrow-rs such as `as_boolean_array` do. + +use crate::DataFusionError; +use arrow::array::{Array, Date32Array}; + +// Downcast ArrayRef to Date32Array +pub fn as_date32_array(array: &dyn Array) -> Result<&Date32Array, DataFusionError> { + array.as_any().downcast_ref::().ok_or_else(|| { + DataFusionError::Internal(format!( + "Expected a Date32Array, got: {}", + array.data_type() + )) + }) +} diff --git a/datafusion/common/src/lib.rs b/datafusion/common/src/lib.rs index 8330a360004a..1d33032ba0f6 100644 --- a/datafusion/common/src/lib.rs +++ b/datafusion/common/src/lib.rs @@ -16,6 +16,7 @@ // under the License. pub mod bisect; +pub mod cast; mod column; pub mod delta; mod dfschema; diff --git a/datafusion/physical-expr/src/datetime_expressions.rs b/datafusion/physical-expr/src/datetime_expressions.rs index 0214d1bfb9be..3dc9956841ea 100644 --- a/datafusion/physical-expr/src/datetime_expressions.rs +++ b/datafusion/physical-expr/src/datetime_expressions.rs @@ -27,7 +27,7 @@ use arrow::{ }; use arrow::{ array::{ - Date32Array, Date64Array, TimestampMicrosecondArray, TimestampMillisecondArray, + Date64Array, TimestampMicrosecondArray, TimestampMillisecondArray, TimestampNanosecondArray, TimestampSecondArray, }, compute::kernels::temporal, @@ -36,6 +36,7 @@ use arrow::{ }; use chrono::prelude::*; use chrono::Duration; +use datafusion_common::cast::as_date32_array; use datafusion_common::{DataFusionError, Result}; use datafusion_common::{ScalarType, ScalarValue}; use datafusion_expr::ColumnarValue; @@ -377,10 +378,10 @@ pub fn date_bin(args: &[ColumnarValue]) -> Result { macro_rules! extract_date_part { ($ARRAY: expr, $FN:expr) => { match $ARRAY.data_type() { - DataType::Date32 => { - let array = $ARRAY.as_any().downcast_ref::().unwrap(); - Ok($FN(array)?) - } + DataType::Date32 => match as_date32_array($ARRAY) { + Ok(array) => Ok($FN(array)?), + Err(e) => Err(e), + }, DataType::Date64 => { let array = $ARRAY.as_any().downcast_ref::().unwrap(); Ok($FN(array)?) @@ -448,16 +449,16 @@ pub fn date_part(args: &[ColumnarValue]) -> Result { }; let arr = match date_part.to_lowercase().as_str() { - "year" => extract_date_part!(array, temporal::year), - "quarter" => extract_date_part!(array, temporal::quarter), - "month" => extract_date_part!(array, temporal::month), - "week" => extract_date_part!(array, temporal::week), - "day" => extract_date_part!(array, temporal::day), - "doy" => extract_date_part!(array, temporal::doy), - "dow" => extract_date_part!(array, temporal::num_days_from_sunday), - "hour" => extract_date_part!(array, temporal::hour), - "minute" => extract_date_part!(array, temporal::minute), - "second" => extract_date_part!(array, temporal::second), + "year" => extract_date_part!(&array, temporal::year), + "quarter" => extract_date_part!(&array, temporal::quarter), + "month" => extract_date_part!(&array, temporal::month), + "week" => extract_date_part!(&array, temporal::week), + "day" => extract_date_part!(&array, temporal::day), + "doy" => extract_date_part!(&array, temporal::doy), + "dow" => extract_date_part!(&array, temporal::num_days_from_sunday), + "hour" => extract_date_part!(&array, temporal::hour), + "minute" => extract_date_part!(&array, temporal::minute), + "second" => extract_date_part!(&array, temporal::second), _ => Err(DataFusionError::Execution(format!( "Date part '{}' not supported", date_part diff --git a/datafusion/physical-expr/src/expressions/datetime.rs b/datafusion/physical-expr/src/expressions/datetime.rs index 97f038068794..e2776023af44 100644 --- a/datafusion/physical-expr/src/expressions/datetime.rs +++ b/datafusion/physical-expr/src/expressions/datetime.rs @@ -17,8 +17,8 @@ use crate::PhysicalExpr; use arrow::array::{ - Array, ArrayRef, Date32Array, Date64Array, TimestampMicrosecondArray, - TimestampMillisecondArray, TimestampNanosecondArray, TimestampSecondArray, + Array, ArrayRef, Date64Array, TimestampMicrosecondArray, TimestampMillisecondArray, + TimestampNanosecondArray, TimestampSecondArray, }; use arrow::compute::unary; use arrow::datatypes::{ @@ -26,6 +26,7 @@ use arrow::datatypes::{ TimestampMillisecondType, TimestampNanosecondType, TimestampSecondType, }; use arrow::record_batch::RecordBatch; +use datafusion_common::cast::as_date32_array; use datafusion_common::scalar::{ date32_add, date64_add, microseconds_add, milliseconds_add, nanoseconds_add, seconds_add, @@ -153,7 +154,7 @@ pub fn evaluate_array( ) -> Result { let ret = match array.data_type() { DataType::Date32 => { - let array = array.as_any().downcast_ref::().unwrap(); + let array = as_date32_array(&array)?; Arc::new(unary::(array, |days| { date32_add(days, scalar, sign).unwrap() })) as ArrayRef diff --git a/datafusion/physical-expr/src/expressions/in_list.rs b/datafusion/physical-expr/src/expressions/in_list.rs index 2d34115220ec..0f82852f1d32 100644 --- a/datafusion/physical-expr/src/expressions/in_list.rs +++ b/datafusion/physical-expr/src/expressions/in_list.rs @@ -36,6 +36,7 @@ use arrow::{ use crate::PhysicalExpr; use arrow::array::*; use arrow::datatypes::TimeUnit; +use datafusion_common::cast::as_date32_array; use datafusion_common::ScalarValue; use datafusion_common::ScalarValue::{ Binary, Boolean, Date32, Date64, Decimal128, Int16, Int32, Int64, Int8, LargeBinary, @@ -589,7 +590,7 @@ impl PhysicalExpr for InListExpr { )) } DataType::Date32 => { - let array = array.as_any().downcast_ref::().unwrap(); + let array = as_date32_array(&array)?; Ok(set_contains_for_primitive!( array, set, diff --git a/datafusion/row/src/writer.rs b/datafusion/row/src/writer.rs index e796be2a9912..c9d56368d3bc 100644 --- a/datafusion/row/src/writer.rs +++ b/datafusion/row/src/writer.rs @@ -22,6 +22,7 @@ use arrow::array::*; use arrow::datatypes::{DataType, Schema}; use arrow::record_batch::RecordBatch; use arrow::util::bit_util::{round_upto_power_of_2, set_bit_raw, unset_bit_raw}; +use datafusion_common::cast::as_date32_array; use datafusion_common::Result; use std::cmp::max; use std::sync::Arc; @@ -326,8 +327,10 @@ pub(crate) fn write_field_date32( col_idx: usize, row_idx: usize, ) { - let from = from.as_any().downcast_ref::().unwrap(); - to.set_date32(col_idx, from.value(row_idx)); + match as_date32_array(from) { + Ok(from) => to.set_date32(col_idx, from.value(row_idx)), + Err(e) => panic!("{}", e), + }; } pub(crate) fn write_field_date64(