diff --git a/Cargo.lock b/Cargo.lock index 502f4772e..1e65aac2f 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1224,9 +1224,11 @@ dependencies = [ "datafusion", "datafusion-ext-commons", "itertools 0.14.0", + "jni", "log", "num", "paste", + "regex", "serde_json", "sonic-rs", ] @@ -3495,9 +3497,9 @@ dependencies = [ [[package]] name = "regex" -version = "1.11.1" +version = "1.12.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b544ef1b4eac5dc2db33ea63606ae9ffcfac26c1416a2806ae0bf5f56b201191" +checksum = "e10754a14b9137dd7b1e3e5b0493cc9171fdd105e0ab477f51b72e7f3ac0e276" dependencies = [ "aho-corasick", "memchr", @@ -3507,9 +3509,9 @@ dependencies = [ [[package]] name = "regex-automata" -version = "0.4.9" +version = "0.4.14" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "809e8dc61f6de73b46c85f4c96486310fe304c434cfa43669d7b40f711150908" +checksum = "6e1dd4122fc1595e8162618945476892eefca7b88c52820e74af6262213cae8f" dependencies = [ "aho-corasick", "memchr", diff --git a/auron-core/src/main/java/org/apache/auron/jni/JniBridge.java b/auron-core/src/main/java/org/apache/auron/jni/JniBridge.java index d08536087..48f10ab01 100644 --- a/auron-core/src/main/java/org/apache/auron/jni/JniBridge.java +++ b/auron-core/src/main/java/org/apache/auron/jni/JniBridge.java @@ -23,6 +23,8 @@ import java.nio.ByteBuffer; import java.util.List; import java.util.concurrent.ConcurrentHashMap; +import java.util.regex.Pattern; +import java.util.regex.PatternSyntaxException; import org.apache.auron.configuration.AuronConfiguration; import org.apache.auron.configuration.ConfigOption; import org.apache.auron.functions.AuronUDFWrapperContext; @@ -39,6 +41,7 @@ @SuppressWarnings("unused") public class JniBridge { private static final ConcurrentHashMap resourcesMap = new ConcurrentHashMap<>(); + private static final ConcurrentHashMap regexCache = new ConcurrentHashMap<>(); private static final List directMXBeans = ManagementFactory.getPlatformMXBeans(BufferPoolMXBean.class); @@ -134,6 +137,38 @@ public static String getEngineName() { return AuronAdaptor.getInstance().getEngineName(); } + public static String[] strToMapSplit(String text, String pairDelim, String keyValueDelim) { + Pattern pairPattern = getCachedPattern(pairDelim, "pairDelim"); + Pattern keyValuePattern = getCachedPattern(keyValueDelim, "keyValueDelim"); + + String[] entries = pairPattern.split(text, -1); + String[] flattened = new String[entries.length * 2]; + for (int i = 0; i < entries.length; i++) { + String[] kv = keyValuePattern.split(entries[i], 2); + flattened[i * 2] = kv[0]; + flattened[i * 2 + 1] = kv.length > 1 ? kv[1] : null; + } + return flattened; + } + + private static Pattern getCachedPattern(String pattern, String argName) { + Pattern cached = regexCache.get(pattern); + if (cached != null) { + return cached; + } + + final Pattern compiled; + try { + compiled = Pattern.compile(pattern); + } catch (PatternSyntaxException e) { + throw new RuntimeException( + "str_to_map " + argName + " arg must be a valid Java regex: " + e.getMessage(), e); + } + + Pattern existing = regexCache.putIfAbsent(pattern, compiled); + return existing != null ? existing : compiled; + } + static T getConfValue(String confKey) { Class confClass = AuronAdaptor.getInstance().getAuronConfiguration().getClass(); diff --git a/native-engine/auron-jni-bridge/src/jni_bridge.rs b/native-engine/auron-jni-bridge/src/jni_bridge.rs index 85b7598d3..aef3b3053 100644 --- a/native-engine/auron-jni-bridge/src/jni_bridge.rs +++ b/native-engine/auron-jni-bridge/src/jni_bridge.rs @@ -642,6 +642,8 @@ pub struct JniBridge<'a> { pub method_booleanConf_ret: ReturnType, pub method_stringConf: JStaticMethodID, pub method_stringConf_ret: ReturnType, + pub method_strToMapSplit: JStaticMethodID, + pub method_strToMapSplit_ret: ReturnType, pub method_getEngineName: JStaticMethodID, pub method_getEngineName_ret: ReturnType, } @@ -757,6 +759,12 @@ impl<'a> JniBridge<'a> { "(Ljava/lang/String;)Ljava/lang/String;", )?, method_stringConf_ret: ReturnType::Object, + method_strToMapSplit: env.get_static_method_id( + class, + "strToMapSplit", + "(Ljava/lang/String;Ljava/lang/String;Ljava/lang/String;)[Ljava/lang/String;", + )?, + method_strToMapSplit_ret: ReturnType::Object, method_getEngineName: env.get_static_method_id( class, "getEngineName", diff --git a/native-engine/datafusion-ext-functions/Cargo.toml b/native-engine/datafusion-ext-functions/Cargo.toml index a5eb5ded5..ad3af0f48 100644 --- a/native-engine/datafusion-ext-functions/Cargo.toml +++ b/native-engine/datafusion-ext-functions/Cargo.toml @@ -29,11 +29,13 @@ arrow = { workspace = true } auron-jni-bridge = { workspace = true } datafusion = { workspace = true } datafusion-ext-commons = { workspace = true } +jni = { workspace = true } itertools = { workspace = true } log = { workspace = true } num = { workspace = true } paste = { workspace = true } +regex = "1.12.3" serde_json = { workspace = true } sonic-rs = { workspace = true } chrono = "0.4.44" diff --git a/native-engine/datafusion-ext-functions/src/lib.rs b/native-engine/datafusion-ext-functions/src/lib.rs index ae80ef2df..7eb1f63a6 100644 --- a/native-engine/datafusion-ext-functions/src/lib.rs +++ b/native-engine/datafusion-ext-functions/src/lib.rs @@ -67,6 +67,7 @@ pub fn create_auron_ext_function( "Spark_MapConcat" => Arc::new(spark_map::map_concat), "Spark_MapFromArrays" => Arc::new(spark_map::map_from_arrays), "Spark_MapFromEntries" => Arc::new(spark_map::map_from_entries), + "Spark_StrToMap" => Arc::new(spark_map::str_to_map), "Spark_StringSpace" => Arc::new(spark_strings::string_space), "Spark_StringRepeat" => Arc::new(spark_strings::string_repeat), "Spark_StringSplit" => Arc::new(spark_strings::string_split), diff --git a/native-engine/datafusion-ext-functions/src/spark_map.rs b/native-engine/datafusion-ext-functions/src/spark_map.rs index aa3cbda16..b078bbba6 100644 --- a/native-engine/datafusion-ext-functions/src/spark_map.rs +++ b/native-engine/datafusion-ext-functions/src/spark_map.rs @@ -19,17 +19,26 @@ use std::{ }; use arrow::{ - array::{Array, ArrayRef, ListArray, MapArray, StructArray, new_empty_array}, + array::{Array, ArrayRef, ListArray, MapArray, StringArray, StructArray, new_empty_array}, buffer::{NullBuffer, OffsetBuffer, ScalarBuffer}, datatypes::{DataType, Field, Fields}, }; +use auron_jni_bridge::{ + is_jni_bridge_inited, jni_call_static, jni_get_string, jni_map_error_with_env, jni_new_string, +}; use datafusion::{ - common::{Result, ScalarValue}, + common::{Result, ScalarValue, cast::as_string_array}, logical_expr::ColumnarValue, }; use datafusion_ext_commons::{ df_execution_err, downcast_any, scalar_value::compacted_scalar_value_from_array, }; +use jni::{ + JNIEnv, + objects::{JObject, JString}, + sys::{jarray, jobjectArray}, +}; +use regex::Regex; #[derive(Clone, Copy, Debug, Eq, PartialEq)] enum MapKeyDedupPolicy { @@ -269,6 +278,271 @@ fn parse_map_key_dedup_policy(args: &[ColumnarValue], idx: usize) -> Result, + pattern: &str, + arg_name: &str, +) -> Result { + if let Some(regex) = cache.get(pattern) { + return Ok(regex.clone()); + } + + let regex = Regex::new(pattern).map_err(|err| { + datafusion::error::DataFusionError::Execution(format!( + "str_to_map {arg_name} arg must be a valid regex in unit-test fallback mode: {err}" + )) + })?; + cache.insert(pattern.to_owned(), regex.clone()); + Ok(regex) +} + +fn jobject_to_string(env: &JNIEnv<'_>, value: JObject<'_>) -> Result { + let auto_local = env.auto_local(value); + let jstring: JString<'_> = auto_local.as_obj().into(); + jni_get_string!(jstring) +} + +fn jobject_to_optional_string(env: &JNIEnv<'_>, value: JObject<'_>) -> Result> { + if value.is_null() { + Ok(None) + } else { + jobject_to_string(env, value).map(Some) + } +} + +fn java_str_to_map_split( + text: &str, + pair_delim: &str, + key_value_delim: &str, +) -> Result)>> { + let text = jni_new_string!(text)?; + let pair_delim = jni_new_string!(pair_delim)?; + let key_value_delim = jni_new_string!(key_value_delim)?; + let flattened = jni_call_static!( + JniBridge.strToMapSplit(text.as_obj(), pair_delim.as_obj(), key_value_delim.as_obj()) + -> JObject + )?; + + auron_jni_bridge::jni_bridge::THREAD_JNIENV.with(|env| { + let flattened: jobjectArray = flattened.as_obj().into_raw() as jobjectArray; + let len = jni_map_error_with_env!(env, env.get_array_length(flattened as jarray))? as usize; + if len % 2 != 0 { + return df_execution_err!( + "str_to_map internal error: Java split returned an odd number of fields", + ); + } + + let mut out = Vec::with_capacity(len / 2); + for idx in (0..len).step_by(2) { + let key_obj = + jni_map_error_with_env!(env, env.get_object_array_element(flattened, idx as i32))?; + let value_obj = jni_map_error_with_env!( + env, + env.get_object_array_element(flattened, (idx + 1) as i32) + )?; + let key = jobject_to_string(env, key_obj)?; + let value = jobject_to_optional_string(env, value_obj)?; + out.push((key, value)); + } + Ok(out) + }) +} + +fn fallback_str_to_map_split( + text: &str, + pair_delim: &str, + key_value_delim: &str, + pair_regex_cache: &mut HashMap, + key_value_regex_cache: &mut HashMap, +) -> Result)>> { + let pair_regex = get_or_compile_test_regex(pair_regex_cache, pair_delim, "pairDelim")?; + let key_value_regex = + get_or_compile_test_regex(key_value_regex_cache, key_value_delim, "keyValueDelim")?; + + Ok(pair_regex + .split(text) + .map(|kv_entry| { + let mut kv_parts = key_value_regex.splitn(kv_entry, 2); + let key = kv_parts.next().unwrap_or_default().to_owned(); + let value = kv_parts.next().map(ToOwned::to_owned); + (key, value) + }) + .collect()) +} + +fn str_to_map_split( + text: &str, + pair_delim: &str, + key_value_delim: &str, + pair_regex_cache: &mut HashMap, + key_value_regex_cache: &mut HashMap, +) -> Result)>> { + if is_jni_bridge_inited() { + java_str_to_map_split(text, pair_delim, key_value_delim) + } else { + fallback_str_to_map_split( + text, + pair_delim, + key_value_delim, + pair_regex_cache, + key_value_regex_cache, + ) + } +} + +fn columnar_value_to_string_array( + arg: &ColumnarValue, + len: usize, + arg_name: &str, +) -> Result { + let array = arg.clone().into_array(len)?; + match array.data_type() { + DataType::Null => Ok(StringArray::from(vec![None::<&str>; array.len()])), + DataType::Utf8 => Ok(as_string_array(&array)?.clone()), + data_type => { + df_execution_err!("str_to_map {arg_name} arg must be string, found {data_type:?}") + } + } +} + +/// Creates a map after splitting text into key/value pairs using Java regex +/// delimiters. +/// +/// This follows Spark StringToMap semantics: +/// - null in any argument => null result +/// - pairDelim is applied as Pattern.compile(pairDelim).split(text, -1) +/// - keyValueDelim is applied as Pattern.compile(keyValueDelim).split(entry, 2) +/// - missing value => null +/// - duplicate keys follow spark.sql.mapKeyDedupPolicy +pub fn str_to_map(args: &[ColumnarValue]) -> Result { + if args.len() < 3 || args.len() > 4 { + return df_execution_err!("str_to_map requires 3 or 4 arguments"); + } + + let dedup_policy = parse_map_key_dedup_policy(args, 3)?; + let num_rows = args + .iter() + .filter_map(|arg| match arg { + ColumnarValue::Array(array) => Some(array.len()), + ColumnarValue::Scalar(_) => None, + }) + .filter(|&len| len != 1) + .max() + .unwrap_or(1); + + if args.iter().any(|arg| match arg { + ColumnarValue::Array(array) => array.len() != 1 && array.len() != num_rows, + ColumnarValue::Scalar(_) => false, + }) { + return df_execution_err!("all arguments of str_to_map must have the same length"); + } + + let text_array = columnar_value_to_string_array(&args[0], num_rows, "text")?; + let pair_delim_array = columnar_value_to_string_array(&args[1], num_rows, "pairDelim")?; + let key_value_delim_array = + columnar_value_to_string_array(&args[2], num_rows, "keyValueDelim")?; + + let key_field = Arc::new(Field::new("key", DataType::Utf8, false)); + let value_field = Arc::new(Field::new("value", DataType::Utf8, true)); + let entries_field = Arc::new(Field::new( + "entries", + DataType::Struct(Fields::from(vec![ + key_field.as_ref().clone(), + value_field.as_ref().clone(), + ])), + false, + )); + + let mut pair_regex_cache = HashMap::new(); + let mut key_value_regex_cache = HashMap::new(); + + let mut all_keys = Vec::new(); + let mut all_values = Vec::new(); + let mut offsets = Vec::with_capacity(num_rows + 1); + let mut valids = Vec::with_capacity(num_rows); + let mut next_offset = 0i32; + + offsets.push(next_offset); + + for row_idx in 0..num_rows { + if text_array.is_null(row_idx) + || pair_delim_array.is_null(row_idx) + || key_value_delim_array.is_null(row_idx) + { + valids.push(false); + offsets.push(next_offset); + continue; + } + + let text = text_array.value(row_idx); + let pair_delim = pair_delim_array.value(row_idx); + let key_value_delim = key_value_delim_array.value(row_idx); + + let split_entries = str_to_map_split( + text, + pair_delim, + key_value_delim, + &mut pair_regex_cache, + &mut key_value_regex_cache, + )?; + let mut row_entries: Vec<(String, Option)> = + Vec::with_capacity(split_entries.len()); + let mut row_key_to_index: HashMap = HashMap::new(); + + for (key, value) in split_entries { + if let Some(idx) = row_key_to_index.get(&key).copied() { + match dedup_policy { + MapKeyDedupPolicy::Exception => { + return df_execution_err!("str_to_map duplicate key found: {key}"); + } + MapKeyDedupPolicy::LastWin => { + row_entries[idx].1 = value; + } + } + } else { + row_key_to_index.insert(key.clone(), row_entries.len()); + row_entries.push((key, value)); + } + } + + valids.push(true); + next_offset += row_entries.len() as i32; + offsets.push(next_offset); + + for (key, value) in row_entries { + all_keys.push(ScalarValue::Utf8(Some(key))); + all_values.push(ScalarValue::Utf8(value)); + } + } + + let keys = if all_keys.is_empty() { + new_empty_array(key_field.data_type()) + } else { + ScalarValue::iter_to_array(all_keys.into_iter())? + }; + + let values = if all_values.is_empty() { + new_empty_array(value_field.data_type()) + } else { + ScalarValue::iter_to_array(all_values.into_iter())? + }; + + let entries = StructArray::from(vec![(key_field, keys), (value_field, values)]); + let nulls = if valids.iter().all(|valid| *valid) { + None + } else { + Some(NullBuffer::from(valids)) + }; + + Ok(ColumnarValue::Array(Arc::new(MapArray::new( + entries_field, + OffsetBuffer::new(ScalarBuffer::from(offsets)), + entries, + nulls, + false, + )))) +} + /// Returns a map created from the given array of entries. /// /// This follows Spark semantics: @@ -1142,4 +1416,101 @@ mod test { assert_eq!(&actual, &expected); Ok(()) } + + #[test] + fn test_str_to_map() -> Result<()> { + let text = Arc::new(StringArray::from(vec![ + Some("a:1,b:2"), + Some("a:1:2,b"), + None, + ])) as ArrayRef; + + let actual = str_to_map(&[ + ColumnarValue::Array(text), + ColumnarValue::Scalar(ScalarValue::Utf8(Some(",".to_string()))), + ColumnarValue::Scalar(ScalarValue::Utf8(Some(":".to_string()))), + ColumnarValue::Scalar(ScalarValue::Utf8(Some("EXCEPTION".to_string()))), + ])? + .into_array(3)?; + + let expected = Arc::new(build_string_string_map_array(vec![ + Some(vec![("a", Some("1")), ("b", Some("2"))]), + Some(vec![("a", Some("1:2")), ("b", None)]), + None, + ])) as ArrayRef; + + assert_eq!(&actual, &expected); + Ok(()) + } + + #[test] + fn test_str_to_map_regex_delims() -> Result<()> { + let text = Arc::new(StringArray::from(vec![Some("a::1,,b:::2")])) as ArrayRef; + + let actual = str_to_map(&[ + ColumnarValue::Array(text), + ColumnarValue::Scalar(ScalarValue::Utf8(Some(",+".to_string()))), + ColumnarValue::Scalar(ScalarValue::Utf8(Some(":+".to_string()))), + ColumnarValue::Scalar(ScalarValue::Utf8(Some("EXCEPTION".to_string()))), + ])? + .into_array(1)?; + + let expected = Arc::new(build_string_string_map_array(vec![Some(vec![ + ("a", Some("1")), + ("b", Some("2")), + ])])) as ArrayRef; + + assert_eq!(&actual, &expected); + Ok(()) + } + + #[test] + fn test_str_to_map_null_scalar_propagation() -> Result<()> { + let actual = str_to_map(&[ + ColumnarValue::Scalar(ScalarValue::Utf8(None)), + ColumnarValue::Scalar(ScalarValue::Utf8(Some(",".to_string()))), + ColumnarValue::Scalar(ScalarValue::Utf8(Some(":".to_string()))), + ColumnarValue::Scalar(ScalarValue::Utf8(Some("EXCEPTION".to_string()))), + ])? + .into_array(1)?; + + assert!(actual.is_null(0)); + Ok(()) + } + + #[test] + fn test_str_to_map_duplicate_keys() { + let text = Arc::new(StringArray::from(vec![Some("a:1,a:2")])) as ArrayRef; + + let err = str_to_map(&[ + ColumnarValue::Array(text), + ColumnarValue::Scalar(ScalarValue::Utf8(Some(",".to_string()))), + ColumnarValue::Scalar(ScalarValue::Utf8(Some(":".to_string()))), + ColumnarValue::Scalar(ScalarValue::Utf8(Some("EXCEPTION".to_string()))), + ]) + .expect_err("str_to_map should fail when duplicate keys exist"); + + assert!(err.to_string().contains("duplicate key")); + } + + #[test] + fn test_str_to_map_last_win() -> Result<()> { + let text = Arc::new(StringArray::from(vec![Some("a:1,b:2,a:3")])) as ArrayRef; + + let actual = str_to_map(&[ + ColumnarValue::Array(text), + ColumnarValue::Scalar(ScalarValue::Utf8(Some(",".to_string()))), + ColumnarValue::Scalar(ScalarValue::Utf8(Some(":".to_string()))), + ColumnarValue::Scalar(ScalarValue::Utf8(Some("LAST_WIN".to_string()))), + ])? + .into_array(1)?; + + let expected = Arc::new(build_string_string_map_array(vec![Some(vec![ + ("a", Some("3")), + ("b", Some("2")), + ])])) as ArrayRef; + + assert_eq!(&actual, &expected); + Ok(()) + } } diff --git a/spark-extension-shims-spark/src/test/scala/org/apache/auron/AuronFunctionSuite.scala b/spark-extension-shims-spark/src/test/scala/org/apache/auron/AuronFunctionSuite.scala index 8dadd4d56..a4ba7f55a 100644 --- a/spark-extension-shims-spark/src/test/scala/org/apache/auron/AuronFunctionSuite.scala +++ b/spark-extension-shims-spark/src/test/scala/org/apache/auron/AuronFunctionSuite.scala @@ -461,6 +461,67 @@ class AuronFunctionSuite extends AuronQueryTest with BaseAuronSQLSuite { } } + test("str_to_map function") { + withTable("t1") { + sql("create table t1(c1 string) using parquet") + sql(""" + |insert into t1 values + | ('a:1,b:2'), + | ('a:1:2,b'), + | (null) + |""".stripMargin) + checkSparkAnswerAndOperator("select str_to_map(c1) from t1") + } + } + + test("str_to_map regex delimiters") { + withTable("t1") { + sql("create table t1(c1 string) using parquet") + sql("insert into t1 values ('a::1,,b:::2')") + checkSparkAnswerAndOperator("select str_to_map(c1, ',+', ':+') from t1") + } + } + + test("str_to_map Java regex delimiters") { + withTable("t1") { + sql("create table t1(c1 string) using parquet") + sql("insert into t1 values ('a:1,b:2,c:3')") + checkSparkAnswerAndOperator("select str_to_map(c1, ',(?=b:|c:)', ':') from t1") + } + } + + test("str_to_map duplicate keys") { + withTable("t1") { + sql("create table t1(c1 string) using parquet") + sql("insert into t1 values ('a:1,a:2')") + val df = sql("select str_to_map(c1) from t1") + val err = intercept[Exception] { + df.collect() + } + val plan = stripAQEPlan(df.queryExecution.executedPlan) + plan + .collectFirst { case op if !isNativeOrPassThrough(op) => op } + .foreach { op => + fail(s""" + |Found non-native operator: ${op.nodeName} + |plan: + |${plan}""".stripMargin) + } + assert(allCauseMessages(err).toLowerCase.contains("duplicate key")) + } + } + + test("str_to_map last win dedup policy") { + withTable("t1") { + sql("create table t1(c1 string) using parquet") + sql("insert into t1 values ('a:1,b:2,a:3')") + withSQLConf( + SQLConf.MAP_KEY_DEDUP_POLICY.key -> SQLConf.MapKeyDedupPolicy.LAST_WIN.toString) { + checkSparkAnswerAndOperator("select str_to_map(c1) from t1") + } + } + } + test("acosh null propagation") { withTable("t1") { sql("create table t1(c1 double) using parquet") diff --git a/spark-extension/src/main/scala/org/apache/spark/sql/auron/NativeConverters.scala b/spark-extension/src/main/scala/org/apache/spark/sql/auron/NativeConverters.scala index 750aaa524..94a428b35 100644 --- a/spark-extension/src/main/scala/org/apache/spark/sql/auron/NativeConverters.scala +++ b/spark-extension/src/main/scala/org/apache/spark/sql/auron/NativeConverters.scala @@ -945,6 +945,14 @@ object NativeConverters extends Logging { buildExtScalarFunction("Spark_XxHash64", children, LongType) case e: MapFromArrays => buildExtScalarFunction("Spark_MapFromArrays", e.children, e.dataType) + case e: StringToMap => + buildExtScalarFunction( + "Spark_StrToMap", + e.text :: e.pairDelim :: e.keyValueDelim :: Literal + .create( + SQLConf.get.getConf(SQLConf.MAP_KEY_DEDUP_POLICY).toString, + StringType) :: Nil, + e.dataType) case e: Greatest => buildScalarFunction(pb.ScalarFunction.Greatest, e.children, e.dataType) case e: Pow =>