-
Notifications
You must be signed in to change notification settings - Fork 1.8k
Support nulls and empty for array functions #7338
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
ca6acee
ad84d01
71b90b5
3f09433
fe09c92
ff0c107
a010854
f81a359
09c3d69
4297340
c2aaad1
797a6e6
b4e9e73
f8c180e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -19,7 +19,7 @@ | |
|
|
||
| use std::sync::Arc; | ||
|
|
||
| use arrow::datatypes::{DataType, IntervalUnit}; | ||
| use arrow::datatypes::{DataType, Field, IntervalUnit}; | ||
|
|
||
| use datafusion_common::config::ConfigOptions; | ||
| use datafusion_common::tree_node::{RewriteRecursion, TreeNodeRewriter}; | ||
|
|
@@ -553,6 +553,198 @@ fn coerce_arguments_for_signature( | |
| .collect::<Result<Vec<_>>>() | ||
| } | ||
|
|
||
| // TODO: Move this function to arrow-rs or common array utils module | ||
| // base type is the non-list type | ||
| fn base_type(data_type: &DataType) -> Result<DataType> { | ||
| match data_type { | ||
| DataType::List(field) => match field.data_type() { | ||
| DataType::List(_) => base_type(field.data_type()), | ||
| base_type => Ok(base_type.clone()), | ||
| }, | ||
|
|
||
| _ => Ok(data_type.clone()), | ||
| } | ||
| } | ||
|
|
||
| // TODO: Move this function to arrow-rs or common array utils module | ||
| // Build a list from the given base type | ||
| // e.g. Int64 -> List[Int64] | ||
| fn coerced_type_from_base_type( | ||
| data_type: &DataType, | ||
| base_type: &DataType, | ||
| ) -> Result<DataType> { | ||
| match data_type { | ||
| DataType::List(field) => match field.data_type() { | ||
| DataType::List(_) => Ok(DataType::List(Arc::new(Field::new( | ||
| field.name(), | ||
| coerced_type_from_base_type(field.data_type(), base_type)?, | ||
| field.is_nullable(), | ||
| )))), | ||
| _ => Ok(DataType::List(Arc::new(Field::new( | ||
| field.name(), | ||
| base_type.clone(), | ||
| field.is_nullable(), | ||
| )))), | ||
| }, | ||
|
|
||
| _ => Ok(base_type.clone()), | ||
| } | ||
| } | ||
|
|
||
| // Replace inner nulls with coerced types | ||
| // i.e. list[i64], list[null] -> list[i64], list[i64] | ||
| fn replace_inner_nulls_with_coerced_types_( | ||
| coerced_types: Vec<DataType>, | ||
| ) -> Result<Vec<DataType>> { | ||
| let first_non_null_base_type = coerced_types | ||
| .iter() | ||
| .map(base_type) | ||
| .find(|t| t.is_ok() && t.as_ref().unwrap() != &DataType::Null) | ||
| .map(|t| t.unwrap()); | ||
|
|
||
| if let Some(data_type) = first_non_null_base_type { | ||
| coerced_types | ||
| .iter() | ||
| .map(|t| { | ||
| if base_type(t)? == DataType::Null { | ||
| coerced_type_from_base_type(t, &data_type) | ||
| } else { | ||
| Ok(t.clone()) | ||
| } | ||
| }) | ||
| .collect::<Result<Vec<_>>>() | ||
| } else { | ||
| Ok(coerced_types) | ||
| } | ||
| } | ||
|
|
||
| // Directly replace null with coerced type | ||
| // i.e. list[utf8], null -> list[utf8], list[utf8] | ||
| fn replace_nulls_with_coerced_types( | ||
| coerced_types: Vec<DataType>, | ||
| ) -> Result<Vec<DataType>> { | ||
| let first_non_null_type = coerced_types.iter().find(|&t| { | ||
| let base_t = base_type(t); | ||
| base_t.is_ok() && base_t.as_ref().unwrap() != &DataType::Null | ||
| }); | ||
|
|
||
| if let Some(data_type) = first_non_null_type { | ||
| coerced_types | ||
| .iter() | ||
| .map(|t| { | ||
| if base_type(t)? == DataType::Null { | ||
| Ok(data_type.clone()) | ||
| } else { | ||
| Ok(t.clone()) | ||
| } | ||
| }) | ||
| .collect::<Result<Vec<_>>>() | ||
| } else { | ||
| Ok(coerced_types) | ||
| } | ||
| } | ||
|
|
||
| fn validate_array_function_arguments( | ||
| fun: &BuiltinScalarFunction, | ||
| input_types: &[DataType], | ||
| ) -> Result<()> { | ||
| match fun { | ||
| BuiltinScalarFunction::ArrayConcat => { | ||
| // Dimension check | ||
| for expr_type in input_types.iter() { | ||
| if let DataType::List(_) = expr_type { | ||
| continue; | ||
| } else { | ||
| return plan_err!( | ||
| "The array_concat function can only accept list as the args" | ||
| ); | ||
| } | ||
| } | ||
| Ok(()) | ||
| } | ||
| // Add more cases for other array-related functions | ||
| _ => Ok(()), | ||
| } | ||
| } | ||
|
|
||
| fn coerced_array_types_without_nulls(input_types: &[DataType]) -> Result<Vec<DataType>> { | ||
| // Get base type for each input type | ||
| // e.g List[Int64] -> Int64 | ||
| // List[List[Int64]] -> Int64 | ||
| // Int64 -> Int64 | ||
| let base_types = input_types | ||
| .iter() | ||
| .map(base_type) | ||
| .collect::<Result<Vec<_>>>()?; | ||
|
|
||
| // Get the coerced type with comparison coercion | ||
| let coerced_base_type = base_types | ||
| .iter() | ||
| .skip(1) | ||
| .fold(base_types.first().unwrap().clone(), |acc, x| { | ||
| comparison_coercion(&acc, x).unwrap_or(acc) | ||
| }); | ||
|
|
||
| // Re-build the coerced type from base type, ignore null since it is difficult to determine the type for it at first scan | ||
| let coerced_types = input_types | ||
| .iter() | ||
| .map(|data_type| | ||
| // Special cases for null (Null) or empty array (List[Null]), type is determined based on array function | ||
| if base_type(data_type)? == DataType::Null { | ||
| Ok(data_type.clone()) | ||
| } else { | ||
| coerced_type_from_base_type(data_type, &coerced_base_type) | ||
| }) | ||
| .collect::<Result<Vec<_>>>()?; | ||
|
|
||
| Ok(coerced_types) | ||
| } | ||
|
|
||
| fn coerced_array_nulls( | ||
| fun: &BuiltinScalarFunction, | ||
| coerced_types: Vec<DataType>, | ||
| ) -> Result<Vec<DataType>> { | ||
| // Convert Null to coerced expression | ||
| match fun { | ||
| // MakeArray(elements...): each element has the same type, convert null to the non-null type. | ||
| BuiltinScalarFunction::MakeArray => { | ||
| replace_nulls_with_coerced_types(coerced_types) | ||
| } | ||
| // ArrayAppend(list, element): null is only possible with the element, convert it to the list inner type | ||
| // ArrayPrepend(element, list): null is only possible with the element, convert it to the list inner type | ||
| // ArrayConcat: convert null to non-null at this step, dimension of list is not changed | ||
| BuiltinScalarFunction::ArrayAppend | ||
| | BuiltinScalarFunction::ArrayPrepend | ||
| | BuiltinScalarFunction::ArrayConcat => { | ||
| replace_inner_nulls_with_coerced_types_(coerced_types) | ||
| } | ||
| _ => Ok(coerced_types), | ||
| } | ||
| } | ||
|
|
||
| // Coerce array arguments types for array functions, convert type or return error for incompatible types at this step | ||
| fn coerce_array_args( | ||
| fun: &BuiltinScalarFunction, | ||
| expressions: &[Expr], | ||
| schema: &DFSchema, | ||
| ) -> Result<Vec<Expr>> { | ||
| let input_types = expressions | ||
| .iter() | ||
| .map(|e| e.get_type(schema)) | ||
| .collect::<Result<Vec<_>>>()?; | ||
| // TODO: We may move this check outside of type coercion | ||
| // Array concat is moved here since handle this before null coercion is easier and make senses to block the invalid arguments before type coercion. | ||
| validate_array_function_arguments(fun, input_types.as_slice())?; | ||
| // coercion is break down into two steps, since not all array functions have the same coercion rules for nulls | ||
| let coerced_types = coerced_array_types_without_nulls(input_types.as_slice())?; | ||
| let coerced_types = coerced_array_nulls(fun, coerced_types)?; | ||
| expressions | ||
| .iter() | ||
| .zip(coerced_types.iter()) | ||
| .map(|(expr, coerced_type)| cast_expr(expr, coerced_type, schema)) | ||
| .collect::<Result<Vec<_>>>() | ||
| } | ||
|
|
||
| fn coerce_arguments_for_fun( | ||
| expressions: &[Expr], | ||
| schema: &DFSchema, | ||
|
|
@@ -581,48 +773,22 @@ fn coerce_arguments_for_fun( | |
| .collect::<Result<Vec<_>>>()?; | ||
| } | ||
|
|
||
| if *fun == BuiltinScalarFunction::MakeArray { | ||
| // Find the final data type for the function arguments | ||
| let current_types = expressions | ||
| .iter() | ||
| .map(|e| e.get_type(schema)) | ||
| .collect::<Result<Vec<_>>>()?; | ||
|
|
||
| let new_type = current_types | ||
| .iter() | ||
| .skip(1) | ||
| .fold(current_types.first().unwrap().clone(), |acc, x| { | ||
| comparison_coercion(&acc, x).unwrap_or(acc) | ||
| }); | ||
|
|
||
| return expressions | ||
| .iter() | ||
| .zip(current_types) | ||
| .map(|(expr, from_type)| cast_array_expr(expr, &from_type, &new_type, schema)) | ||
| .collect(); | ||
| match fun { | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I am inclined to solve this problem by expanding signature's structure because there is one difficulty with User Defined Function. For example, I am Arrow DataFusion's user and I want to define my own What do you think about it, @alamb and @jayzhan211? |
||
| BuiltinScalarFunction::MakeArray | ||
| | BuiltinScalarFunction::ArrayAppend | ||
| | BuiltinScalarFunction::ArrayPrepend | ||
| | BuiltinScalarFunction::ArrayConcat => { | ||
| coerce_array_args(fun, expressions.as_slice(), schema) | ||
| } | ||
| _ => Ok(expressions), | ||
| } | ||
| Ok(expressions) | ||
| } | ||
|
|
||
| /// Cast `expr` to the specified type, if possible | ||
| fn cast_expr(expr: &Expr, to_type: &DataType, schema: &DFSchema) -> Result<Expr> { | ||
| expr.clone().cast_to(to_type, schema) | ||
| } | ||
|
|
||
| /// Cast array `expr` to the specified type, if possible | ||
| fn cast_array_expr( | ||
| expr: &Expr, | ||
| from_type: &DataType, | ||
| to_type: &DataType, | ||
| schema: &DFSchema, | ||
| ) -> Result<Expr> { | ||
| if from_type.equals_datatype(&DataType::Null) { | ||
| Ok(expr.clone()) | ||
| } else { | ||
| cast_expr(expr, to_type, schema) | ||
| } | ||
| } | ||
|
|
||
| /// Returns the coerced exprs for each `input_exprs`. | ||
| /// Get the coerced data type from `aggregate_rule::coerce_types` and add `try_cast` if the | ||
| /// data type of `input_exprs` need to be coerced. | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The function definitely has a practical use but it should be expanded for all nested data types (list, fixed_list, map, union ...).