Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
102 changes: 88 additions & 14 deletions datafusion/spark/src/function/array/shuffle.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,21 +15,25 @@
// specific language governing permissions and limitations
// under the License.

use crate::function::functions_nested_utils::make_scalar_function;
use arrow::array::{
Array, ArrayRef, Capacities, FixedSizeListArray, GenericListArray, MutableArrayData,
OffsetSizeTrait,
};
use arrow::buffer::OffsetBuffer;
use arrow::datatypes::DataType;
use arrow::datatypes::DataType::{FixedSizeList, LargeList, List, Null};
use arrow::datatypes::{DataType, FieldRef};
use arrow::datatypes::FieldRef;
use datafusion_common::cast::{
as_fixed_size_list_array, as_large_list_array, as_list_array,
};
use datafusion_common::{exec_err, utils::take_function_args, Result};
use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility};
use datafusion_common::{exec_err, utils::take_function_args, Result, ScalarValue};
use datafusion_expr::{
ArrayFunctionArgument, ArrayFunctionSignature, ColumnarValue, ScalarUDFImpl,
Signature, TypeSignature, Volatility,
};
use rand::rng;
use rand::seq::SliceRandom;
use rand::rngs::StdRng;
use rand::{seq::SliceRandom, Rng, SeedableRng};
use std::any::Any;
use std::sync::Arc;

Expand All @@ -47,7 +51,25 @@ impl Default for SparkShuffle {
impl SparkShuffle {
pub fn new() -> Self {
Self {
signature: Signature::arrays(1, None, Volatility::Volatile),
signature: Signature {
type_signature: TypeSignature::OneOf(vec![
// Only array argument
TypeSignature::ArraySignature(ArrayFunctionSignature::Array {
arguments: vec![ArrayFunctionArgument::Array],
array_coercion: None,
}),
// Array + Index (seed) argument
TypeSignature::ArraySignature(ArrayFunctionSignature::Array {
arguments: vec![
ArrayFunctionArgument::Array,
ArrayFunctionArgument::Index,
],
array_coercion: None,
}),
]),
volatility: Volatility::Volatile,
parameter_names: None,
},
}
}
}
Expand All @@ -73,25 +95,63 @@ impl ScalarUDFImpl for SparkShuffle {
&self,
args: datafusion_expr::ScalarFunctionArgs,
) -> Result<ColumnarValue> {
make_scalar_function(array_shuffle_inner)(&args.args)
if args.args.is_empty() {
return exec_err!("shuffle expects at least 1 argument");
}
if args.args.len() > 2 {
return exec_err!("shuffle expects at most 2 arguments");
}

// Extract seed from second argument if present
let seed = if args.args.len() == 2 {
extract_seed(&args.args[1])?
} else {
None
};

// Convert arguments to arrays
let arrays = ColumnarValue::values_to_arrays(&args.args[..1])?;
array_shuffle_with_seed(&arrays, seed).map(ColumnarValue::Array)
}
}

/// Extract seed value from ColumnarValue
fn extract_seed(seed_arg: &ColumnarValue) -> Result<Option<u64>> {
match seed_arg {
ColumnarValue::Scalar(scalar) => {
let seed = match scalar {
ScalarValue::Int64(Some(v)) => Some(*v as u64),
ScalarValue::Null => None,
_ => {
return exec_err!(
"shuffle seed must be Int64 type, got '{}'",
scalar.data_type()
);
}
};
Ok(seed)
}
ColumnarValue::Array(_) => {
exec_err!("shuffle seed must be a scalar value, not an array")
}
}
}

/// array_shuffle SQL function
pub fn array_shuffle_inner(arg: &[ArrayRef]) -> Result<ArrayRef> {
/// array_shuffle SQL function with optional seed
fn array_shuffle_with_seed(arg: &[ArrayRef], seed: Option<u64>) -> Result<ArrayRef> {
let [input_array] = take_function_args("shuffle", arg)?;
match &input_array.data_type() {
List(field) => {
let array = as_list_array(input_array)?;
general_array_shuffle::<i32>(array, field)
general_array_shuffle::<i32>(array, field, seed)
}
LargeList(field) => {
let array = as_large_list_array(input_array)?;
general_array_shuffle::<i64>(array, field)
general_array_shuffle::<i64>(array, field, seed)
}
FixedSizeList(field, _) => {
let array = as_fixed_size_list_array(input_array)?;
fixed_size_array_shuffle(array, field)
fixed_size_array_shuffle(array, field, seed)
}
Null => Ok(Arc::clone(input_array)),
array_type => exec_err!("shuffle does not support type '{array_type}'."),
Expand All @@ -101,6 +161,7 @@ pub fn array_shuffle_inner(arg: &[ArrayRef]) -> Result<ArrayRef> {
fn general_array_shuffle<O: OffsetSizeTrait>(
array: &GenericListArray<O>,
field: &FieldRef,
seed: Option<u64>,
) -> Result<ArrayRef> {
let values = array.values();
let original_data = values.to_data();
Expand All @@ -109,7 +170,13 @@ fn general_array_shuffle<O: OffsetSizeTrait>(
let mut nulls = vec![];
let mut mutable =
MutableArrayData::with_capacities(vec![&original_data], false, capacity);
let mut rng = rng();
let mut rng = if let Some(s) = seed {
StdRng::seed_from_u64(s)
} else {
// Use a random seed from the thread-local RNG
let seed = rng().random::<u64>();
StdRng::seed_from_u64(seed)
};

for (row_index, offset_window) in array.offsets().windows(2).enumerate() {
// skip the null value
Expand Down Expand Up @@ -149,6 +216,7 @@ fn general_array_shuffle<O: OffsetSizeTrait>(
fn fixed_size_array_shuffle(
array: &FixedSizeListArray,
field: &FieldRef,
seed: Option<u64>,
) -> Result<ArrayRef> {
let values = array.values();
let original_data = values.to_data();
Expand All @@ -157,7 +225,13 @@ fn fixed_size_array_shuffle(
let mut mutable =
MutableArrayData::with_capacities(vec![&original_data], false, capacity);
let value_length = array.value_length() as usize;
let mut rng = rng();
let mut rng = if let Some(s) = seed {
StdRng::seed_from_u64(s)
} else {
// Use a random seed from the thread-local RNG
let seed = rng().random::<u64>();
StdRng::seed_from_u64(seed)
};

for row_index in 0..array.len() {
// skip the null value
Expand Down
46 changes: 15 additions & 31 deletions datafusion/sqllogictest/test_files/spark/array/shuffle.slt
Original file line number Diff line number Diff line change
Expand Up @@ -16,27 +16,16 @@
# under the License.

# Test shuffle function with simple arrays
query B
SELECT array_sort(shuffle([1, 2, 3, 4, 5, NULL])) = [NULL,1, 2, 3, 4, 5];
----
true

query B
SELECT shuffle([1, 2, 3, 4, 5, NULL]) != [1, 2, 3, 4, 5, NULL];
query ?
SELECT shuffle([1, 2, 3, 4, 5, NULL], 1);
----
true
[1, 4, NULL, 2, 5, 3]

# Test shuffle function with string arrays

query B
SELECT array_sort(shuffle(['a', 'b', 'c', 'd', 'e', 'f'])) = ['a', 'b', 'c', 'd', 'e', 'f'];
----
true

query B
SELECT shuffle(['a', 'b', 'c', 'd', 'e', 'f']) != ['a', 'b', 'c', 'd', 'e', 'f'];;
query ?
SELECT shuffle(['a', 'b', 'c', 'd', 'e', 'f'], 1);
----
true
[a, d, f, b, e, c]

# Test shuffle function with empty array
query ?
Expand All @@ -57,15 +46,10 @@ SELECT shuffle(NULL);
NULL

# Test shuffle function with fixed size list arrays
query B
SELECT array_sort(shuffle(arrow_cast([1, 2, NULL, 3, 4, 5], 'FixedSizeList(6, Int64)'))) = [NULL, 1, 2, 3, 4, 5];
----
true

query B
SELECT shuffle(arrow_cast([1, 2, NULL, 3, 4, 5], 'FixedSizeList(6, Int64)')) != [1, 2, NULL, 3, 4, 5];
query ?
SELECT shuffle(arrow_cast([1, 2, NULL, 3, 4, 5], 'FixedSizeList(6, Int64)'), 1);
----
true
[1, 3, 5, 2, 4, NULL]

# Test shuffle on table data with different list types
statement ok
Expand All @@ -78,10 +62,10 @@ CREATE TABLE test_shuffle_list_types AS VALUES

# Test shuffle with large list from table
query ?
SELECT array_sort(shuffle(column1)) FROM test_shuffle_list_types;
SELECT shuffle(column1, 1) FROM test_shuffle_list_types;
----
[1, 2, 3, 4]
[5, 6, 7, 8, 9]
[1, 4, 3, 2]
[8, 9, 6, 5, 7]
[10]
NULL
[]
Expand All @@ -96,11 +80,11 @@ CREATE TABLE test_shuffle_fixed_size AS VALUES

# Test shuffle with fixed size list from table
query ?
SELECT array_sort(shuffle(column1)) FROM test_shuffle_fixed_size;
SELECT shuffle(column1, 1) FROM test_shuffle_fixed_size;
----
[1, 2, 3]
[4, 5, 6]
[NULL, 8, 9]
[4, 6, 5]
[9, NULL, 8]
NULL

# Clean up
Expand Down