Skip to content

Commit

Permalink
Improve Error Handling and Readibility for downcasting StructArray (#…
Browse files Browse the repository at this point in the history
…4061)

* improve error messages for StructArray

* refactor newly added Date32Array downcasting and correct error string

* beautify code

* changes after code review

* fix formatting
  • Loading branch information
retikulum committed Nov 3, 2022
1 parent 61429f8 commit 761e167
Show file tree
Hide file tree
Showing 4 changed files with 20 additions and 17 deletions.
6 changes: 3 additions & 3 deletions benchmarks/src/tpch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,7 @@
// under the License.

use arrow::array::{
Array, ArrayRef, Date32Array, Decimal128Array, Float64Array, Int32Array, Int64Array,
StringArray,
Array, ArrayRef, Decimal128Array, Float64Array, Int32Array, Int64Array, StringArray,
};
use arrow::datatypes::SchemaRef;
use arrow::record_batch::RecordBatch;
Expand All @@ -27,6 +26,7 @@ use std::path::Path;
use std::sync::Arc;
use std::time::Instant;

use datafusion::common::cast::as_date32_array;
use datafusion::common::ScalarValue;
use datafusion::logical_expr::Cast;
use datafusion::prelude::*;
Expand Down Expand Up @@ -440,7 +440,7 @@ fn col_to_scalar(column: &ArrayRef, row_index: usize) -> ScalarValue {
ScalarValue::Decimal128(Some(array.value(row_index)), *p, *s)
}
DataType::Date32 => {
let array = column.as_any().downcast_ref::<Date32Array>().unwrap();
let array = as_date32_array(column).unwrap();
ScalarValue::Date32(Some(array.value(row_index)))
}
DataType::Utf8 => {
Expand Down
12 changes: 11 additions & 1 deletion datafusion/common/src/cast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
//! kernels in arrow-rs such as `as_boolean_array` do.

use crate::DataFusionError;
use arrow::array::{Array, Date32Array};
use arrow::array::{Array, Date32Array, StructArray};

// Downcast ArrayRef to Date32Array
pub fn as_date32_array(array: &dyn Array) -> Result<&Date32Array, DataFusionError> {
Expand All @@ -32,3 +32,13 @@ pub fn as_date32_array(array: &dyn Array) -> Result<&Date32Array, DataFusionErro
))
})
}

// Downcast ArrayRef to StructArray
pub fn as_struct_array(array: &dyn Array) -> Result<&StructArray, DataFusionError> {
array.as_any().downcast_ref::<StructArray>().ok_or_else(|| {
DataFusionError::Internal(format!(
"Expected a StructArray, got: {}",
array.data_type()
))
})
}
14 changes: 3 additions & 11 deletions datafusion/common/src/scalar.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ use arrow::{
use chrono::{Datelike, Duration, NaiveDate, NaiveDateTime};
use ordered_float::OrderedFloat;

use crate::cast::as_struct_array;
use crate::delta::shift_months;
use crate::error::{DataFusionError, Result};

Expand Down Expand Up @@ -2008,15 +2009,7 @@ impl ScalarValue {
Self::Dictionary(key_type.clone(), Box::new(value))
}
DataType::Struct(fields) => {
let array =
array
.as_any()
.downcast_ref::<StructArray>()
.ok_or_else(|| {
DataFusionError::Internal(
"Failed to downcast ArrayRef to StructArray".to_string(),
)
})?;
let array = as_struct_array(array)?;
let mut field_values: Vec<ScalarValue> = Vec::new();
for col_index in 0..array.num_columns() {
let col_array = array.column(col_index);
Expand Down Expand Up @@ -3611,8 +3604,7 @@ mod tests {
// iter_to_array for struct scalars
let array =
ScalarValue::iter_to_array(vec![s0.clone(), s1.clone(), s2.clone()]).unwrap();
let array = array.as_any().downcast_ref::<StructArray>().unwrap();

let array = as_struct_array(&array).unwrap();
let expected = StructArray::from(vec![
(
field_a.clone(),
Expand Down
5 changes: 3 additions & 2 deletions datafusion/physical-expr/src/expressions/get_indexed_field.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,15 @@

use crate::PhysicalExpr;
use arrow::array::Array;
use arrow::array::{ListArray, StructArray};
use arrow::array::ListArray;
use arrow::compute::concat;

use crate::physical_expr::down_cast_any_ref;
use arrow::{
datatypes::{DataType, Schema},
record_batch::RecordBatch,
};
use datafusion_common::cast::as_struct_array;
use datafusion_common::DataFusionError;
use datafusion_common::Result;
use datafusion_common::ScalarValue;
Expand Down Expand Up @@ -122,7 +123,7 @@ impl PhysicalExpr for GetIndexedFieldExpr {
}
}
(DataType::Struct(_), ScalarValue::Utf8(Some(k))) => {
let as_struct_array = array.as_any().downcast_ref::<StructArray>().unwrap();
let as_struct_array = as_struct_array(&array)?;
match as_struct_array.column_by_name(k) {
None => Err(DataFusionError::Execution(format!("get indexed field {} not found in struct", k))),
Some(col) => Ok(ColumnarValue::Array(col.clone()))
Expand Down

0 comments on commit 761e167

Please sign in to comment.