diff --git a/datafusion/spark/src/function/map/map_from_entries.rs b/datafusion/spark/src/function/map/map_from_entries.rs new file mode 100644 index 000000000000..6648979c5dd2 --- /dev/null +++ b/datafusion/spark/src/function/map/map_from_entries.rs @@ -0,0 +1,133 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::any::Any; + +use crate::function::map::utils::{ + get_element_type, get_list_offsets, get_list_values, + map_from_keys_values_offsets_nulls, map_type_from_key_value_types, +}; +use arrow::array::{Array, ArrayRef, NullBufferBuilder, StructArray}; +use arrow::buffer::NullBuffer; +use arrow::datatypes::DataType; +use datafusion_common::utils::take_function_args; +use datafusion_common::{exec_err, Result}; +use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; +use datafusion_functions::utils::make_scalar_function; + +/// Spark-compatible `map_from_entries` expression +/// +#[derive(Debug, PartialEq, Eq, Hash)] +pub struct MapFromEntries { + signature: Signature, +} + +impl Default for MapFromEntries { + fn default() -> Self { + Self::new() + } +} + +impl MapFromEntries { + pub fn new() -> Self { + Self { + signature: Signature::array(Volatility::Immutable), + } + } +} + +impl ScalarUDFImpl for MapFromEntries { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "map_from_entries" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + let [entries_type] = take_function_args("map_from_entries", arg_types)?; + let entries_element_type = get_element_type(entries_type)?; + let (keys_type, values_type) = match entries_element_type { + DataType::Struct(fields) if fields.len() == 2 => { + Ok((fields[0].data_type(), fields[1].data_type())) + } + wrong_type => exec_err!( + "map_from_entries: expected array>, got {:?}", + wrong_type + ), + }?; + Ok(map_type_from_key_value_types(keys_type, values_type)) + } + + fn invoke_with_args( + &self, + args: datafusion_expr::ScalarFunctionArgs, + ) -> Result { + make_scalar_function(map_from_entries_inner, vec![])(&args.args) + } +} + +fn map_from_entries_inner(args: &[ArrayRef]) -> Result { + let [entries] = take_function_args("map_from_entries", args)?; + let entries_offsets = get_list_offsets(entries)?; + let entries_values = get_list_values(entries)?; + + let (flat_keys, flat_values) = + match entries_values.as_any().downcast_ref::() { + Some(a) => Ok((a.column(0), a.column(1))), + None => exec_err!( + "map_from_entries: expected array>, got {:?}", + entries_values.data_type() + ), + }?; + + let entries_with_nulls = entries_values.nulls().and_then(|entries_inner_nulls| { + let mut builder = NullBufferBuilder::new_with_len(0); + let mut cur_offset = entries_offsets + .first() + .map(|offset| *offset as usize) + .unwrap_or(0); + + for next_offset in entries_offsets.iter().skip(1) { + let num_entries = *next_offset as usize - cur_offset; + builder.append( + entries_inner_nulls + .slice(cur_offset, num_entries) + .null_count() + == 0, + ); + cur_offset = *next_offset as usize; + } + builder.finish() + }); + + let res_nulls = NullBuffer::union(entries.nulls(), entries_with_nulls.as_ref()); + + map_from_keys_values_offsets_nulls( + flat_keys, + flat_values, + &entries_offsets, + &entries_offsets, + None, + res_nulls.as_ref(), + ) +} diff --git a/datafusion/spark/src/function/map/mod.rs b/datafusion/spark/src/function/map/mod.rs index 21d1e0f108c0..2f596b19b422 100644 --- a/datafusion/spark/src/function/map/mod.rs +++ b/datafusion/spark/src/function/map/mod.rs @@ -16,6 +16,7 @@ // under the License. pub mod map_from_arrays; +pub mod map_from_entries; mod utils; use datafusion_expr::ScalarUDF; @@ -23,6 +24,7 @@ use datafusion_functions::make_udf_function; use std::sync::Arc; make_udf_function!(map_from_arrays::MapFromArrays, map_from_arrays); +make_udf_function!(map_from_entries::MapFromEntries, map_from_entries); pub mod expr_fn { use datafusion_functions::export_functions; @@ -32,8 +34,14 @@ pub mod expr_fn { "Creates a map from arrays of keys and values.", keys values )); + + export_functions!(( + map_from_entries, + "Creates a map from array>.", + arg1 + )); } pub fn functions() -> Vec> { - vec![map_from_arrays()] + vec![map_from_arrays(), map_from_entries()] } diff --git a/datafusion/spark/src/function/map/utils.rs b/datafusion/spark/src/function/map/utils.rs index fa4fc5fae44d..b568f45403c3 100644 --- a/datafusion/spark/src/function/map/utils.rs +++ b/datafusion/spark/src/function/map/utils.rs @@ -157,8 +157,15 @@ fn map_deduplicate_keys( let offsets_len = keys_offsets.len(); let mut new_offsets = Vec::with_capacity(offsets_len); - let mut cur_keys_offset = 0; - let mut cur_values_offset = 0; + let mut cur_keys_offset = keys_offsets + .first() + .map(|offset| *offset as usize) + .unwrap_or(0); + let mut cur_values_offset = values_offsets + .first() + .map(|offset| *offset as usize) + .unwrap_or(0); + let mut new_last_offset = 0; new_offsets.push(new_last_offset); @@ -176,36 +183,38 @@ fn map_deduplicate_keys( let mut keys_mask_one = [false].repeat(num_keys_entries); let mut values_mask_one = [false].repeat(num_values_entries); - if num_keys_entries != num_values_entries { - let key_is_valid = keys_nulls.is_none_or(|buf| buf.is_valid(row_idx)); - let value_is_valid = values_nulls.is_none_or(|buf| buf.is_valid(row_idx)); - if key_is_valid && value_is_valid { + let key_is_valid = keys_nulls.is_none_or(|buf| buf.is_valid(row_idx)); + let value_is_valid = values_nulls.is_none_or(|buf| buf.is_valid(row_idx)); + + if key_is_valid && value_is_valid { + if num_keys_entries != num_values_entries { return exec_err!("map_deduplicate_keys: keys and values lists in the same row must have equal lengths"); + } else if num_keys_entries != 0 { + let mut seen_keys = HashSet::new(); + + for cur_entry_idx in (0..num_keys_entries).rev() { + let key = ScalarValue::try_from_array( + &flat_keys, + cur_keys_offset + cur_entry_idx, + )? + .compacted(); + if seen_keys.contains(&key) { + // TODO: implement configuration and logic for spark.sql.mapKeyDedupPolicy=EXCEPTION (this is default spark-config) + // exec_err!("invalid argument: duplicate keys in map") + // https://github.com/apache/spark/blob/cf3a34e19dfcf70e2d679217ff1ba21302212472/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala#L4961 + } else { + // This code implements deduplication logic for spark.sql.mapKeyDedupPolicy=LAST_WIN (this is NOT default spark-config) + keys_mask_one[cur_entry_idx] = true; + values_mask_one[cur_entry_idx] = true; + seen_keys.insert(key); + new_last_offset += 1; + } + } } - // else the result entry is NULL + } else { + // the result entry is NULL // both current row offsets are skipped // keys or values in the current row are marked false in the masks - } else if num_keys_entries != 0 { - let mut seen_keys = HashSet::new(); - - for cur_entry_idx in (0..num_keys_entries).rev() { - let key = ScalarValue::try_from_array( - &flat_keys, - cur_keys_offset + cur_entry_idx, - )? - .compacted(); - if seen_keys.contains(&key) { - // TODO: implement configuration and logic for spark.sql.mapKeyDedupPolicy=EXCEPTION (this is default spark-config) - // exec_err!("invalid argument: duplicate keys in map") - // https://github.com/apache/spark/blob/cf3a34e19dfcf70e2d679217ff1ba21302212472/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala#L4961 - } else { - // This code implements deduplication logic for spark.sql.mapKeyDedupPolicy=LAST_WIN (this is NOT default spark-config) - keys_mask_one[cur_entry_idx] = true; - values_mask_one[cur_entry_idx] = true; - seen_keys.insert(key); - new_last_offset += 1; - } - } } keys_mask_builder.append_array(&keys_mask_one.into()); values_mask_builder.append_array(&values_mask_one.into()); diff --git a/datafusion/sqllogictest/test_files/spark/map/map_from_entries.slt b/datafusion/sqllogictest/test_files/spark/map/map_from_entries.slt new file mode 100644 index 000000000000..19b46886a027 --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/map/map_from_entries.slt @@ -0,0 +1,164 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# Spark doctests +query ? +SELECT map_from_entries(array[struct(1, 'a'), struct(2, 'b')]); +---- +{1: a, 2: b} + +query ? +SELECT map_from_entries(array[struct(1, cast(null as string)), struct(2, 'b')]); +---- +{1: NULL, 2: b} + +query ? +SELECT map_from_entries(data) +from values + (array[struct(1, 'a'), struct(2, 'b')]), + (array[struct(3, 'c')]) +as tab(data); +---- +{1: a, 2: b} +{3: c} + +# Tests with NULL and empty input structarrays +query ? +SELECT map_from_entries(data) +from values + (cast(array[] as array>)), + (cast(NULL as array>)) +as tab(data); +---- +{} +NULL + +# Test with NULL key, should fail +query error DataFusion error: Arrow error: Invalid argument error: Found unmasked nulls for non-nullable StructArray field "key" +SELECT map_from_entries(array[struct(NULL, 1)]); + +# Tests with NULL and array of Null type, should fail +query error DataFusion error: Execution error: map_from_entries: expected array>, got Null +SELECT map_from_entries(NULL); + +query error DataFusion error: Execution error: map_from_entries: expected array>, got Null +SELECT map_from_entries(array[NULL]); + +# Test with NULL array and NULL entries in arrays +# output is NULL if any entry is NULL +query ? +SELECT map_from_entries(data) +from values + ( + array[ + struct(1 as a, 'a' as b), + cast(NULL as struct), + cast(NULL as struct) + ] + ), + (NULL), + ( + array[ + struct(2 as a, 'b' as b), + struct(3 as a, 'c' as b) + ] + ), + ( + array[ + struct(4 as a, 'd' as b), + cast(NULL as struct), + struct(5 as a, 'e' as b), + struct(6 as a, 'f' as b) + ] + ) +as tab(data); +---- +NULL +NULL +{2: b, 3: c} +NULL + +#Test with multiple rows: good, empty and nullable +query ? +SELECT map_from_entries(data) +from values + (NULL), + (array[ + struct(1 as a, 'b' as b), + struct(2 as a, cast(NULL as string) as b), + struct(3 as a, 'd' as b) + ]), + (array[]), + (NULL) +as tab(data); +---- +NULL +{1: b, 2: NULL, 3: d} +{} +NULL + +# Test with complex types +query ? +SELECT map_from_entries(array[ + struct(array('a', 'b'), struct(1, 2, 3)), + struct(array('c', 'd'), struct(4, 5, 6)) +]); +---- +{[a, b]: {c0: 1, c1: 2, c2: 3}, [c, d]: {c0: 4, c1: 5, c2: 6}} + +# Test with nested function calls +query ? +SELECT + map_from_entries( + array[ + struct( + 'outer_key1', + -- value for outer_key1: a map itself + map_from_entries( + array[ + struct('inner_a', 1), + struct('inner_b', 2) + ] + ) + ), + struct( + 'outer_key2', + -- value for outer_key2: another map + map_from_entries( + array[ + struct('inner_x', 10), + struct('inner_y', 20), + struct('inner_z', 30) + ] + ) + ) + ] + ) AS nested_map; +---- +{outer_key1: {inner_a: 1, inner_b: 2}, outer_key2: {inner_x: 10, inner_y: 20, inner_z: 30}} + +# Test with duplicate keys +query ? +SELECT map_from_entries(array( + struct(true, 'a'), + struct(false, 'b'), + struct(true, 'c'), + struct(false, cast(NULL as string)), + struct(true, 'd') +)); +---- +{false: NULL, true: d}