diff --git a/datafusion/spark/src/function/array/shuffle.rs b/datafusion/spark/src/function/array/shuffle.rs index abeafd3a9366..9f345b53b89a 100644 --- a/datafusion/spark/src/function/array/shuffle.rs +++ b/datafusion/spark/src/function/array/shuffle.rs @@ -15,21 +15,25 @@ // specific language governing permissions and limitations // under the License. -use crate::function::functions_nested_utils::make_scalar_function; use arrow::array::{ Array, ArrayRef, Capacities, FixedSizeListArray, GenericListArray, MutableArrayData, OffsetSizeTrait, }; use arrow::buffer::OffsetBuffer; +use arrow::datatypes::DataType; use arrow::datatypes::DataType::{FixedSizeList, LargeList, List, Null}; -use arrow::datatypes::{DataType, FieldRef}; +use arrow::datatypes::FieldRef; use datafusion_common::cast::{ as_fixed_size_list_array, as_large_list_array, as_list_array, }; -use datafusion_common::{exec_err, utils::take_function_args, Result}; -use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; +use datafusion_common::{exec_err, utils::take_function_args, Result, ScalarValue}; +use datafusion_expr::{ + ArrayFunctionArgument, ArrayFunctionSignature, ColumnarValue, ScalarUDFImpl, + Signature, TypeSignature, Volatility, +}; use rand::rng; -use rand::seq::SliceRandom; +use rand::rngs::StdRng; +use rand::{seq::SliceRandom, Rng, SeedableRng}; use std::any::Any; use std::sync::Arc; @@ -47,7 +51,25 @@ impl Default for SparkShuffle { impl SparkShuffle { pub fn new() -> Self { Self { - signature: Signature::arrays(1, None, Volatility::Volatile), + signature: Signature { + type_signature: TypeSignature::OneOf(vec![ + // Only array argument + TypeSignature::ArraySignature(ArrayFunctionSignature::Array { + arguments: vec![ArrayFunctionArgument::Array], + array_coercion: None, + }), + // Array + Index (seed) argument + TypeSignature::ArraySignature(ArrayFunctionSignature::Array { + arguments: vec![ + ArrayFunctionArgument::Array, + ArrayFunctionArgument::Index, + ], + array_coercion: None, + }), + ]), + volatility: Volatility::Volatile, + parameter_names: None, + }, } } } @@ -73,25 +95,63 @@ impl ScalarUDFImpl for SparkShuffle { &self, args: datafusion_expr::ScalarFunctionArgs, ) -> Result { - make_scalar_function(array_shuffle_inner)(&args.args) + if args.args.is_empty() { + return exec_err!("shuffle expects at least 1 argument"); + } + if args.args.len() > 2 { + return exec_err!("shuffle expects at most 2 arguments"); + } + + // Extract seed from second argument if present + let seed = if args.args.len() == 2 { + extract_seed(&args.args[1])? + } else { + None + }; + + // Convert arguments to arrays + let arrays = ColumnarValue::values_to_arrays(&args.args[..1])?; + array_shuffle_with_seed(&arrays, seed).map(ColumnarValue::Array) + } +} + +/// Extract seed value from ColumnarValue +fn extract_seed(seed_arg: &ColumnarValue) -> Result> { + match seed_arg { + ColumnarValue::Scalar(scalar) => { + let seed = match scalar { + ScalarValue::Int64(Some(v)) => Some(*v as u64), + ScalarValue::Null => None, + _ => { + return exec_err!( + "shuffle seed must be Int64 type, got '{}'", + scalar.data_type() + ); + } + }; + Ok(seed) + } + ColumnarValue::Array(_) => { + exec_err!("shuffle seed must be a scalar value, not an array") + } } } -/// array_shuffle SQL function -pub fn array_shuffle_inner(arg: &[ArrayRef]) -> Result { +/// array_shuffle SQL function with optional seed +fn array_shuffle_with_seed(arg: &[ArrayRef], seed: Option) -> Result { let [input_array] = take_function_args("shuffle", arg)?; match &input_array.data_type() { List(field) => { let array = as_list_array(input_array)?; - general_array_shuffle::(array, field) + general_array_shuffle::(array, field, seed) } LargeList(field) => { let array = as_large_list_array(input_array)?; - general_array_shuffle::(array, field) + general_array_shuffle::(array, field, seed) } FixedSizeList(field, _) => { let array = as_fixed_size_list_array(input_array)?; - fixed_size_array_shuffle(array, field) + fixed_size_array_shuffle(array, field, seed) } Null => Ok(Arc::clone(input_array)), array_type => exec_err!("shuffle does not support type '{array_type}'."), @@ -101,6 +161,7 @@ pub fn array_shuffle_inner(arg: &[ArrayRef]) -> Result { fn general_array_shuffle( array: &GenericListArray, field: &FieldRef, + seed: Option, ) -> Result { let values = array.values(); let original_data = values.to_data(); @@ -109,7 +170,13 @@ fn general_array_shuffle( let mut nulls = vec![]; let mut mutable = MutableArrayData::with_capacities(vec![&original_data], false, capacity); - let mut rng = rng(); + let mut rng = if let Some(s) = seed { + StdRng::seed_from_u64(s) + } else { + // Use a random seed from the thread-local RNG + let seed = rng().random::(); + StdRng::seed_from_u64(seed) + }; for (row_index, offset_window) in array.offsets().windows(2).enumerate() { // skip the null value @@ -149,6 +216,7 @@ fn general_array_shuffle( fn fixed_size_array_shuffle( array: &FixedSizeListArray, field: &FieldRef, + seed: Option, ) -> Result { let values = array.values(); let original_data = values.to_data(); @@ -157,7 +225,13 @@ fn fixed_size_array_shuffle( let mut mutable = MutableArrayData::with_capacities(vec![&original_data], false, capacity); let value_length = array.value_length() as usize; - let mut rng = rng(); + let mut rng = if let Some(s) = seed { + StdRng::seed_from_u64(s) + } else { + // Use a random seed from the thread-local RNG + let seed = rng().random::(); + StdRng::seed_from_u64(seed) + }; for row_index in 0..array.len() { // skip the null value diff --git a/datafusion/sqllogictest/test_files/spark/array/shuffle.slt b/datafusion/sqllogictest/test_files/spark/array/shuffle.slt index 7614caef666b..35aad58144c9 100644 --- a/datafusion/sqllogictest/test_files/spark/array/shuffle.slt +++ b/datafusion/sqllogictest/test_files/spark/array/shuffle.slt @@ -16,27 +16,16 @@ # under the License. # Test shuffle function with simple arrays -query B -SELECT array_sort(shuffle([1, 2, 3, 4, 5, NULL])) = [NULL,1, 2, 3, 4, 5]; ----- -true - -query B -SELECT shuffle([1, 2, 3, 4, 5, NULL]) != [1, 2, 3, 4, 5, NULL]; +query ? +SELECT shuffle([1, 2, 3, 4, 5, NULL], 1); ---- -true +[1, 4, NULL, 2, 5, 3] # Test shuffle function with string arrays - -query B -SELECT array_sort(shuffle(['a', 'b', 'c', 'd', 'e', 'f'])) = ['a', 'b', 'c', 'd', 'e', 'f']; ----- -true - -query B -SELECT shuffle(['a', 'b', 'c', 'd', 'e', 'f']) != ['a', 'b', 'c', 'd', 'e', 'f'];; +query ? +SELECT shuffle(['a', 'b', 'c', 'd', 'e', 'f'], 1); ---- -true +[a, d, f, b, e, c] # Test shuffle function with empty array query ? @@ -57,15 +46,10 @@ SELECT shuffle(NULL); NULL # Test shuffle function with fixed size list arrays -query B -SELECT array_sort(shuffle(arrow_cast([1, 2, NULL, 3, 4, 5], 'FixedSizeList(6, Int64)'))) = [NULL, 1, 2, 3, 4, 5]; ----- -true - -query B -SELECT shuffle(arrow_cast([1, 2, NULL, 3, 4, 5], 'FixedSizeList(6, Int64)')) != [1, 2, NULL, 3, 4, 5]; +query ? +SELECT shuffle(arrow_cast([1, 2, NULL, 3, 4, 5], 'FixedSizeList(6, Int64)'), 1); ---- -true +[1, 3, 5, 2, 4, NULL] # Test shuffle on table data with different list types statement ok @@ -78,10 +62,10 @@ CREATE TABLE test_shuffle_list_types AS VALUES # Test shuffle with large list from table query ? -SELECT array_sort(shuffle(column1)) FROM test_shuffle_list_types; +SELECT shuffle(column1, 1) FROM test_shuffle_list_types; ---- -[1, 2, 3, 4] -[5, 6, 7, 8, 9] +[1, 4, 3, 2] +[8, 9, 6, 5, 7] [10] NULL [] @@ -96,11 +80,11 @@ CREATE TABLE test_shuffle_fixed_size AS VALUES # Test shuffle with fixed size list from table query ? -SELECT array_sort(shuffle(column1)) FROM test_shuffle_fixed_size; +SELECT shuffle(column1, 1) FROM test_shuffle_fixed_size; ---- [1, 2, 3] -[4, 5, 6] -[NULL, 8, 9] +[4, 6, 5] +[9, NULL, 8] NULL # Clean up