Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 46 additions & 8 deletions datafusion/functions-nested/src/extract.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ use arrow::datatypes::{
DataType::{FixedSizeList, LargeList, LargeListView, List, ListView, Null},
Field,
};
use datafusion_common::config::Dialect;
use datafusion_common::cast::as_large_list_array;
use datafusion_common::cast::as_list_array;
use datafusion_common::cast::{
Expand Down Expand Up @@ -170,7 +171,11 @@ impl ScalarUDFImpl for ArrayElement {
}

fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
make_scalar_function(array_element_inner)(&args.args)
let negative_from_end_indexing =
args.config_options.sql_parser.dialect != Dialect::PostgreSQL;
make_scalar_function(move |arrays| {
array_element_inner(arrays, negative_from_end_indexing)
})(&args.args)
}

fn aliases(&self) -> &[String] {
Expand All @@ -189,20 +194,23 @@ impl ScalarUDFImpl for ArrayElement {
///
/// For example:
/// > array_element(\[1, 2, 3], 2) -> 2
fn array_element_inner(args: &[ArrayRef]) -> Result<ArrayRef> {
fn array_element_inner(
args: &[ArrayRef],
negative_from_end_indexing: bool,
) -> Result<ArrayRef> {
let [array, indexes] = take_function_args("array_element", args)?;

match &array.data_type() {
Null => Ok(Arc::new(NullArray::new(array.len()))),
List(_) => {
let array = as_list_array(&array)?;
let indexes = as_int64_array(&indexes)?;
general_array_element::<i32>(array, indexes)
general_array_element::<i32>(array, indexes, negative_from_end_indexing)
}
LargeList(_) => {
let array = as_large_list_array(&array)?;
let indexes = as_int64_array(&indexes)?;
general_array_element::<i64>(array, indexes)
general_array_element::<i64>(array, indexes, negative_from_end_indexing)
}
arg_type => {
exec_err!("array_element does not support type {arg_type}")
Expand All @@ -213,6 +221,7 @@ fn array_element_inner(args: &[ArrayRef]) -> Result<ArrayRef> {
fn general_array_element<O: OffsetSizeTrait>(
array: &GenericListArray<O>,
indexes: &Int64Array,
negative_from_end_indexing: bool,
) -> Result<ArrayRef>
where
i64: TryInto<O>,
Expand All @@ -229,7 +238,11 @@ where
let mut mutable =
MutableArrayData::with_capacities(vec![&original_data], true, capacity);

fn adjusted_array_index<O: OffsetSizeTrait>(index: i64, len: O) -> Result<Option<O>>
fn adjusted_array_index<O: OffsetSizeTrait>(
index: i64,
len: O,
negative_from_end_indexing: bool,
) -> Result<Option<O>>
where
i64: TryInto<O>,
{
Expand All @@ -238,7 +251,12 @@ where
})?;
// 0 ~ len - 1
let adjusted_zero_index = if index < O::usize_as(0) {
index + len
if negative_from_end_indexing {
index + len
} else {
// PostgreSQL does not support negative array subscripts.
return Ok(None);
}
} else {
index - O::usize_as(1)
};
Expand All @@ -262,7 +280,11 @@ where
continue;
}

let index = adjusted_array_index::<O>(indexes.value(row_index), len)?;
let index = adjusted_array_index::<O>(
indexes.value(row_index),
len,
negative_from_end_indexing,
)?;

if let Some(index) = index {
let start = start.as_usize() + index.as_usize();
Expand Down Expand Up @@ -1179,7 +1201,7 @@ mod tests {
let list_array = ListArray::new(field, offsets, values, Some(nulls));
let indexes = Int64Array::from(vec![1, 1, 1]);

let result = general_array_element(&list_array, &indexes)?;
let result = general_array_element(&list_array, &indexes, true)?;

let expected = [
"+--------+",
Expand All @@ -1198,6 +1220,22 @@ mod tests {
Ok(())
}

#[test]
fn test_array_element_postgres_negative_index_returns_null() -> Result<()> {
let values = Arc::new(Int32Array::from(vec![10, 20, 30]));
let offsets = OffsetBuffer::new(ScalarBuffer::from(vec![0, 3]));
let field = Arc::new(Field::new("item", DataType::Int32, true));
let list_array = ListArray::new(field, offsets, values, None);
let indexes = Int64Array::from(vec![-1]);

let result = general_array_element(&list_array, &indexes, false)?;

assert_eq!(result.len(), 1);
assert!(result.is_null(0));

Ok(())
}

#[test]
fn test_array_any_null_handling() -> Result<()> {
let values: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 3, 4, 5]));
Expand Down