diff --git a/datafusion/expr/src/built_in_function.rs b/datafusion/expr/src/built_in_function.rs index ca3ca18e4d77..0d2c1f2e3cb7 100644 --- a/datafusion/expr/src/built_in_function.rs +++ b/datafusion/expr/src/built_in_function.rs @@ -176,6 +176,8 @@ pub enum BuiltinScalarFunction { ArrayToString, /// array_intersect ArrayIntersect, + /// array_union + ArrayUnion, /// cardinality Cardinality, /// construct an array from columns @@ -401,6 +403,7 @@ impl BuiltinScalarFunction { BuiltinScalarFunction::ArraySlice => Volatility::Immutable, BuiltinScalarFunction::ArrayToString => Volatility::Immutable, BuiltinScalarFunction::ArrayIntersect => Volatility::Immutable, + BuiltinScalarFunction::ArrayUnion => Volatility::Immutable, BuiltinScalarFunction::Cardinality => Volatility::Immutable, BuiltinScalarFunction::MakeArray => Volatility::Immutable, BuiltinScalarFunction::Ascii => Volatility::Immutable, @@ -581,6 +584,7 @@ impl BuiltinScalarFunction { BuiltinScalarFunction::ArraySlice => Ok(input_expr_types[0].clone()), BuiltinScalarFunction::ArrayToString => Ok(Utf8), BuiltinScalarFunction::ArrayIntersect => Ok(input_expr_types[0].clone()), + BuiltinScalarFunction::ArrayUnion => Ok(input_expr_types[0].clone()), BuiltinScalarFunction::Cardinality => Ok(UInt64), BuiltinScalarFunction::MakeArray => match input_expr_types.len() { 0 => Ok(List(Arc::new(Field::new("item", Null, true)))), @@ -885,6 +889,7 @@ impl BuiltinScalarFunction { Signature::variadic_any(self.volatility()) } BuiltinScalarFunction::ArrayIntersect => Signature::any(2, self.volatility()), + BuiltinScalarFunction::ArrayUnion => Signature::any(2, self.volatility()), BuiltinScalarFunction::Cardinality => Signature::any(1, self.volatility()), BuiltinScalarFunction::MakeArray => { // 0 or more arguments of arbitrary type @@ -1508,6 +1513,7 @@ fn aliases(func: &BuiltinScalarFunction) -> &'static [&'static str] { "array_join", "list_join", ], + BuiltinScalarFunction::ArrayUnion => &["array_union", "list_union"], BuiltinScalarFunction::Cardinality => &["cardinality"], BuiltinScalarFunction::MakeArray => &["make_array", "make_list"], BuiltinScalarFunction::ArrayIntersect => &["array_intersect", "list_intersect"], diff --git a/datafusion/expr/src/expr_fn.rs b/datafusion/expr/src/expr_fn.rs index 0e0ad46da101..0d920beb416f 100644 --- a/datafusion/expr/src/expr_fn.rs +++ b/datafusion/expr/src/expr_fn.rs @@ -717,6 +717,8 @@ scalar_expr!( array delimiter, "converts each element to its text representation." ); +scalar_expr!(ArrayUnion, array_union, array1 array2, "returns an array of the elements in the union of array1 and array2 without duplicates."); + scalar_expr!( Cardinality, cardinality, diff --git a/datafusion/physical-expr/src/array_expressions.rs b/datafusion/physical-expr/src/array_expressions.rs index 54452e3653a8..20004440d7a7 100644 --- a/datafusion/physical-expr/src/array_expressions.rs +++ b/datafusion/physical-expr/src/array_expressions.rs @@ -27,6 +27,7 @@ use arrow::datatypes::{DataType, Field, UInt64Type}; use arrow::row::{RowConverter, SortField}; use arrow_buffer::NullBuffer; +use arrow_schema::FieldRef; use datafusion_common::cast::{ as_generic_string_array, as_int64_array, as_list_array, as_string_array, }; @@ -36,8 +37,8 @@ use datafusion_common::{ DataFusionError, Result, }; -use hashbrown::HashSet; use itertools::Itertools; +use std::collections::HashSet; macro_rules! downcast_arg { ($ARG:expr, $ARRAY_TYPE:ident) => {{ @@ -1340,6 +1341,86 @@ macro_rules! to_string { }}; } +fn union_generic_lists( + l: &GenericListArray, + r: &GenericListArray, + field: &FieldRef, +) -> Result> { + let converter = RowConverter::new(vec![SortField::new(l.value_type().clone())])?; + + let nulls = NullBuffer::union(l.nulls(), r.nulls()); + let l_values = l.values().clone(); + let r_values = r.values().clone(); + let l_values = converter.convert_columns(&[l_values])?; + let r_values = converter.convert_columns(&[r_values])?; + + // Might be worth adding an upstream OffsetBufferBuilder + let mut offsets = Vec::::with_capacity(l.len() + 1); + offsets.push(OffsetSize::usize_as(0)); + let mut rows = Vec::with_capacity(l_values.num_rows() + r_values.num_rows()); + let mut dedup = HashSet::new(); + for (l_w, r_w) in l.offsets().windows(2).zip(r.offsets().windows(2)) { + let l_slice = l_w[0].as_usize()..l_w[1].as_usize(); + let r_slice = r_w[0].as_usize()..r_w[1].as_usize(); + for i in l_slice { + let left_row = l_values.row(i); + if dedup.insert(left_row) { + rows.push(left_row); + } + } + for i in r_slice { + let right_row = r_values.row(i); + if dedup.insert(right_row) { + rows.push(right_row); + } + } + offsets.push(OffsetSize::usize_as(rows.len())); + dedup.clear(); + } + + let values = converter.convert_rows(rows)?; + let offsets = OffsetBuffer::new(offsets.into()); + let result = values[0].clone(); + Ok(GenericListArray::::new( + field.clone(), + offsets, + result, + nulls, + )) +} + +/// Array_union SQL function +pub fn array_union(args: &[ArrayRef]) -> Result { + if args.len() != 2 { + return exec_err!("array_union needs two arguments"); + } + let array1 = &args[0]; + let array2 = &args[1]; + match (array1.data_type(), array2.data_type()) { + (DataType::Null, _) => Ok(array2.clone()), + (_, DataType::Null) => Ok(array1.clone()), + (DataType::List(field_ref), DataType::List(_)) => { + check_datatypes("array_union", &[&array1, &array2])?; + let list1 = array1.as_list::(); + let list2 = array2.as_list::(); + let result = union_generic_lists::(list1, list2, field_ref)?; + Ok(Arc::new(result)) + } + (DataType::LargeList(field_ref), DataType::LargeList(_)) => { + check_datatypes("array_union", &[&array1, &array2])?; + let list1 = array1.as_list::(); + let list2 = array2.as_list::(); + let result = union_generic_lists::(list1, list2, field_ref)?; + Ok(Arc::new(result)) + } + _ => { + internal_err!( + "array_union only support list with offsets of type int32 and int64" + ) + } + } +} + /// Array_to_string SQL function pub fn array_to_string(args: &[ArrayRef]) -> Result { let arr = &args[0]; diff --git a/datafusion/physical-expr/src/functions.rs b/datafusion/physical-expr/src/functions.rs index 9185ade313eb..80c0eaf054fd 100644 --- a/datafusion/physical-expr/src/functions.rs +++ b/datafusion/physical-expr/src/functions.rs @@ -407,7 +407,9 @@ pub fn create_physical_fun( BuiltinScalarFunction::MakeArray => { Arc::new(|args| make_scalar_function(array_expressions::make_array)(args)) } - + BuiltinScalarFunction::ArrayUnion => { + Arc::new(|args| make_scalar_function(array_expressions::array_union)(args)) + } // struct functions BuiltinScalarFunction::Struct => Arc::new(struct_expressions::struct_expr), diff --git a/datafusion/proto/proto/datafusion.proto b/datafusion/proto/proto/datafusion.proto index 62b226e33339..793378a1ea87 100644 --- a/datafusion/proto/proto/datafusion.proto +++ b/datafusion/proto/proto/datafusion.proto @@ -635,6 +635,7 @@ enum ScalarFunction { StringToArray = 117; ToTimestampNanos = 118; ArrayIntersect = 119; + ArrayUnion = 120; } message ScalarFunctionNode { diff --git a/datafusion/proto/src/generated/pbjson.rs b/datafusion/proto/src/generated/pbjson.rs index 7602e1a36657..a78da2a51c9d 100644 --- a/datafusion/proto/src/generated/pbjson.rs +++ b/datafusion/proto/src/generated/pbjson.rs @@ -20908,6 +20908,7 @@ impl serde::Serialize for ScalarFunction { Self::StringToArray => "StringToArray", Self::ToTimestampNanos => "ToTimestampNanos", Self::ArrayIntersect => "ArrayIntersect", + Self::ArrayUnion => "ArrayUnion", }; serializer.serialize_str(variant) } @@ -21039,6 +21040,7 @@ impl<'de> serde::Deserialize<'de> for ScalarFunction { "StringToArray", "ToTimestampNanos", "ArrayIntersect", + "ArrayUnion", ]; struct GeneratedVisitor; @@ -21199,6 +21201,7 @@ impl<'de> serde::Deserialize<'de> for ScalarFunction { "StringToArray" => Ok(ScalarFunction::StringToArray), "ToTimestampNanos" => Ok(ScalarFunction::ToTimestampNanos), "ArrayIntersect" => Ok(ScalarFunction::ArrayIntersect), + "ArrayUnion" => Ok(ScalarFunction::ArrayUnion), _ => Err(serde::de::Error::unknown_variant(value, FIELDS)), } } diff --git a/datafusion/proto/src/generated/prost.rs b/datafusion/proto/src/generated/prost.rs index 825481a18822..7b7b0afb9216 100644 --- a/datafusion/proto/src/generated/prost.rs +++ b/datafusion/proto/src/generated/prost.rs @@ -2562,6 +2562,7 @@ pub enum ScalarFunction { StringToArray = 117, ToTimestampNanos = 118, ArrayIntersect = 119, + ArrayUnion = 120, } impl ScalarFunction { /// String value of the enum field names used in the ProtoBuf definition. @@ -2690,6 +2691,7 @@ impl ScalarFunction { ScalarFunction::StringToArray => "StringToArray", ScalarFunction::ToTimestampNanos => "ToTimestampNanos", ScalarFunction::ArrayIntersect => "ArrayIntersect", + ScalarFunction::ArrayUnion => "ArrayUnion", } } /// Creates an enum from field names used in the ProtoBuf definition. @@ -2815,6 +2817,7 @@ impl ScalarFunction { "StringToArray" => Some(Self::StringToArray), "ToTimestampNanos" => Some(Self::ToTimestampNanos), "ArrayIntersect" => Some(Self::ArrayIntersect), + "ArrayUnion" => Some(Self::ArrayUnion), _ => None, } } diff --git a/datafusion/proto/src/logical_plan/from_proto.rs b/datafusion/proto/src/logical_plan/from_proto.rs index 674492edef43..f7e38757e923 100644 --- a/datafusion/proto/src/logical_plan/from_proto.rs +++ b/datafusion/proto/src/logical_plan/from_proto.rs @@ -484,6 +484,7 @@ impl From<&protobuf::ScalarFunction> for BuiltinScalarFunction { ScalarFunction::ArraySlice => Self::ArraySlice, ScalarFunction::ArrayToString => Self::ArrayToString, ScalarFunction::ArrayIntersect => Self::ArrayIntersect, + ScalarFunction::ArrayUnion => Self::ArrayUnion, ScalarFunction::Cardinality => Self::Cardinality, ScalarFunction::Array => Self::MakeArray, ScalarFunction::NullIf => Self::NullIf, @@ -1424,6 +1425,12 @@ pub fn parse_expr( ScalarFunction::ArrayNdims => { Ok(array_ndims(parse_expr(&args[0], registry)?)) } + ScalarFunction::ArrayUnion => Ok(array( + args.to_owned() + .iter() + .map(|expr| parse_expr(expr, registry)) + .collect::, _>>()?, + )), ScalarFunction::Sqrt => Ok(sqrt(parse_expr(&args[0], registry)?)), ScalarFunction::Cbrt => Ok(cbrt(parse_expr(&args[0], registry)?)), ScalarFunction::Sin => Ok(sin(parse_expr(&args[0], registry)?)), diff --git a/datafusion/proto/src/logical_plan/to_proto.rs b/datafusion/proto/src/logical_plan/to_proto.rs index 946f2c6964a5..2bb7f89c7d4d 100644 --- a/datafusion/proto/src/logical_plan/to_proto.rs +++ b/datafusion/proto/src/logical_plan/to_proto.rs @@ -1487,6 +1487,7 @@ impl TryFrom<&BuiltinScalarFunction> for protobuf::ScalarFunction { BuiltinScalarFunction::ArraySlice => Self::ArraySlice, BuiltinScalarFunction::ArrayToString => Self::ArrayToString, BuiltinScalarFunction::ArrayIntersect => Self::ArrayIntersect, + BuiltinScalarFunction::ArrayUnion => Self::ArrayUnion, BuiltinScalarFunction::Cardinality => Self::Cardinality, BuiltinScalarFunction::MakeArray => Self::Array, BuiltinScalarFunction::NullIf => Self::NullIf, diff --git a/datafusion/sqllogictest/test_files/array.slt b/datafusion/sqllogictest/test_files/array.slt index ad81f37e0764..96689cd7e13e 100644 --- a/datafusion/sqllogictest/test_files/array.slt +++ b/datafusion/sqllogictest/test_files/array.slt @@ -1919,6 +1919,101 @@ select array_to_string(make_array(), ',') ---- (empty) + +## array_union (aliases: `list_union`) + +# array_union scalar function #1 +query ? +select array_union([1, 2, 3, 4], [5, 6, 3, 4]); +---- +[1, 2, 3, 4, 5, 6] + +# array_union scalar function #2 +query ? +select array_union([1, 2, 3, 4], [5, 6, 7, 8]); +---- +[1, 2, 3, 4, 5, 6, 7, 8] + +# array_union scalar function #3 +query ? +select array_union([1,2,3], []); +---- +[1, 2, 3] + +# array_union scalar function #4 +query ? +select array_union([1, 2, 3, 4], [5, 4]); +---- +[1, 2, 3, 4, 5] + +# array_union scalar function #5 +statement ok +CREATE TABLE arrays_with_repeating_elements_for_union +AS VALUES + ([1], [2]), + ([2, 3], [3]), + ([3], [3, 4]) +; + +query ? +select array_union(column1, column2) from arrays_with_repeating_elements_for_union; +---- +[1, 2] +[2, 3] +[3, 4] + +statement ok +drop table arrays_with_repeating_elements_for_union; + +# array_union scalar function #6 +query ? +select array_union([], []); +---- +NULL + +# array_union scalar function #7 +query ? +select array_union([[null]], []); +---- +[[]] + +# array_union scalar function #8 +query ? +select array_union([null], [null]); +---- +[] + +# array_union scalar function #9 +query ? +select array_union(null, []); +---- +NULL + +# array_union scalar function #10 +query ? +select array_union(null, null); +---- +NULL + +# array_union scalar function #11 +query ? +select array_union([1.2, 3.0], [1.2, 3.0, 5.7]); +---- +[1.2, 3.0, 5.7] + +# array_union scalar function #12 +query ? +select array_union(['hello'], ['hello','datafusion']); +---- +[hello, datafusion] + + + + + + + + # list_to_string scalar function #4 (function alias `array_to_string`) query TTT select list_to_string(['h', 'e', 'l', 'l', 'o'], ','), list_to_string([1, 2, 3, 4, 5], '-'), list_to_string([1.0, 2.0, 3.0], '|'); diff --git a/docs/source/user-guide/expressions.md b/docs/source/user-guide/expressions.md index 27384dccffe0..bec3ba9bb28c 100644 --- a/docs/source/user-guide/expressions.md +++ b/docs/source/user-guide/expressions.md @@ -233,6 +233,7 @@ Unlike to some databases the math functions in Datafusion works the same way as | array_slice(array, index) | Returns a slice of the array. `array_slice([1, 2, 3, 4, 5, 6, 7, 8], 3, 6) -> [3, 4, 5, 6]` | | array_to_string(array, delimiter) | Converts each element to its text representation. `array_to_string([1, 2, 3, 4], ',') -> 1,2,3,4` | | array_intersect(array1, array2) | Returns an array of the elements in the intersection of array1 and array2. `array_intersect([1, 2, 3, 4], [5, 6, 3, 4]) -> [3, 4]` | +| array_union(array1, array2) | Returns an array of the elements in the union of array1 and array2 without duplicates. `array_union([1, 2, 3, 4], [5, 6, 3, 4]) -> [1, 2, 3, 4, 5, 6]` | | cardinality(array) | Returns the total number of elements in the array. `cardinality([[1, 2, 3], [4, 5, 6]]) -> 6` | | make_array(value1, [value2 [, ...]]) | Returns an Arrow array using the specified input expressions. `make_array(1, 2, 3) -> [1, 2, 3]` | | trim_array(array, n) | Deprecated | diff --git a/docs/source/user-guide/sql/scalar_functions.md b/docs/source/user-guide/sql/scalar_functions.md index be05084fb249..2959e8202437 100644 --- a/docs/source/user-guide/sql/scalar_functions.md +++ b/docs/source/user-guide/sql/scalar_functions.md @@ -2211,6 +2211,44 @@ array_to_string(array, delimiter) - list_join - list_to_string +### `array_union` + +Returns an array of elements that are present in both arrays (all elements from both arrays) with out duplicates. + +``` +array_union(array1, array2) +``` + +#### Arguments + +- **array1**: Array expression. + Can be a constant, column, or function, and any combination of array operators. +- **array2**: Array expression. + Can be a constant, column, or function, and any combination of array operators. + +#### Example + +``` +❯ select array_union([1, 2, 3, 4], [5, 6, 3, 4]); ++----------------------------------------------------+ +| array_union([1, 2, 3, 4], [5, 6, 3, 4]); | ++----------------------------------------------------+ +| [1, 2, 3, 4, 5, 6] | ++----------------------------------------------------+ +❯ select array_union([1, 2, 3, 4], [5, 6, 7, 8]); ++----------------------------------------------------+ +| array_union([1, 2, 3, 4], [5, 6, 7, 8]); | ++----------------------------------------------------+ +| [1, 2, 3, 4, 5, 6] | ++----------------------------------------------------+ +``` + +--- + +#### Aliases + +- list_union + ### `cardinality` Returns the total number of elements in the array.