Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
191 changes: 191 additions & 0 deletions datafusion/spark/src/function/array/array_contains.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,191 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

use arrow::array::{Array, AsArray, BooleanArray, BooleanBufferBuilder};
use arrow::buffer::{BooleanBuffer, NullBuffer};
use arrow::datatypes::DataType;
use datafusion_common::Result;
use datafusion_expr::{
ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility,
};
use datafusion_functions_nested::array_has::array_has_udf;
use std::any::Any;
use std::sync::Arc;

/// Spark-compatible `array_contains` function.
///
/// Calls DataFusion's `array_has` and then applies Spark's null semantics:
/// - If the result from `array_has` is `true`, return `true`.
/// - If the result is `false` and the input array row contains any null elements,
/// return `null` (because the element might have been the null).
/// - If the result is `false` and the input array row has no null elements,
/// return `false`.
#[derive(Debug, PartialEq, Eq, Hash)]
pub struct SparkArrayContains {
signature: Signature,
}

impl Default for SparkArrayContains {
fn default() -> Self {
Self::new()
}
}

impl SparkArrayContains {
pub fn new() -> Self {
Self {
signature: Signature::array_and_element(Volatility::Immutable),
}
}
}

impl ScalarUDFImpl for SparkArrayContains {
fn as_any(&self) -> &dyn Any {
self
}

fn name(&self) -> &str {
"array_contains"
}

fn signature(&self) -> &Signature {
&self.signature
}

fn return_type(&self, _: &[DataType]) -> Result<DataType> {
Ok(DataType::Boolean)
}

fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
let haystack = args.args[0].clone();
let array_has_result = array_has_udf().invoke_with_args(args)?;

let result_array = array_has_result.to_array(1)?;
let patched = apply_spark_null_semantics(result_array.as_boolean(), &haystack)?;
Ok(ColumnarValue::Array(Arc::new(patched)))
}
}

/// For each row where `array_has` returned `false`, set the output to null
/// if that row's input array contains any null elements.
fn apply_spark_null_semantics(
result: &BooleanArray,
haystack_arg: &ColumnarValue,
) -> Result<BooleanArray> {
let haystack = match haystack_arg {
ColumnarValue::Array(arr) => Arc::clone(arr),
ColumnarValue::Scalar(s) => s.to_array_of_size(result.len())?,
};

if haystack.data_type() == &DataType::Null {
return Ok(result.clone());
}

// If every result is already true or null, nothing to nullify.
if result.false_count() == 0 {
return Ok(result.clone());
}

// Build a per-row bitmap: true if that row's list contains any null element.
// Works directly on the values' null bitmap + offsets, no per-row array allocation.
let row_has_nulls = match haystack.data_type() {
DataType::List(_) => {
let list = haystack.as_list::<i32>();
if list.values().null_count() == 0 {
return Ok(result.clone());
}
build_row_has_nulls(
list.values().nulls().unwrap().inner(),
list.offsets().iter().map(|o| *o as usize),
list.len(),
list.nulls(),
)
}
DataType::LargeList(_) => {
let list = haystack.as_list::<i64>();
if list.values().null_count() == 0 {
return Ok(result.clone());
}
build_row_has_nulls(
list.values().nulls().unwrap().inner(),
list.offsets().iter().map(|o| *o as usize),
list.len(),
list.nulls(),
)
}
DataType::FixedSizeList(_, _) => {
let list = haystack.as_fixed_size_list();
if list.values().null_count() == 0 {
return Ok(result.clone());
}
let vl = list.value_length() as usize;
build_row_has_nulls(
list.values().nulls().unwrap().inner(),
(0..=list.len()).map(|i| i * vl),
list.len(),
list.nulls(),
)
}
_ => return Ok(result.clone()),
};

// A row should be nullified when: result is false AND row has nulls.
// nullify_mask = !result_values & row_has_nulls
// new_validity = old_validity & !nullify_mask
let nullify_mask = &(!result.values()) & &row_has_nulls;
let old_validity = match result.nulls() {
Some(n) => n.inner().clone(),
None => BooleanBuffer::new_set(result.len()),
};
let new_validity = &old_validity & &(!&nullify_mask);

Ok(BooleanArray::new(
result.values().clone(),
Some(NullBuffer::new(new_validity)),
))
}

/// Build a BooleanBuffer where bit `i` is set if list row `i` contains
/// any null element, using the flat values null bitmap and row offsets.
fn build_row_has_nulls<I>(
values_validity: &BooleanBuffer,
offsets: I,
num_rows: usize,
list_nulls: Option<&NullBuffer>,
) -> BooleanBuffer
where
I: Iterator<Item = usize>,
{
let mut builder = BooleanBufferBuilder::new(num_rows);
let mut prev = None;
for offset in offsets {
if let Some(start) = prev {
let len = offset - start;
// count_set_bits on the validity bitmap = number of valid (non-null) elements
let valid_count = values_validity.slice(start, len).count_set_bits();
builder.append(valid_count < len);
}
prev = Some(offset);
}

let buf = builder.finish();
// Mask out rows where the list itself is null (those shouldn't count as "has nulls").
match list_nulls {
Some(n) => &buf & n.inner(),
None => buf,
}
}
15 changes: 14 additions & 1 deletion datafusion/spark/src/function/array/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
// specific language governing permissions and limitations
// under the License.

pub mod array_contains;
pub mod repeat;
pub mod shuffle;
pub mod slice;
Expand All @@ -24,6 +25,7 @@ use datafusion_expr::ScalarUDF;
use datafusion_functions::make_udf_function;
use std::sync::Arc;

make_udf_function!(array_contains::SparkArrayContains, spark_array_contains);
make_udf_function!(spark_array::SparkArray, array);
make_udf_function!(shuffle::SparkShuffle, shuffle);
make_udf_function!(repeat::SparkArrayRepeat, array_repeat);
Expand All @@ -32,6 +34,11 @@ make_udf_function!(slice::SparkSlice, slice);
pub mod expr_fn {
use datafusion_functions::export_functions;

export_functions!((
spark_array_contains,
"Returns true if the array contains the element (Spark semantics).",
array element
));
export_functions!((array, "Returns an array with the given elements.", args));
export_functions!((
shuffle,
Expand All @@ -51,5 +58,11 @@ pub mod expr_fn {
}

pub fn functions() -> Vec<Arc<ScalarUDF>> {
vec![array(), shuffle(), array_repeat(), slice()]
vec![
spark_array_contains(),
array(),
shuffle(),
array_repeat(),
slice(),
]
}
140 changes: 140 additions & 0 deletions datafusion/sqllogictest/test_files/spark/array/array_contains.slt
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at

# http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

# Tests for Spark-compatible array_contains function.
# Spark semantics: if element is found -> true; if not found and array has nulls -> null; if not found and no nulls -> false.

###
### Scalar tests
###

# Element found in array
query B
SELECT array_contains(array(1, 2, 3), 2);
----
true

# Element not found, no nulls in array
query B
SELECT array_contains(array(1, 2, 3), 4);
----
false

# Element not found, array has null elements -> null
query B
SELECT array_contains(array(1, NULL, 3), 2);
----
NULL

# Element found, array has null elements -> true (nulls don't matter)
query B
SELECT array_contains(array(1, NULL, 3), 1);
----
true

# Element found at the end, array has null elements -> true
query B
SELECT array_contains(array(1, NULL, 3), 3);
----
true

# Null array -> null
query B
SELECT array_contains(NULL, 1);
----
NULL

# Null element -> null
query B
SELECT array_contains(array(1, 2, 3), NULL);
----
NULL

# Empty array, element not found -> false
query B
SELECT array_contains(array(), 1);
----
false

# Array with only nulls, element not found -> null
query B
SELECT array_contains(array(NULL, NULL), 1);
----
NULL

# String array, element found
query B
SELECT array_contains(array('a', 'b', 'c'), 'b');
----
true

# String array, element not found, no nulls
query B
SELECT array_contains(array('a', 'b', 'c'), 'd');
----
false

# String array, element not found, has null
query B
SELECT array_contains(array('a', NULL, 'c'), 'd');
----
NULL

###
### Columnar tests with a table
###

statement ok
CREATE TABLE test_arrays AS VALUES
(1, make_array(1, 2, 3), 10),
(2, make_array(4, NULL, 6), 5),
(3, make_array(7, 8, 9), 10),
(4, NULL, 1),
(5, make_array(10, NULL, NULL), 10);

# Column needle against column array
query IBB
SELECT column1,
array_contains(column2, column3),
array_contains(column2, 10)
FROM test_arrays
ORDER BY column1;
----
1 false false
2 NULL NULL
3 false false
4 NULL NULL
5 true true

statement ok
DROP TABLE test_arrays;

###
### Nested array tests
###

# Nested array element found
query B
SELECT array_contains(array(array(1, 2), array(3, 4)), array(3, 4));
----
true

# Nested array element not found, no nulls
query B
SELECT array_contains(array(array(1, 2), array(3, 4)), array(5, 6));
----
false