Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 23 additions & 3 deletions datafusion/functions-nested/src/extract.rs
Original file line number Diff line number Diff line change
Expand Up @@ -256,8 +256,8 @@ where
let end = offset_window[1];
let len = end - start;

// array is null
if array.is_null(row_index) {
// array or index is null
if array.is_null(row_index) || indexes.is_null(row_index) {
mutable.extend_nulls(1);
continue;
}
Expand Down Expand Up @@ -1107,7 +1107,7 @@ mod tests {
};
use arrow::array::{ListArray, RecordBatch};
use arrow::buffer::{NullBuffer, OffsetBuffer, ScalarBuffer};
use arrow::datatypes::{DataType, Field};
use arrow::datatypes::{DataType, Field, Int32Type};
use datafusion_common::{Column, DFSchema, Result, assert_batches_eq};
use datafusion_expr::expr::ScalarFunction;
use datafusion_expr::{Expr, ExprSchemable};
Expand Down Expand Up @@ -1198,6 +1198,26 @@ mod tests {
Ok(())
}

#[test]
fn test_array_element_null_index_with_non_zero_buffer_returns_null() -> Result<()> {
let list_array = ListArray::from_iter_primitive::<Int32Type, _, _>(vec![
Some(vec![Some(1), Some(2), Some(3)]),
Some(vec![Some(4)]),
Some(vec![Some(5)]),
]);
let indexes = Int64Array::new(
ScalarBuffer::from(vec![1, 1, 1]),
Some(NullBuffer::from(vec![true, false, true])),
);

let result = general_array_element(&list_array, &indexes)?;
let expected = Int32Array::from(vec![Some(1), None, Some(5)]);

assert_eq!(result.as_primitive::<Int32Type>(), &expected);

Ok(())
}

#[test]
fn test_array_any_null_handling() -> Result<()> {
let values: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 3, 4, 5]));
Expand Down
130 changes: 89 additions & 41 deletions datafusion/functions-nested/src/remove.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,10 @@
use crate::utils;
use crate::utils::make_scalar_function;
use arrow::array::{
Array, ArrayRef, Capacities, GenericListArray, MutableArrayData, OffsetSizeTrait,
cast::AsArray, make_array,
Array, ArrayRef, Capacities, GenericListArray, MutableArrayData, NullBufferBuilder,
OffsetSizeTrait, cast::AsArray, make_array,
};
use arrow::buffer::{NullBuffer, OffsetBuffer};
use arrow::buffer::OffsetBuffer;
use arrow::datatypes::{DataType, FieldRef};
use datafusion_common::cast::as_int64_array;
use datafusion_common::utils::ListCoercion;
Expand Down Expand Up @@ -210,7 +210,9 @@ impl ScalarUDFImpl for ArrayRemoveN {
&self,
args: datafusion_expr::ReturnFieldArgs,
) -> Result<FieldRef> {
Ok(Arc::clone(&args.arg_fields[0]))
let array_field = args.arg_fields[0].as_ref().clone();
let nullable = args.arg_fields.iter().any(|f| f.is_nullable());
Ok(Arc::new(array_field.with_nullable(nullable)))
Comment on lines +213 to +215
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems like we want the same change to return_field_from_args for array_remove and array_remove_all? Looking around, it seems like array_any_match has a similar bug. Can you fix these, either as part of this PR or in a separate PR?

}

fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
Expand Down Expand Up @@ -319,28 +321,28 @@ impl ScalarUDFImpl for ArrayRemoveAll {
fn array_remove_inner(args: &[ArrayRef]) -> Result<ArrayRef> {
let [array, element] = take_function_args("array_remove", args)?;

let arr_n = vec![1; array.len()];
let arr_n = vec![Some(1); array.len()];
array_remove_internal(array, element, &arr_n)
}

fn array_remove_n_inner(args: &[ArrayRef]) -> Result<ArrayRef> {
let [array, element, max] = take_function_args("array_remove_n", args)?;

let arr_n = as_int64_array(max)?.values().to_vec();
let arr_n = as_int64_array(max)?.iter().collect::<Vec<_>>();
array_remove_internal(array, element, &arr_n)
}

fn array_remove_all_inner(args: &[ArrayRef]) -> Result<ArrayRef> {
let [array, element] = take_function_args("array_remove_all", args)?;

let arr_n = vec![i64::MAX; array.len()];
let arr_n = vec![Some(i64::MAX); array.len()];
array_remove_internal(array, element, &arr_n)
}

fn array_remove_internal(
array: &ArrayRef,
element_array: &ArrayRef,
arr_n: &[i64],
arr_n: &[Option<i64>],
) -> Result<ArrayRef> {
match array.data_type() {
DataType::List(_) => {
Expand Down Expand Up @@ -377,7 +379,7 @@ fn array_remove_internal(
fn general_remove<OffsetSize: OffsetSizeTrait>(
list_array: &GenericListArray<OffsetSize>,
element_array: &ArrayRef,
arr_n: &[i64],
arr_n: &[Option<i64>],
) -> Result<ArrayRef> {
let list_field = match list_array.data_type() {
DataType::List(field) | DataType::LargeList(field) => field,
Expand All @@ -398,20 +400,23 @@ fn general_remove<OffsetSize: OffsetSizeTrait>(
false,
Capacities::Array(original_data.len()),
);

// Pre-compute combined null bitmap
let nulls = NullBuffer::union(list_array.nulls(), element_array.nulls());
let mut valid = NullBufferBuilder::new(list_array.len());

for (row_index, offset_window) in list_array.offsets().windows(2).enumerate() {
if nulls.as_ref().is_some_and(|nulls| nulls.is_null(row_index)) {
if list_array.is_null(row_index) || element_array.is_null(row_index) {
offsets.push(offsets[row_index]);
valid.append_null();
continue;
}

let Some(n) = arr_n[row_index] else {
offsets.push(offsets[row_index]);
valid.append_null();
continue;
};

let start = offset_window[0].to_usize().unwrap();
let end = offset_window[1].to_usize().unwrap();
// n is the number of elements to remove in this row
let n = arr_n[row_index];

// compare each element in the list, `false` means the element matches and should be removed
let eq_array = utils::compare_element_to_list(
Expand All @@ -427,6 +432,7 @@ fn general_remove<OffsetSize: OffsetSizeTrait>(
if num_to_remove == 0 {
mutable.extend(0, start, end);
offsets.push(offsets[row_index] + OffsetSize::usize_as(end - start));
valid.append_non_null();
continue;
}

Expand Down Expand Up @@ -457,23 +463,26 @@ fn general_remove<OffsetSize: OffsetSizeTrait>(
}

offsets.push(offsets[row_index] + OffsetSize::usize_as(copied));
valid.append_non_null();
}

let new_values = make_array(mutable.freeze());
Ok(Arc::new(GenericListArray::<OffsetSize>::try_new(
Arc::clone(list_field),
OffsetBuffer::new(offsets.into()),
new_values,
nulls,
valid.finish(),
)?))
}

#[cfg(test)]
mod tests {
use crate::remove::{ArrayRemove, ArrayRemoveAll, ArrayRemoveN};
use arrow::array::{
Array, ArrayRef, AsArray, GenericListArray, ListArray, OffsetSizeTrait,
Array, ArrayRef, AsArray, GenericListArray, Int32Array, Int64Array, ListArray,
OffsetSizeTrait,
};
use arrow::buffer::{NullBuffer, ScalarBuffer};
use arrow::datatypes::{DataType, Field, Int32Type};
use datafusion_common::ScalarValue;
use datafusion_expr::{ReturnFieldArgs, ScalarFunctionArgs, ScalarUDFImpl};
Expand Down Expand Up @@ -512,30 +521,47 @@ mod tests {
fn test_array_remove_n_nullability() {
for nullability in [true, false] {
for item_nullability in [true, false] {
let input_field = Arc::new(Field::new(
"num",
DataType::new_list(DataType::Int32, item_nullability),
nullability,
));
let args_fields = vec![
Arc::clone(&input_field),
Arc::new(Field::new("a", DataType::Int32, false)),
Arc::new(Field::new("b", DataType::Int64, false)),
];
let scalar_args = vec![
None,
Some(&ScalarValue::Int32(Some(1))),
Some(&ScalarValue::Int64(Some(1))),
];

let result = ArrayRemoveN::new()
.return_field_from_args(ReturnFieldArgs {
arg_fields: &args_fields,
scalar_arguments: &scalar_args,
})
.unwrap();

assert_eq!(result, input_field);
for element_nullability in [true, false] {
for count_nullability in [true, false] {
Comment on lines +524 to +525
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You could potentially clean up the deep nesting with iproduct!:

    for (array_nullable, item_nullable, element_nullable, count_nullable) in
        iproduct!(bools, bools, bools, bools)

But feel free to ignore if you think that's too clever.

let input_field = Arc::new(Field::new(
"num",
DataType::new_list(DataType::Int32, item_nullability),
nullability,
));
let args_fields = vec![
Arc::clone(&input_field),
Arc::new(Field::new(
"a",
DataType::Int32,
element_nullability,
)),
Arc::new(Field::new("b", DataType::Int64, count_nullability)),
];
let scalar_args = vec![
None,
Some(&ScalarValue::Int32(Some(1))),
Some(&ScalarValue::Int64(Some(1))),
];

let result = ArrayRemoveN::new()
.return_field_from_args(ReturnFieldArgs {
arg_fields: &args_fields,
scalar_arguments: &scalar_args,
})
.unwrap();

let expected_nullable =
nullability || element_nullability || count_nullability;
let expected = Arc::new(
input_field
.as_ref()
.clone()
.with_nullable(expected_nullable),
);

assert_eq!(result, expected);
}
}
}
}
}
Expand Down Expand Up @@ -734,6 +760,28 @@ mod tests {
assert_array_remove_n(input_list, expected_list, element_to_remove, 2);
}

#[test]
fn test_array_remove_n_null_count_returns_null() {
let array: ArrayRef =
Arc::new(ListArray::from_iter_primitive::<Int32Type, _, _>(vec![
Some(vec![Some(1), Some(2), Some(2)]),
Some(vec![Some(4), Some(2)]),
]));
let element: ArrayRef = Arc::new(Int32Array::from(vec![2, 2]));
let max: ArrayRef = Arc::new(Int64Array::new(
ScalarBuffer::from(vec![1, 1]),
Some(NullBuffer::from(vec![true, false])),
));

let result = super::array_remove_n_inner(&[array, element, max]).unwrap();
let expected = ListArray::from_iter_primitive::<Int32Type, _, _>(vec![
Some(vec![Some(1), Some(2)]),
None,
]);

assert_eq!(result.as_list::<i32>(), &expected);
}

fn assert_array_remove_n(
input_list: ArrayRef,
expected_list: GenericListArray<i32>,
Expand Down
Loading
Loading