Skip to content

Commit

Permalink
improve error handling and add some more types (#4352)
Browse files Browse the repository at this point in the history
  • Loading branch information
retikulum committed Nov 27, 2022
1 parent 0496904 commit da54fa5
Show file tree
Hide file tree
Showing 14 changed files with 118 additions and 141 deletions.
47 changes: 43 additions & 4 deletions datafusion/common/src/cast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,14 @@
//! kernels in arrow-rs such as `as_boolean_array` do.

use crate::{downcast_value, DataFusionError};
use arrow::array::{
Array, BooleanArray, Date32Array, Decimal128Array, Float32Array, Float64Array,
Int32Array, Int64Array, ListArray, StringArray, StructArray, UInt32Array,
UInt64Array,
use arrow::{
array::{
Array, BooleanArray, Date32Array, Decimal128Array, DictionaryArray, Float32Array,
Float64Array, GenericBinaryArray, GenericListArray, Int32Array, Int64Array,
LargeListArray, ListArray, OffsetSizeTrait, PrimitiveArray, StringArray,
StructArray, UInt32Array, UInt64Array,
},
datatypes::{ArrowDictionaryKeyType, ArrowPrimitiveType},
};

// Downcast ArrayRef to Date32Array
Expand Down Expand Up @@ -88,3 +92,38 @@ pub fn as_boolean_array(array: &dyn Array) -> Result<&BooleanArray, DataFusionEr
pub fn as_list_array(array: &dyn Array) -> Result<&ListArray, DataFusionError> {
Ok(downcast_value!(array, ListArray))
}

// Downcast ArrayRef to DictionaryArray
pub fn as_dictionary_array<T: ArrowDictionaryKeyType>(
array: &dyn Array,
) -> Result<&DictionaryArray<T>, DataFusionError> {
Ok(downcast_value!(array, DictionaryArray, T))
}

// Downcast ArrayRef to GenericBinaryArray
pub fn as_generic_binary_array<T: OffsetSizeTrait>(
array: &dyn Array,
) -> Result<&GenericBinaryArray<T>, DataFusionError> {
Ok(downcast_value!(array, GenericBinaryArray, T))
}

// Downcast ArrayRef to GenericListArray
pub fn as_generic_list_array<T: OffsetSizeTrait>(
array: &dyn Array,
) -> Result<&GenericListArray<T>, DataFusionError> {
Ok(downcast_value!(array, GenericListArray, T))
}

// Downcast ArrayRef to LargeListArray
pub fn as_large_list_array(
array: &dyn Array,
) -> Result<&LargeListArray, DataFusionError> {
Ok(downcast_value!(array, LargeListArray))
}

// Downcast ArrayRef to PrimitiveArray
pub fn as_primitive_array<T: ArrowPrimitiveType>(
array: &dyn Array,
) -> Result<&PrimitiveArray<T>, DataFusionError> {
Ok(downcast_value!(array, PrimitiveArray, T))
}
8 changes: 5 additions & 3 deletions datafusion/common/src/scalar.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,9 @@ use std::ops::{Add, Sub};
use std::str::FromStr;
use std::{convert::TryFrom, fmt, iter::repeat, sync::Arc};

use crate::cast::{as_decimal128_array, as_list_array, as_struct_array};
use crate::cast::{
as_decimal128_array, as_dictionary_array, as_list_array, as_struct_array,
};
use crate::delta::shift_months;
use crate::error::{DataFusionError, Result};
use arrow::{
Expand Down Expand Up @@ -721,7 +723,7 @@ fn get_dict_value<K: ArrowDictionaryKeyType>(
array: &ArrayRef,
index: usize,
) -> (&ArrayRef, Option<usize>) {
let dict_array = as_dictionary_array::<K>(array);
let dict_array = as_dictionary_array::<K>(array).unwrap();
(dict_array.values(), dict_array.key(index))
}

Expand Down Expand Up @@ -3212,7 +3214,7 @@ mod tests {
];

let array = ScalarValue::iter_to_array(scalars.into_iter()).unwrap();
let array = as_dictionary_array::<Int32Type>(&array);
let array = as_dictionary_array::<Int32Type>(&array).unwrap();
let values_array = as_string_array(array.values()).unwrap();

let values = array
Expand Down
17 changes: 8 additions & 9 deletions datafusion/core/src/physical_plan/joins/hash_join.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,11 @@ use ahash::RandomState;

use arrow::{
array::{
as_dictionary_array, ArrayData, ArrayRef, BooleanArray, Date32Array, Date64Array,
Decimal128Array, DictionaryArray, LargeStringArray, PrimitiveArray,
Time32MillisecondArray, Time32SecondArray, Time64MicrosecondArray,
Time64NanosecondArray, TimestampMicrosecondArray, TimestampMillisecondArray,
TimestampSecondArray, UInt32BufferBuilder, UInt32Builder, UInt64BufferBuilder,
UInt64Builder,
ArrayData, ArrayRef, BooleanArray, Date32Array, Date64Array, Decimal128Array,
DictionaryArray, LargeStringArray, PrimitiveArray, Time32MillisecondArray,
Time32SecondArray, Time64MicrosecondArray, Time64NanosecondArray,
TimestampMicrosecondArray, TimestampMillisecondArray, TimestampSecondArray,
UInt32BufferBuilder, UInt32Builder, UInt64BufferBuilder, UInt64Builder,
},
compute,
datatypes::{
Expand All @@ -54,7 +53,7 @@ use arrow::array::{
UInt8Array,
};

use datafusion_common::cast::{as_boolean_array, as_string_array};
use datafusion_common::cast::{as_boolean_array, as_dictionary_array, as_string_array};

use hashbrown::raw::RawTable;

Expand Down Expand Up @@ -1127,9 +1126,9 @@ macro_rules! equal_rows_elem {
macro_rules! equal_rows_elem_with_string_dict {
($key_array_type:ident, $l: ident, $r: ident, $left: ident, $right: ident, $null_equals_null: ident) => {{
let left_array: &DictionaryArray<$key_array_type> =
as_dictionary_array::<$key_array_type>($l);
as_dictionary_array::<$key_array_type>($l).unwrap();
let right_array: &DictionaryArray<$key_array_type> =
as_dictionary_array::<$key_array_type>($r);
as_dictionary_array::<$key_array_type>($r).unwrap();

let (left_values, left_values_index) = {
let keys_col = left_array.keys();
Expand Down
14 changes: 7 additions & 7 deletions datafusion/core/src/physical_plan/sorts/sort.rs
Original file line number Diff line number Diff line change
Expand Up @@ -951,7 +951,7 @@ mod tests {
use arrow::array::*;
use arrow::compute::SortOptions;
use arrow::datatypes::*;
use datafusion_common::cast::as_string_array;
use datafusion_common::cast::{as_primitive_array, as_string_array};
use futures::FutureExt;
use std::collections::{BTreeMap, HashMap};

Expand Down Expand Up @@ -995,11 +995,11 @@ mod tests {
assert_eq!(c1.value(0), "a");
assert_eq!(c1.value(c1.len() - 1), "e");

let c2 = as_primitive_array::<UInt32Type>(&columns[1]);
let c2 = as_primitive_array::<UInt32Type>(&columns[1])?;
assert_eq!(c2.value(0), 1);
assert_eq!(c2.value(c2.len() - 1), 5,);

let c7 = as_primitive_array::<UInt8Type>(&columns[6]);
let c7 = as_primitive_array::<UInt8Type>(&columns[6])?;
assert_eq!(c7.value(0), 15);
assert_eq!(c7.value(c7.len() - 1), 254,);

Expand Down Expand Up @@ -1067,11 +1067,11 @@ mod tests {
assert_eq!(c1.value(0), "a");
assert_eq!(c1.value(c1.len() - 1), "e");

let c2 = as_primitive_array::<UInt32Type>(&columns[1]);
let c2 = as_primitive_array::<UInt32Type>(&columns[1])?;
assert_eq!(c2.value(0), 1);
assert_eq!(c2.value(c2.len() - 1), 5,);

let c7 = as_primitive_array::<UInt8Type>(&columns[6]);
let c7 = as_primitive_array::<UInt8Type>(&columns[6])?;
assert_eq!(c7.value(0), 15);
assert_eq!(c7.value(c7.len() - 1), 254,);

Expand Down Expand Up @@ -1271,8 +1271,8 @@ mod tests {
assert_eq!(DataType::Float32, *columns[0].data_type());
assert_eq!(DataType::Float64, *columns[1].data_type());

let a = as_primitive_array::<Float32Type>(&columns[0]);
let b = as_primitive_array::<Float64Type>(&columns[1]);
let a = as_primitive_array::<Float32Type>(&columns[0])?;
let b = as_primitive_array::<Float64Type>(&columns[1])?;

// convert result to strings to allow comparing to expected result containing NaN
let result: Vec<(Option<String>, Option<String>)> = (0..result[0].num_rows())
Expand Down
7 changes: 4 additions & 3 deletions datafusion/core/src/physical_plan/windows/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,7 @@ mod tests {
use arrow::array::*;
use arrow::datatypes::{DataType, Field, SchemaRef};
use arrow::record_batch::RecordBatch;
use datafusion_common::cast::as_primitive_array;
use futures::FutureExt;

fn create_test_schema(partitions: usize) -> Result<(Arc<CsvExec>, SchemaRef)> {
Expand Down Expand Up @@ -228,15 +229,15 @@ mod tests {

// c3 is small int

let count: &Int64Array = as_primitive_array(&columns[0]);
let count: &Int64Array = as_primitive_array(&columns[0])?;
assert_eq!(count.value(0), 100);
assert_eq!(count.value(99), 100);

let max: &Int8Array = as_primitive_array(&columns[1]);
let max: &Int8Array = as_primitive_array(&columns[1])?;
assert_eq!(max.value(0), 125);
assert_eq!(max.value(99), 125);

let min: &Int8Array = as_primitive_array(&columns[2]);
let min: &Int8Array = as_primitive_array(&columns[2])?;
assert_eq!(min.value(0), -117);
assert_eq!(min.value(99), -117);

Expand Down
15 changes: 4 additions & 11 deletions datafusion/core/tests/custom_sources.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
// specific language governing permissions and limitations
// under the License.

use arrow::array::{Int32Array, Int64Array, PrimitiveArray};
use arrow::array::{Int32Array, Int64Array};
use arrow::compute::kernels::aggregate;
use arrow::datatypes::{DataType, Field, Int32Type, Schema, SchemaRef};
use arrow::error::Result as ArrowResult;
Expand All @@ -38,6 +38,7 @@ use datafusion::{
};
use datafusion::{error::Result, physical_plan::DisplayFormatType};

use datafusion_common::cast::as_primitive_array;
use futures::stream::Stream;
use std::any::Any;
use std::pin::Pin;
Expand Down Expand Up @@ -162,18 +163,10 @@ impl ExecutionPlan for CustomExecutionPlan {
.map(|i| ColumnStatistics {
null_count: Some(batch.column(*i).null_count()),
min_value: Some(ScalarValue::Int32(aggregate::min(
batch
.column(*i)
.as_any()
.downcast_ref::<PrimitiveArray<Int32Type>>()
.unwrap(),
as_primitive_array::<Int32Type>(batch.column(*i)).unwrap(),
))),
max_value: Some(ScalarValue::Int32(aggregate::max(
batch
.column(*i)
.as_any()
.downcast_ref::<PrimitiveArray<Int32Type>>()
.unwrap(),
as_primitive_array::<Int32Type>(batch.column(*i)).unwrap(),
))),
..Default::default()
})
Expand Down
7 changes: 4 additions & 3 deletions datafusion/core/tests/provider_filter_pushdown.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
// specific language governing permissions and limitations
// under the License.

use arrow::array::{as_primitive_array, Int32Builder, Int64Array};
use arrow::array::{Int32Builder, Int64Array};
use arrow::datatypes::{DataType, Field, Schema, SchemaRef};
use arrow::record_batch::RecordBatch;
use async_trait::async_trait;
Expand All @@ -31,6 +31,7 @@ use datafusion::physical_plan::{
};
use datafusion::prelude::*;
use datafusion::scalar::ScalarValue;
use datafusion_common::cast::as_primitive_array;
use datafusion_common::DataFusionError;
use datafusion_expr::expr::{BinaryExpr, Cast};
use std::ops::Deref;
Expand Down Expand Up @@ -215,7 +216,7 @@ async fn assert_provider_row_count(value: i64, expected_count: i64) -> Result<()
.aggregate(vec![], vec![count(col("flag"))])?;

let results = df.collect().await?;
let result_col: &Int64Array = as_primitive_array(results[0].column(0));
let result_col: &Int64Array = as_primitive_array(results[0].column(0))?;
assert_eq!(result_col.value(0), expected_count);

ctx.register_table("data", Arc::new(provider))?;
Expand All @@ -225,7 +226,7 @@ async fn assert_provider_row_count(value: i64, expected_count: i64) -> Result<()
.collect()
.await?;

let sql_result_col: &Int64Array = as_primitive_array(sql_results[0].column(0));
let sql_result_col: &Int64Array = as_primitive_array(sql_results[0].column(0))?;
assert_eq!(sql_result_col.value(0), expected_count);

Ok(())
Expand Down
20 changes: 4 additions & 16 deletions datafusion/core/tests/sql/parquet.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ use std::{fs, path::Path};

use ::parquet::arrow::ArrowWriter;
use datafusion::datasource::listing::ListingOptions;
use datafusion_common::cast::{as_list_array, as_string_array};
use datafusion_common::cast::{as_list_array, as_primitive_array, as_string_array};
use tempfile::TempDir;

use super::*;
Expand Down Expand Up @@ -239,11 +239,7 @@ async fn parquet_list_columns() {
let utf8_list_array = as_list_array(batch.column(1)).unwrap();

assert_eq!(
int_list_array
.value(0)
.as_any()
.downcast_ref::<PrimitiveArray<Int64Type>>()
.unwrap(),
as_primitive_array::<Int64Type>(&int_list_array.value(0)).unwrap(),
&PrimitiveArray::<Int64Type>::from(vec![Some(1), Some(2), Some(3),])
);

Expand All @@ -253,22 +249,14 @@ async fn parquet_list_columns() {
);

assert_eq!(
int_list_array
.value(1)
.as_any()
.downcast_ref::<PrimitiveArray<Int64Type>>()
.unwrap(),
as_primitive_array::<Int64Type>(&int_list_array.value(1)).unwrap(),
&PrimitiveArray::<Int64Type>::from(vec![None, Some(1),])
);

assert!(utf8_list_array.is_null(1));

assert_eq!(
int_list_array
.value(2)
.as_any()
.downcast_ref::<PrimitiveArray<Int64Type>>()
.unwrap(),
as_primitive_array::<Int64Type>(&int_list_array.value(2)).unwrap(),
&PrimitiveArray::<Int64Type>::from(vec![Some(4),])
);

Expand Down
7 changes: 4 additions & 3 deletions datafusion/core/tests/user_defined_aggregates.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ use std::sync::Arc;

use datafusion::{
arrow::{
array::{as_primitive_array, ArrayRef, Float64Array, TimestampNanosecondArray},
array::{ArrayRef, Float64Array, TimestampNanosecondArray},
datatypes::{DataType, Field, Float64Type, TimeUnit, TimestampNanosecondType},
record_batch::RecordBatch,
},
Expand All @@ -37,6 +37,7 @@ use datafusion::{
prelude::SessionContext,
scalar::ScalarValue,
};
use datafusion_common::cast::as_primitive_array;

#[tokio::test]
/// Basic query for with a udaf returning a structure
Expand Down Expand Up @@ -227,8 +228,8 @@ impl Accumulator for FirstSelector {
fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
// cast argumets to the appropriate type (DataFusion will type
// check these based on the declared allowed input types)
let v = as_primitive_array::<Float64Type>(&values[0]);
let t = as_primitive_array::<TimestampNanosecondType>(&values[1]);
let v = as_primitive_array::<Float64Type>(&values[0])?;
let t = as_primitive_array::<TimestampNanosecondType>(&values[1])?;

// Update the actual values
for (value, time) in v.iter().zip(t.iter()) {
Expand Down
19 changes: 4 additions & 15 deletions datafusion/physical-expr/src/aggregate/median.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,13 @@

use crate::expressions::format_state_name;
use crate::{AggregateExpr, PhysicalExpr};
use arrow::array::{Array, ArrayRef, PrimitiveArray, PrimitiveBuilder};
use arrow::array::{Array, ArrayRef, PrimitiveBuilder};
use arrow::compute::sort;
use arrow::datatypes::{
ArrowPrimitiveType, DataType, Field, Float32Type, Float64Type, Int16Type, Int32Type,
Int64Type, Int8Type, UInt16Type, UInt32Type, UInt64Type, UInt8Type,
};
use datafusion_common::cast::as_primitive_array;
use datafusion_common::{DataFusionError, Result, ScalarValue};
use datafusion_expr::{Accumulator, AggregateState};
use std::any::Any;
Expand Down Expand Up @@ -102,12 +103,7 @@ macro_rules! median {
return Ok(ScalarValue::Null);
}
let sorted = sort(&combined, None)?;
let array = sorted
.as_any()
.downcast_ref::<PrimitiveArray<$TY>>()
.ok_or(DataFusionError::Internal(
"median! macro failed to cast array to expected type".to_string(),
))?;
let array = as_primitive_array::<$TY>(&sorted)?;
let len = sorted.len();
let mid = len / 2;
if len % 2 == 0 {
Expand Down Expand Up @@ -209,14 +205,7 @@ fn combine_arrays<T: ArrowPrimitiveType>(arrays: &[ArrayRef]) -> Result<ArrayRef
let len = arrays.iter().map(|a| a.len() - a.null_count()).sum();
let mut builder: PrimitiveBuilder<T> = PrimitiveBuilder::with_capacity(len);
for array in arrays {
let array = array
.as_any()
.downcast_ref::<PrimitiveArray<T>>()
.ok_or_else(|| {
DataFusionError::Internal(
"combine_arrays failed to cast array to expected type".to_string(),
)
})?;
let array = as_primitive_array::<T>(array)?;
for i in 0..array.len() {
if !array.is_null(i) {
builder.append_value(array.value(i));
Expand Down
Loading

0 comments on commit da54fa5

Please sign in to comment.