diff --git a/datafusion/functions-nested/src/array_add.rs b/datafusion/functions-nested/src/array_add.rs new file mode 100644 index 0000000000000..c6edf67bf5a93 --- /dev/null +++ b/datafusion/functions-nested/src/array_add.rs @@ -0,0 +1,203 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! [`ScalarUDFImpl`] definitions for array_add function. + +use crate::utils::{coerce_array_math_arg_types, make_scalar_function}; +use arrow::array::{ + Array, ArrayRef, Float64Array, GenericListArray, NullBufferBuilder, + OffsetBufferBuilder, OffsetSizeTrait, +}; +use arrow::buffer::NullBuffer; +use arrow::datatypes::{ + DataType, + DataType::{LargeList, List}, + Field, +}; +use datafusion_common::cast::{as_float64_array, as_generic_list_array}; +use datafusion_common::{Result, exec_err, utils::take_function_args}; +use datafusion_expr::{ + ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature, + Volatility, +}; +use datafusion_macros::user_doc; +use std::sync::Arc; + +make_udf_expr_and_func!( + ArrayAdd, + array_add, + array1 array2, + "returns the element-wise sum of two numeric arrays.", + array_add_udf +); + +#[user_doc( + doc_section(label = "Array Functions"), + description = "Returns the element-wise sum of two numeric arrays of equal length, computed as `array1[i] + array2[i]` per position. NULL is propagated per element: if either input element at position `i` is NULL, the corresponding output element is NULL (positions are preserved). Returns NULL if either entire input array is NULL. Errors if the per-row lengths differ. Returns an empty array if both inputs are empty.", + syntax_example = "array_add(array1, array2)", + sql_example = r#"```sql +> select array_add([1.0, 2.0, 3.0], [10.0, 20.0, 30.0]); ++---------------------------------------------------------+ +| array_add(List([1.0,2.0,3.0]),List([10.0,20.0,30.0])) | ++---------------------------------------------------------+ +| [11.0, 22.0, 33.0] | ++---------------------------------------------------------+ +```"#, + argument( + name = "array1", + description = "Array expression. Can be a constant, column, or function, and any combination of array operators." + ), + argument( + name = "array2", + description = "Array expression. Can be a constant, column, or function, and any combination of array operators." + ) +)] +#[derive(Debug, PartialEq, Eq, Hash)] +pub struct ArrayAdd { + signature: Signature, + aliases: Vec, +} + +impl Default for ArrayAdd { + fn default() -> Self { + Self::new() + } +} + +impl ArrayAdd { + pub fn new() -> Self { + Self { + signature: Signature::user_defined(Volatility::Immutable), + aliases: vec!["list_add".to_string()], + } + } +} + +impl ScalarUDFImpl for ArrayAdd { + fn name(&self) -> &str { + "array_add" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + // After `coerce_types`, both args share the same List/LargeList shape. + Ok(arg_types[0].clone()) + } + + fn coerce_types(&self, arg_types: &[DataType]) -> Result> { + let [_, _] = take_function_args(self.name(), arg_types)?; + coerce_array_math_arg_types(self.name(), arg_types) + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + make_scalar_function(array_add_inner)(&args.args) + } + + fn aliases(&self) -> &[String] { + &self.aliases + } + + fn documentation(&self) -> Option<&Documentation> { + self.doc() + } +} + +fn array_add_inner(args: &[ArrayRef]) -> Result { + let [array1, array2] = take_function_args("array_add", args)?; + match (array1.data_type(), array2.data_type()) { + (List(_), List(_)) => general_array_add::(array1, array2), + (LargeList(_), LargeList(_)) => general_array_add::(array1, array2), + (arg_type1, arg_type2) => exec_err!( + "array_add received unexpected types after coercion: {arg_type1} and {arg_type2}" + ), + } +} + +fn general_array_add( + lhs: &ArrayRef, + rhs: &ArrayRef, +) -> Result { + let lhs = as_generic_list_array::(lhs)?; + let rhs = as_generic_list_array::(rhs)?; + + let lhs_values = as_float64_array(lhs.values())?; + let rhs_values = as_float64_array(rhs.values())?; + let lhs_offsets = lhs.value_offsets(); + let rhs_offsets = rhs.value_offsets(); + + // Row-level validity: a row is valid iff both sides are valid at that row. + let row_nulls = NullBuffer::union(lhs.nulls(), rhs.nulls()); + + let mut out_values: Vec = Vec::with_capacity(lhs_values.len()); + let mut out_inner_nulls = NullBufferBuilder::new(lhs_values.len()); + let mut out_offsets = OffsetBufferBuilder::::new(lhs.len()); + + for row in 0..lhs.len() { + // Whole-row NULL on either side -> NULL output row, no elements. + if row_nulls.as_ref().is_some_and(|nb| nb.is_null(row)) { + out_offsets.push_length(0); + continue; + } + + let start1 = lhs_offsets[row].as_usize(); + let len1 = lhs.value_length(row).as_usize(); + let start2 = rhs_offsets[row].as_usize(); + let len2 = rhs.value_length(row).as_usize(); + + if len1 != len2 { + return exec_err!( + "array_add requires both list inputs to have the same length per row, got {len1} and {len2} at row {row}" + ); + } + + let l_slice = lhs_values.slice(start1, len1); + let r_slice = rhs_values.slice(start2, len2); + + let l_vals = l_slice.values(); + let r_vals = r_slice.values(); + + for i in 0..len1 { + out_values.push(l_vals[i] + r_vals[i]); + } + + // Per-element validity: position `i` is valid iff both lhs[i] and rhs[i] + // are valid. `NullBuffer::union` returns `None` when both sides are + // entirely valid. + match NullBuffer::union(l_slice.nulls(), r_slice.nulls()) { + Some(nb) => out_inner_nulls.append_buffer(&nb), + None => out_inner_nulls.append_n_non_nulls(len1), + } + + out_offsets.push_length(len1); + } + + let values_array = Arc::new(Float64Array::new( + out_values.into(), + out_inner_nulls.finish(), + )); + let field = Arc::new(Field::new_list_field(DataType::Float64, true)); + + Ok(Arc::new(GenericListArray::::try_new( + field, + out_offsets.finish(), + values_array, + row_nulls, + )?)) +} diff --git a/datafusion/functions-nested/src/lib.rs b/datafusion/functions-nested/src/lib.rs index 1e6dc68cb23ae..43f48e8247951 100644 --- a/datafusion/functions-nested/src/lib.rs +++ b/datafusion/functions-nested/src/lib.rs @@ -40,9 +40,8 @@ pub mod macros; #[macro_use] pub mod macros_lambda; +pub mod array_add; pub mod array_any_match; -pub(crate) mod lambda_utils; - pub mod array_compact; pub mod array_filter; pub mod array_has; @@ -60,6 +59,7 @@ pub mod expr_ext; pub mod extract; pub mod flatten; pub mod inner_product; +pub(crate) mod lambda_utils; pub mod length; pub mod make_array; pub mod map; @@ -89,6 +89,7 @@ use std::sync::Arc; /// Fluent-style API for creating `Expr`s pub mod expr_fn { + pub use super::array_add::array_add; pub use super::array_any_match::array_any_match; pub use super::array_compact::array_compact; pub use super::array_filter::array_filter; @@ -171,6 +172,7 @@ pub fn all_default_nested_functions() -> Vec> { empty::array_empty_udf(), length::array_length_udf(), array_normalize::array_normalize_udf(), + array_add::array_add_udf(), cosine_distance::cosine_distance_udf(), inner_product::inner_product_udf(), distance::array_distance_udf(), diff --git a/datafusion/functions-nested/src/utils.rs b/datafusion/functions-nested/src/utils.rs index eeff003e8e766..bdd71f2ff8f28 100644 --- a/datafusion/functions-nested/src/utils.rs +++ b/datafusion/functions-nested/src/utils.rs @@ -276,6 +276,57 @@ pub(crate) fn get_map_entry_field(data_type: &DataType) -> Result<&Fields> { } } +/// Shared `coerce_types` impl for array-math UDFs whose kernels expect +/// `List` / `LargeList` (e.g. `array_add`, `cosine_distance`, +/// `inner_product`, `array_normalize`). +/// +/// Each input must be `Null`, `List`, `LargeList`, or `FixedSizeList`; otherwise +/// returns a plan error naming `name`. `FixedSizeList` is widened to `List`, +/// `Null` is coerced to a list of `Float64`, and if any input is `LargeList` +/// the rest are widened to `LargeList` so the runtime sees a homogeneous pair. +pub(crate) fn coerce_array_math_arg_types( + name: &str, + arg_types: &[DataType], +) -> Result> { + use DataType::{FixedSizeList, LargeList, List, Null}; + use datafusion_common::utils::{ListCoercion, coerced_type_with_base_type_only}; + + let coercion = Some(&ListCoercion::FixedSizedListToList); + + for arg_type in arg_types { + if !matches!(arg_type, Null | List(_) | LargeList(_) | FixedSizeList(..)) { + return plan_err!("{name} does not support type {arg_type}"); + } + } + + // If any input is `LargeList`, both sides must be widened to `LargeList` + // so the runtime dispatch in `inner_product_inner` sees a homogeneous + // pair. Follows the pattern in `ArrayConcat::coerce_types`. + let any_large_list = arg_types.iter().any(|t| matches!(t, LargeList(_))); + + let coerced = arg_types + .iter() + .map(|arg_type| { + if matches!(arg_type, Null) { + let field = Arc::new(Field::new_list_field(DataType::Float64, true)); + return if any_large_list { + LargeList(field) + } else { + List(field) + }; + } + let coerced = + coerced_type_with_base_type_only(arg_type, &DataType::Float64, coercion); + match coerced { + List(field) if any_large_list => LargeList(field), + other => other, + } + }) + .collect(); + + Ok(coerced) +} + #[cfg(test)] mod tests { use super::*; diff --git a/datafusion/sqllogictest/test_files/array_add.slt b/datafusion/sqllogictest/test_files/array_add.slt new file mode 100644 index 0000000000000..e13f6acd269cb --- /dev/null +++ b/datafusion/sqllogictest/test_files/array_add.slt @@ -0,0 +1,237 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +## array_add + +# Basic element-wise sum +query ? +select array_add([1.0, 2.0, 3.0], [10.0, 20.0, 30.0]); +---- +[11.0, 22.0, 33.0] + +# Negative components +query ? +select array_add([1.0, -2.0, 3.0], [-1.0, 2.0, -3.0]); +---- +[0.0, 0.0, 0.0] + +# Single-element arrays +query ? +select array_add([5.0], [7.0]); +---- +[12.0] + +# Bare NULL on left -> NULL row +query ? +select array_add(NULL, [1.0, 2.0]); +---- +NULL + +# Bare NULL on right -> NULL row +query ? +select array_add([1.0, 2.0], NULL); +---- +NULL + +# Both bare NULL -> NULL row +query ? +select array_add(NULL, NULL); +---- +NULL + +# NULL element on left propagates to that position only +query ? +select array_add([1.0, NULL, 3.0], [10.0, 20.0, 30.0]); +---- +[11.0, NULL, 33.0] + +# NULL element on right propagates to that position only +query ? +select array_add([1.0, 2.0, 3.0], [10.0, NULL, 30.0]); +---- +[11.0, NULL, 33.0] + +# NULL element on both sides at the same position +query ? +select array_add([1.0, NULL, 3.0], [10.0, NULL, 30.0]); +---- +[11.0, NULL, 33.0] + +# NULL elements at different positions both propagate +query ? +select array_add([1.0, NULL, 3.0], [NULL, 20.0, 30.0]); +---- +[NULL, NULL, 33.0] + +# Length mismatch is an exec error +query error array_add requires both list inputs to have the same length per row +select array_add([1.0, 2.0], [10.0, 20.0, 30.0]); + +# Empty arrays on both sides return empty array +query ? +select array_add(arrow_cast(make_array(), 'List(Float64)'), arrow_cast(make_array(), 'List(Float64)')); +---- +[] + +# Integer literals coerced to Float64 +query ? +select array_add([1, 2, 3], [10, 20, 30]); +---- +[11.0, 22.0, 33.0] + +# Mixed int + float literals coerced to Float64 +query ? +select array_add([1, 2, 3], [0.5, 0.5, 0.5]); +---- +[1.5, 2.5, 3.5] + +# LargeList input on both sides +query ? +select array_add( + arrow_cast([1.0, 2.0, 3.0], 'LargeList(Float64)'), + arrow_cast([10.0, 20.0, 30.0], 'LargeList(Float64)') +); +---- +[11.0, 22.0, 33.0] + +# Mixed List + LargeList -> both widened to LargeList +query ? +select array_add( + [1.0, 2.0, 3.0], + arrow_cast([10.0, 20.0, 30.0], 'LargeList(Float64)') +); +---- +[11.0, 22.0, 33.0] + +# FixedSizeList input (coerced to List) +query ? +select array_add( + arrow_cast([1.0, 2.0, 3.0], 'FixedSizeList(3, Float64)'), + arrow_cast([10.0, 20.0, 30.0], 'FixedSizeList(3, Float64)') +); +---- +[11.0, 22.0, 33.0] + +# Float32 inner type on one side +query ? +select array_add( + arrow_cast([1.0, 2.0, 3.0], 'List(Float32)'), + [10.0, 20.0, 30.0] +); +---- +[11.0, 22.0, 33.0] + +# Int64 inner type +query ? +select array_add( + arrow_cast([1, 2, 3], 'List(Int64)'), + arrow_cast([10, 20, 30], 'List(Int64)') +); +---- +[11.0, 22.0, 33.0] + +# Unsupported non-list input (plan error) +query error array_add does not support type +select array_add(1, [1.0, 2.0]); + +# Wrong arg count +query error array_add function requires 2 arguments, got 0 +select array_add(); + +query error array_add function requires 2 arguments, got 1 +select array_add([1.0, 2.0]); + +# Return type matches input variant +query ?T +select array_add([1.0, 2.0], [3.0, 4.0]), arrow_typeof(array_add([1.0, 2.0], [3.0, 4.0])); +---- +[4.0, 6.0] List(Float64) + +# Multi-row query: normal row, NULL row, element-NULL row, length-matched row +query ? +select array_add(a, b) from (values + (make_array(1.0, 2.0, 3.0), make_array(10.0, 20.0, 30.0)), + (NULL, make_array(1.0, 2.0, 3.0)), + (make_array(1.0, 2.0, 3.0), NULL), + (make_array(1.0, NULL, 3.0), make_array(10.0, 20.0, 30.0)) +) as t(a, b); +---- +[11.0, 22.0, 33.0] +NULL +NULL +[11.0, NULL, 33.0] + +# list_add alias +query ? +select list_add([1.0, 2.0], [3.0, 4.0]); +---- +[4.0, 6.0] + +# list_add alias multi-row +query ? +select list_add(a, b) from (values + (make_array(1.0, 2.0), make_array(10.0, 20.0)), + (NULL, make_array(1.0, 2.0)) +) as t(a, b); +---- +[11.0, 22.0] +NULL + +# Decimal element types are coerced to Float64 (lossy) like other array-math UDFs +query ? +select array_add( + arrow_cast([1, 2, 3], 'List(Decimal128(10, 2))'), + arrow_cast([10, 20, 30], 'List(Decimal128(10, 2))') +); +---- +[11.0, 22.0, 33.0] + +# Explicit cast to DOUBLE works as the documented opt-in +query ? +select array_add( + arrow_cast(arrow_cast([1, 2, 3], 'List(Decimal128(10, 2))'), 'List(Float64)'), + [10.0, 20.0, 30.0] +); +---- +[11.0, 22.0, 33.0] + +# Chained array_add: result of inner call feeds the outer call +query ? +select array_add(array_add([1.0, 2.0, 3.0], [10.0, 20.0, 30.0]), [100.0, 200.0, 300.0]); +---- +[111.0, 222.0, 333.0] + +# Chained array_add propagates element-level NULLs through both layers +query ? +select array_add( + array_add([1.0, NULL, 3.0], [10.0, 20.0, 30.0]), + [100.0, 200.0, NULL] +); +---- +[111.0, NULL, NULL] + +# Chained array_add over multiple rows +query ? +select array_add(array_add(a, b), c) from (values + (make_array(1.0, 2.0), make_array(10.0, 20.0), make_array(100.0, 200.0)), + (NULL, make_array(1.0, 2.0), make_array(3.0, 4.0)), + (make_array(1.0, 2.0), make_array(10.0, NULL), make_array(100.0, 200.0)) +) as t(a, b, c); +---- +[111.0, 222.0] +NULL +[111.0, NULL] \ No newline at end of file diff --git a/docs/source/user-guide/sql/scalar_functions.md b/docs/source/user-guide/sql/scalar_functions.md index 6bf61391eb10e..94f73f42ded65 100644 --- a/docs/source/user-guide/sql/scalar_functions.md +++ b/docs/source/user-guide/sql/scalar_functions.md @@ -3244,6 +3244,7 @@ _Alias of [current_date](#current_date)._ ## Array Functions - [any_match](#any_match) +- [array_add](#array_add) - [array_any_match](#array_any_match) - [array_any_value](#array_any_value) - [array_append](#array_append) @@ -3300,6 +3301,7 @@ _Alias of [current_date](#current_date)._ - [flatten](#flatten) - [generate_series](#generate_series) - [inner_product](#inner_product) +- [list_add](#list_add) - [list_any_match](#list_any_match) - [list_any_value](#list_any_value) - [list_append](#list_append) @@ -3357,6 +3359,34 @@ _Alias of [current_date](#current_date)._ _Alias of [array_any_match](#array_any_match)._ +### `array_add` + +Returns the element-wise sum of two numeric arrays of equal length, computed as `array1[i] + array2[i]` per position. NULL is propagated per element: if either input element at position `i` is NULL, the corresponding output element is NULL (positions are preserved). Returns NULL if either entire input array is NULL. Errors if the per-row lengths differ. Returns an empty array if both inputs are empty. + +```sql +array_add(array1, array2) +``` + +#### Arguments + +- **array1**: Array expression. Can be a constant, column, or function, and any combination of array operators. +- **array2**: Array expression. Can be a constant, column, or function, and any combination of array operators. + +#### Example + +```sql +> select array_add([1.0, 2.0, 3.0], [10.0, 20.0, 30.0]); ++---------------------------------------------------------+ +| array_add(List([1.0,2.0,3.0]),List([10.0,20.0,30.0])) | ++---------------------------------------------------------+ +| [11.0, 22.0, 33.0] | ++---------------------------------------------------------+ +``` + +#### Aliases + +- list_add + ### `array_any_match` Returns whether any elements of an array match the given predicate. Returns true if one or more elements match, false if none match (including empty arrays), and null if the predicate returns null for some elements and false for all others. @@ -4745,6 +4775,10 @@ inner_product(array1, array2) - dot_product +### `list_add` + +_Alias of [array_add](#array_add)._ + ### `list_any_match` _Alias of [array_any_match](#array_any_match)._