Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
228 changes: 209 additions & 19 deletions datafusion/functions-nested/src/position.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ use datafusion_common::cast::{
use datafusion_common::{Result, exec_err, utils::take_function_args};
use itertools::Itertools;

use crate::utils::{compare_element_to_list, make_scalar_function};
use crate::utils::{compare_element_to_list_fixed, make_scalar_function};

make_udf_expr_and_func!(
ArrayPosition,
Expand Down Expand Up @@ -209,9 +209,15 @@ fn resolve_start_from(
Some(ColumnarValue::Scalar(ScalarValue::Int64(Some(v)))) => {
Ok(vec![v - 1; num_rows])
}
Some(ColumnarValue::Scalar(s)) if s.is_null() => {
exec_err!("array_position index cannot contain nulls")
}
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Technically this was checked by the below arm, but I added this as a more specific error message

Some(ColumnarValue::Scalar(s)) => {
exec_err!("array_position expected Int64 for start_from, got {s}")
}
Some(ColumnarValue::Array(a)) if a.null_count() > 0 => {
exec_err!("array_position index cannot contain nulls")
}
Some(ColumnarValue::Array(a)) => {
Ok(as_int64_array(a)?.values().iter().map(|&x| x - 1).collect())
}
Expand Down Expand Up @@ -306,11 +312,11 @@ fn general_position_dispatch<O: OffsetSizeTrait>(args: &[ArrayRef]) -> Result<Ar
crate::utils::check_datatypes("array_position", &[haystack.values(), needle])?;

let arr_from = if args.len() == 3 {
as_int64_array(&args[2])?
.values()
.iter()
.map(|&x| x - 1)
.collect::<Vec<_>>()
let arr_from = as_int64_array(&args[2])?;
if arr_from.null_count() > 0 {
return exec_err!("array_position index cannot contain nulls");
}
arr_from.values().iter().map(|&x| x - 1).collect::<Vec<_>>()
} else {
vec![0; haystack.len()]
};
Expand All @@ -321,23 +327,27 @@ fn general_position_dispatch<O: OffsetSizeTrait>(args: &[ArrayRef]) -> Result<Ar
}
}

generic_position::<O>(haystack, needle, &arr_from)
if needle.data_type().is_list() {
generic_position::<O, true>(haystack, needle, &arr_from)
} else {
generic_position::<O, false>(haystack, needle, &arr_from)
}
}

fn generic_position<O: OffsetSizeTrait>(
fn generic_position<O: OffsetSizeTrait, const IS_NESTED: bool>(
haystack: &GenericListArray<O>,
needle: &ArrayRef,
arr_from: &[i64], // 0-indexed
) -> Result<ArrayRef> {
let mut data = Vec::with_capacity(haystack.len());

for (row_index, (row, &from)) in haystack.iter().zip(arr_from.iter()).enumerate() {
let from = from as usize;

if let Some(row) = row {
let eq_array = compare_element_to_list(&row, needle, row_index, true)?;
let eq_array =
compare_element_to_list_fixed::<IS_NESTED>(&row, needle, row_index)?;

// Collect `true`s in 1-indexed positions
let from = from as usize;
let index = eq_array
.iter()
.skip(from)
Expand All @@ -363,7 +373,7 @@ make_udf_expr_and_func!(

#[user_doc(
doc_section(label = "Array Functions"),
description = "Searches for an element in the array, returns all occurrences.",
description = "Returns the positions of all occurrences of an element in the array. Returns an empty list `[]` if not found. Comparisons are done using `IS DISTINCT FROM` semantics, so NULL is considered to match NULL. Only returns NULL if the array to search itself is NULL.",
syntax_example = "array_positions(array, element)",
sql_example = r#"```sql
> select array_positions([1, 2, 2, 3, 1, 4], 2);
Expand Down Expand Up @@ -476,14 +486,24 @@ fn try_array_positions_scalar(args: &[ColumnarValue]) -> Result<Option<ColumnarV
fn array_positions_inner(args: &[ArrayRef]) -> Result<ArrayRef> {
let [haystack, needle] = take_function_args("array_positions", args)?;

match &haystack.data_type() {
List(_) => general_positions::<i32>(as_list_array(&haystack)?, needle),
LargeList(_) => general_positions::<i64>(as_large_list_array(&haystack)?, needle),
dt => exec_err!("array_positions does not support type '{dt}'"),
match (haystack.data_type(), needle.data_type().is_list()) {
(List(_), true) => {
general_positions::<i32, true>(as_list_array(&haystack)?, needle)
}
(LargeList(_), true) => {
general_positions::<i64, true>(as_large_list_array(&haystack)?, needle)
}
(List(_), false) => {
general_positions::<i32, false>(as_list_array(&haystack)?, needle)
}
(LargeList(_), false) => {
general_positions::<i64, false>(as_large_list_array(&haystack)?, needle)
}
(dt, _) => exec_err!("array_positions does not support type '{dt}'"),
}
}

fn general_positions<O: OffsetSizeTrait>(
fn general_positions<O: OffsetSizeTrait, const IS_NESTED: bool>(
haystack: &GenericListArray<O>,
needle: &ArrayRef,
) -> Result<ArrayRef> {
Expand All @@ -492,7 +512,8 @@ fn general_positions<O: OffsetSizeTrait>(

for (row_index, row) in haystack.iter().enumerate() {
if let Some(row) = row {
let eq_array = compare_element_to_list(&row, needle, row_index, true)?;
let eq_array =
compare_element_to_list_fixed::<IS_NESTED>(&row, needle, row_index)?;

// Collect `true`s in 1-indexed positions
let indexes = eq_array
Expand Down Expand Up @@ -591,7 +612,7 @@ fn array_positions_scalar<O: OffsetSizeTrait>(
#[cfg(test)]
mod tests {
use super::*;
use arrow::array::AsArray;
use arrow::array::{AsArray, Int32Array, new_empty_array};
use arrow::datatypes::Int32Type;
use datafusion_common::config::ConfigOptions;

Expand Down Expand Up @@ -750,4 +771,173 @@ mod tests {

Ok(())
}

#[test]
fn test_nested_non_empty_null() -> Result<()> {
// Haystack Needle array_position array_positionS
// [[7]] [null] null []
// [[7]] null null []
// [[7], null] [null] null []
// [[7], null] null 2 [2]
// [[7], [null]] [null] 2 [2]
// [[7], [null], null] [null] 2 [2]

// Nulls are not zero sized and have underlying value of 7

// [[7], [7], [7], null, [7], null, [7], [null], [7], [null], null]
let inner = Arc::new(ListArray::new(
Field::new_list_field(DataType::Int32, true).into(),
OffsetBuffer::from_lengths(vec![1; 11]),
Arc::new(Int32Array::new(
vec![7; 11].into(),
Some(
vec![
true, true, true, true, true, true, true, false, true, false,
true,
]
.into(),
),
)),
Some(
vec![
true, true, true, false, true, false, true, true, true, true, false,
]
.into(),
),
));

// [[[7]], [[7]], [[7], null], [[7], null], [[7], [null]], [[7], [null], null]]
let haystack: Arc<dyn Array> = Arc::new(ListArray::new(
Field::new_list_field(inner.data_type().clone(), true).into(),
OffsetBuffer::from_lengths(vec![1, 1, 2, 2, 2, 3]),
inner,
None,
));

// [[null], null, [null], null, [null], [null]]
let needle: Arc<dyn Array> = Arc::new(ListArray::new(
Field::new_list_field(DataType::Int32, true).into(),
OffsetBuffer::from_lengths(vec![1; 6]),
Arc::new(Int32Array::new(
vec![7; 6].into(),
Some(vec![false; 6].into()),
)),
Some(vec![true, false, true, false, true, true].into()),
));

let output = ArrayPosition::new()
.invoke_with_args(ScalarFunctionArgs {
args: vec![
ColumnarValue::Array(Arc::clone(&haystack)),
ColumnarValue::Array(Arc::clone(&needle)),
],
arg_fields: vec![],
number_rows: 0,
return_field: Arc::new(Field::new("", DataType::Null, true)),
config_options: Arc::new(ConfigOptions::default()),
})?
.into_array(9)?;
// [null, null, null, 2, 2, 2]
let expected: Arc<dyn Array> = Arc::new(UInt64Array::from(vec![
None,
None,
None,
Some(2),
Some(2),
Some(2),
]));
assert_eq!(&output, &expected);

let output = ArrayPositions::new()
.invoke_with_args(ScalarFunctionArgs {
args: vec![ColumnarValue::Array(haystack), ColumnarValue::Array(needle)],
arg_fields: vec![],
number_rows: 0,
return_field: Arc::new(Field::new("", DataType::Null, true)),
config_options: Arc::new(ConfigOptions::default()),
})?
.into_array(9)?;
// [[], [], [], [2], [2], [2]]
let expected: Arc<dyn Array> =
Arc::new(ListArray::from_iter_primitive::<UInt64Type, _, _>(vec![
Some(vec![]),
Some(vec![]),
Some(vec![]),
Some(vec![Some(2)]),
Some(vec![Some(2)]),
Some(vec![Some(2)]),
]));
assert_eq!(&output, &expected);

Ok(())
}

#[test]
fn test_nested_empty_list() -> Result<()> {
// Haystack Needle array_position array_positionS
// [[]] null null []
// [[7], []] [] 2 [2]
// [[7], null, []] [] 3 [3]

// [[], [7], [], [7], null, []]
let inner = Arc::new(ListArray::new(
Field::new_list_field(DataType::Int32, true).into(),
OffsetBuffer::from_lengths(vec![0, 1, 0, 1, 0, 0]),
Arc::new(Int32Array::from(vec![7, 7])),
Some(vec![true, true, true, true, false, true].into()),
));

// [[[]], [[7], []], [[7], null, []]]
let haystack: Arc<dyn Array> = Arc::new(ListArray::new(
Field::new_list_field(inner.data_type().clone(), true).into(),
OffsetBuffer::from_lengths(vec![1, 2, 3]),
inner,
None,
));

// [null, [], []]
let needle: Arc<dyn Array> = Arc::new(ListArray::new(
Field::new_list_field(DataType::Int32, true).into(),
OffsetBuffer::from_lengths(vec![0, 0, 0]),
Arc::new(new_empty_array(&DataType::Int32)),
Some(vec![false, true, true].into()),
));

let output = ArrayPosition::new()
.invoke_with_args(ScalarFunctionArgs {
args: vec![
ColumnarValue::Array(Arc::clone(&haystack)),
ColumnarValue::Array(Arc::clone(&needle)),
],
arg_fields: vec![],
number_rows: 0,
return_field: Arc::new(Field::new("", DataType::Null, true)),
config_options: Arc::new(ConfigOptions::default()),
})?
.into_array(9)?;
// [null, 2, 3]
let expected: Arc<dyn Array> =
Arc::new(UInt64Array::from(vec![None, Some(2), Some(3)]));
assert_eq!(&output, &expected);

let output = ArrayPositions::new()
.invoke_with_args(ScalarFunctionArgs {
args: vec![ColumnarValue::Array(haystack), ColumnarValue::Array(needle)],
arg_fields: vec![],
number_rows: 0,
return_field: Arc::new(Field::new("", DataType::Null, true)),
config_options: Arc::new(ConfigOptions::default()),
})?
.into_array(9)?;
// [[], [2], [3]]
let expected: Arc<dyn Array> =
Arc::new(ListArray::from_iter_primitive::<UInt64Type, _, _>(vec![
Some(vec![]),
Some(vec![Some(2)]),
Some(vec![Some(3)]),
]));
assert_eq!(&output, &expected);

Ok(())
}
}
26 changes: 26 additions & 0 deletions datafusion/functions-nested/src/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,12 @@

use std::sync::Arc;

use arrow::compute::SortOptions;
use arrow::datatypes::{DataType, Field, Fields};

use arrow::array::{
Array, ArrayRef, BooleanArray, GenericListArray, OffsetSizeTrait, Scalar,
make_comparator,
};
use arrow::buffer::OffsetBuffer;
use datafusion_common::cast::{
Expand Down Expand Up @@ -220,6 +222,30 @@ pub(crate) fn compare_element_to_list(
Ok(res)
}

/// Given a `haystack` array, and a specific value from `needle` selected by
/// `needle_element_index`, return a `BooleanArray` based on whether the elements
/// in `haystack` match the `needle` value using `IS NOT DISTINCT FROM` semantics.
/// - Allows NULL = NULL to be considered true
pub(crate) fn compare_element_to_list_fixed<const IS_LIST: bool>(
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added a new version as I didn't want to affect array_remove/array_replace yet; I'll tackle them in followups

(this is why I also omitted the eq parameter, as for now this is only used by position anyway)

haystack: &dyn Array,
needle: &dyn Array,
needle_element_index: usize,
) -> Result<BooleanArray> {
if IS_LIST {
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I figured it would be a better idea to not do this datatype check inside, as this function is in the hotloop (called for every row), so pulled it into a const generic

// arrow_ord::cmp::eq does not support ListArray, so we resort to make_comparator
let cmp = make_comparator(haystack, needle, SortOptions::default())?;
let res = (0..haystack.len())
.map(|i| cmp(i, needle_element_index).is_eq())
.collect::<BooleanArray>();
Ok(res)
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Main fix here, using comparator now which handles nulls on both sides properly (null = null is true)

} else {
let needle = needle.slice(needle_element_index, 1);
let needle_value = Scalar::new(needle);
// use not_distinct so we can compare NULL
Ok(arrow_ord::cmp::not_distinct(&haystack, &needle_value)?)
}
}

/// Returns the length of each array dimension
pub(crate) fn compute_array_dims(
arr: Option<ArrayRef>,
Expand Down
Loading
Loading