diff --git a/datafusion/spark/src/function/array/spark_array.rs b/datafusion/spark/src/function/array/spark_array.rs index bf5842cb5a5a..bb9665613de9 100644 --- a/datafusion/spark/src/function/array/spark_array.rs +++ b/datafusion/spark/src/function/array/spark_array.rs @@ -24,7 +24,7 @@ use arrow::array::{ use arrow::buffer::OffsetBuffer; use arrow::datatypes::{DataType, Field, FieldRef}; use datafusion_common::utils::SingleRowListArrayBuilder; -use datafusion_common::{plan_datafusion_err, plan_err, Result}; +use datafusion_common::{internal_err, plan_datafusion_err, plan_err, Result}; use datafusion_expr::type_coercion::binary::comparison_coercion; use datafusion_expr::{ ColumnarValue, ReturnFieldArgs, ScalarFunctionArgs, ScalarUDFImpl, Signature, @@ -72,9 +72,20 @@ impl ScalarUDFImpl for SparkArray { &self.signature } - fn return_type(&self, arg_types: &[DataType]) -> Result { + fn return_type(&self, _arg_types: &[DataType]) -> Result { + internal_err!("return_field_from_args should be used instead") + } + + fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result { + let data_types = args + .arg_fields + .iter() + .map(|f| f.data_type()) + .cloned() + .collect::>(); + let mut expr_type = DataType::Null; - for arg_type in arg_types { + for arg_type in &data_types { if !arg_type.equals_datatype(&DataType::Null) { expr_type = arg_type.clone(); break; @@ -85,21 +96,12 @@ impl ScalarUDFImpl for SparkArray { expr_type = DataType::Int32; } - Ok(DataType::List(Arc::new(Field::new( + let return_type = DataType::List(Arc::new(Field::new( ARRAY_FIELD_DEFAULT_NAME, expr_type, true, - )))) - } + ))); - fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result { - let data_types = args - .arg_fields - .iter() - .map(|f| f.data_type()) - .cloned() - .collect::>(); - let return_type = self.return_type(&data_types)?; Ok(Arc::new(Field::new( "this_field_name_is_irrelevant", return_type, @@ -166,7 +168,6 @@ pub fn make_array_inner(arrays: &[ArrayRef]) -> Result { .build_list_array(), )) } - DataType::LargeList(..) => array_array::(arrays, data_type), _ => array_array::(arrays, data_type), } } diff --git a/datafusion/sqllogictest/test_files/spark/array/array.slt b/datafusion/sqllogictest/test_files/spark/array/array.slt index 09821e6d582d..79dca1c10a7d 100644 --- a/datafusion/sqllogictest/test_files/spark/array/array.slt +++ b/datafusion/sqllogictest/test_files/spark/array/array.slt @@ -70,3 +70,18 @@ query ? SELECT array(array(1,2)); ---- [[1, 2]] + +query ? +SELECT array(arrow_cast(array(1), 'LargeList(Int64)')); +---- +[[1]] + +query ? +SELECT array(arrow_cast(array(1), 'LargeList(Int64)'), arrow_cast(array(), 'LargeList(Int64)')); +---- +[[1], []] + +query ? +SELECT array(arrow_cast(array(1,2), 'LargeList(Int64)'), array(3)); +---- +[[1, 2], [3]]