Skip to content
20 changes: 7 additions & 13 deletions datafusion/expr/src/built_in_function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -516,25 +516,19 @@ impl BuiltinScalarFunction {
Ok(data_type)
}
BuiltinScalarFunction::ArrayAppend => Ok(input_expr_types[0].clone()),
// Array concat allows multiple dimensions of arrays, i.e. array_concat(1D, 3D, 2D), find the largest one between them.
BuiltinScalarFunction::ArrayConcat => {
let mut expr_type = Null;
let mut max_dims = 0;
for input_expr_type in input_expr_types {
match input_expr_type {
List(field) => {
if !field.data_type().equals_datatype(&Null) {
let dims = self.return_dimension(input_expr_type.clone());
if max_dims < dims {
max_dims = dims;
expr_type = input_expr_type.clone();
}
if let List(field) = input_expr_type {
if !field.data_type().equals_datatype(&Null) {
let dims = self.return_dimension(input_expr_type.clone());
if max_dims < dims {
max_dims = dims;
expr_type = input_expr_type.clone();
}
}
_ => {
return plan_err!(
"The {self} function can only accept list as the args."
)
}
}
}

Expand Down
236 changes: 201 additions & 35 deletions datafusion/optimizer/src/analyzer/type_coercion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

use std::sync::Arc;

use arrow::datatypes::{DataType, IntervalUnit};
use arrow::datatypes::{DataType, Field, IntervalUnit};

use datafusion_common::config::ConfigOptions;
use datafusion_common::tree_node::{RewriteRecursion, TreeNodeRewriter};
Expand Down Expand Up @@ -553,6 +553,198 @@ fn coerce_arguments_for_signature(
.collect::<Result<Vec<_>>>()
}

// TODO: Move this function to arrow-rs or common array utils module
// base type is the non-list type
fn base_type(data_type: &DataType) -> Result<DataType> {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The function definitely has a practical use but it should be expanded for all nested data types (list, fixed_list, map, union ...).

match data_type {
DataType::List(field) => match field.data_type() {
DataType::List(_) => base_type(field.data_type()),
base_type => Ok(base_type.clone()),
},

_ => Ok(data_type.clone()),
}
}

// TODO: Move this function to arrow-rs or common array utils module
// Build a list from the given base type
// e.g. Int64 -> List[Int64]
fn coerced_type_from_base_type(
data_type: &DataType,
base_type: &DataType,
) -> Result<DataType> {
match data_type {
DataType::List(field) => match field.data_type() {
DataType::List(_) => Ok(DataType::List(Arc::new(Field::new(
field.name(),
coerced_type_from_base_type(field.data_type(), base_type)?,
field.is_nullable(),
)))),
_ => Ok(DataType::List(Arc::new(Field::new(
field.name(),
base_type.clone(),
field.is_nullable(),
)))),
},

_ => Ok(base_type.clone()),
}
}

// Replace inner nulls with coerced types
// i.e. list[i64], list[null] -> list[i64], list[i64]
fn replace_inner_nulls_with_coerced_types_(
coerced_types: Vec<DataType>,
) -> Result<Vec<DataType>> {
let first_non_null_base_type = coerced_types
.iter()
.map(base_type)
.find(|t| t.is_ok() && t.as_ref().unwrap() != &DataType::Null)
.map(|t| t.unwrap());

if let Some(data_type) = first_non_null_base_type {
coerced_types
.iter()
.map(|t| {
if base_type(t)? == DataType::Null {
coerced_type_from_base_type(t, &data_type)
} else {
Ok(t.clone())
}
})
.collect::<Result<Vec<_>>>()
} else {
Ok(coerced_types)
}
}

// Directly replace null with coerced type
// i.e. list[utf8], null -> list[utf8], list[utf8]
fn replace_nulls_with_coerced_types(
coerced_types: Vec<DataType>,
) -> Result<Vec<DataType>> {
let first_non_null_type = coerced_types.iter().find(|&t| {
let base_t = base_type(t);
base_t.is_ok() && base_t.as_ref().unwrap() != &DataType::Null
});

if let Some(data_type) = first_non_null_type {
coerced_types
.iter()
.map(|t| {
if base_type(t)? == DataType::Null {
Ok(data_type.clone())
} else {
Ok(t.clone())
}
})
.collect::<Result<Vec<_>>>()
} else {
Ok(coerced_types)
}
}

fn validate_array_function_arguments(
fun: &BuiltinScalarFunction,
input_types: &[DataType],
) -> Result<()> {
match fun {
BuiltinScalarFunction::ArrayConcat => {
// Dimension check
for expr_type in input_types.iter() {
if let DataType::List(_) = expr_type {
continue;
} else {
return plan_err!(
"The array_concat function can only accept list as the args"
);
}
}
Ok(())
}
// Add more cases for other array-related functions
_ => Ok(()),
}
}

fn coerced_array_types_without_nulls(input_types: &[DataType]) -> Result<Vec<DataType>> {
// Get base type for each input type
// e.g List[Int64] -> Int64
// List[List[Int64]] -> Int64
// Int64 -> Int64
let base_types = input_types
.iter()
.map(base_type)
.collect::<Result<Vec<_>>>()?;

// Get the coerced type with comparison coercion
let coerced_base_type = base_types
.iter()
.skip(1)
.fold(base_types.first().unwrap().clone(), |acc, x| {
comparison_coercion(&acc, x).unwrap_or(acc)
});

// Re-build the coerced type from base type, ignore null since it is difficult to determine the type for it at first scan
let coerced_types = input_types
.iter()
.map(|data_type|
// Special cases for null (Null) or empty array (List[Null]), type is determined based on array function
if base_type(data_type)? == DataType::Null {
Ok(data_type.clone())
} else {
coerced_type_from_base_type(data_type, &coerced_base_type)
})
.collect::<Result<Vec<_>>>()?;

Ok(coerced_types)
}

fn coerced_array_nulls(
fun: &BuiltinScalarFunction,
coerced_types: Vec<DataType>,
) -> Result<Vec<DataType>> {
// Convert Null to coerced expression
match fun {
// MakeArray(elements...): each element has the same type, convert null to the non-null type.
BuiltinScalarFunction::MakeArray => {
replace_nulls_with_coerced_types(coerced_types)
}
// ArrayAppend(list, element): null is only possible with the element, convert it to the list inner type
// ArrayPrepend(element, list): null is only possible with the element, convert it to the list inner type
// ArrayConcat: convert null to non-null at this step, dimension of list is not changed
BuiltinScalarFunction::ArrayAppend
| BuiltinScalarFunction::ArrayPrepend
| BuiltinScalarFunction::ArrayConcat => {
replace_inner_nulls_with_coerced_types_(coerced_types)
}
_ => Ok(coerced_types),
}
}

// Coerce array arguments types for array functions, convert type or return error for incompatible types at this step
fn coerce_array_args(
fun: &BuiltinScalarFunction,
expressions: &[Expr],
schema: &DFSchema,
) -> Result<Vec<Expr>> {
let input_types = expressions
.iter()
.map(|e| e.get_type(schema))
.collect::<Result<Vec<_>>>()?;
// TODO: We may move this check outside of type coercion
// Array concat is moved here since handle this before null coercion is easier and make senses to block the invalid arguments before type coercion.
validate_array_function_arguments(fun, input_types.as_slice())?;
// coercion is break down into two steps, since not all array functions have the same coercion rules for nulls
let coerced_types = coerced_array_types_without_nulls(input_types.as_slice())?;
let coerced_types = coerced_array_nulls(fun, coerced_types)?;
expressions
.iter()
.zip(coerced_types.iter())
.map(|(expr, coerced_type)| cast_expr(expr, coerced_type, schema))
.collect::<Result<Vec<_>>>()
}

fn coerce_arguments_for_fun(
expressions: &[Expr],
schema: &DFSchema,
Expand Down Expand Up @@ -581,48 +773,22 @@ fn coerce_arguments_for_fun(
.collect::<Result<Vec<_>>>()?;
}

if *fun == BuiltinScalarFunction::MakeArray {
// Find the final data type for the function arguments
let current_types = expressions
.iter()
.map(|e| e.get_type(schema))
.collect::<Result<Vec<_>>>()?;

let new_type = current_types
.iter()
.skip(1)
.fold(current_types.first().unwrap().clone(), |acc, x| {
comparison_coercion(&acc, x).unwrap_or(acc)
});

return expressions
.iter()
.zip(current_types)
.map(|(expr, from_type)| cast_array_expr(expr, &from_type, &new_type, schema))
.collect();
match fun {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am inclined to solve this problem by expanding signature's structure because there is one difficulty with User Defined Function. For example, I am Arrow DataFusion's user and I want to define my own ArrayAppend implementation (the function new_array_append). And how this function would handle nulls?

What do you think about it, @alamb and @jayzhan211?

BuiltinScalarFunction::MakeArray
| BuiltinScalarFunction::ArrayAppend
| BuiltinScalarFunction::ArrayPrepend
| BuiltinScalarFunction::ArrayConcat => {
coerce_array_args(fun, expressions.as_slice(), schema)
}
_ => Ok(expressions),
}
Ok(expressions)
}

/// Cast `expr` to the specified type, if possible
fn cast_expr(expr: &Expr, to_type: &DataType, schema: &DFSchema) -> Result<Expr> {
expr.clone().cast_to(to_type, schema)
}

/// Cast array `expr` to the specified type, if possible
fn cast_array_expr(
expr: &Expr,
from_type: &DataType,
to_type: &DataType,
schema: &DFSchema,
) -> Result<Expr> {
if from_type.equals_datatype(&DataType::Null) {
Ok(expr.clone())
} else {
cast_expr(expr, to_type, schema)
}
}

/// Returns the coerced exprs for each `input_exprs`.
/// Get the coerced data type from `aggregate_rule::coerce_types` and add `try_cast` if the
/// data type of `input_exprs` need to be coerced.
Expand Down
28 changes: 8 additions & 20 deletions datafusion/physical-expr/src/array_expressions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -327,7 +327,9 @@ fn array_array(args: &[ArrayRef], data_type: DataType) -> Result<ArrayRef> {
} else if arg.as_any().downcast_ref::<NullArray>().is_some() {
arrays.push(ListOrNull::Null);
} else {
return internal_err!("Unsupported argument type for array");
return internal_err!(
"(array_array) Unsupported argument type for array"
);
}
}

Expand Down Expand Up @@ -674,7 +676,7 @@ pub fn array_append(args: &[ArrayRef]) -> Result<ArrayRef> {

check_datatypes("array_append", &[arr.values(), element])?;
let res = match arr.value_type() {
DataType::List(_) => concat_internal(args)?,
DataType::List(_) => array_concat(args)?,
DataType::Null => {
return Ok(array(&[ColumnarValue::Array(args[1].clone())])?.into_array(1))
}
Expand Down Expand Up @@ -750,7 +752,7 @@ pub fn array_prepend(args: &[ArrayRef]) -> Result<ArrayRef> {

check_datatypes("array_prepend", &[element, arr.values()])?;
let res = match arr.value_type() {
DataType::List(_) => concat_internal(args)?,
DataType::List(_) => array_concat(args)?,
DataType::Null => {
return Ok(array(&[ColumnarValue::Array(args[0].clone())])?.into_array(1))
}
Expand Down Expand Up @@ -810,7 +812,9 @@ fn align_array_dimensions(args: Vec<ArrayRef>) -> Result<Vec<ArrayRef>> {
aligned_args
}

fn concat_internal(args: &[ArrayRef]) -> Result<ArrayRef> {
/// Array_concat/Array_cat SQL function
pub fn array_concat(args: &[ArrayRef]) -> Result<ArrayRef> {
// Dimension check and null conversion is done in `type coercion` step.
let args = align_array_dimensions(args.to_vec())?;

let list_arrays =
Expand Down Expand Up @@ -863,22 +867,6 @@ fn concat_internal(args: &[ArrayRef]) -> Result<ArrayRef> {
Ok(Arc::new(list))
}

/// Array_concat/Array_cat SQL function
pub fn array_concat(args: &[ArrayRef]) -> Result<ArrayRef> {
let mut new_args = vec![];
for arg in args {
let (ndim, lower_data_type) =
compute_array_ndims_with_datatype(Some(arg.clone()))?;
if ndim.is_none() || ndim == Some(1) {
return not_impl_err!("Array is not type '{lower_data_type:?}'.");
} else if !lower_data_type.equals_datatype(&DataType::Null) {
new_args.push(arg.clone());
}
}

concat_internal(new_args.as_slice())
}

macro_rules! general_repeat {
($ELEMENT:expr, $COUNT:expr, $ARRAY_TYPE:ident) => {{
let mut offsets: Vec<i32> = vec![0];
Expand Down
Loading