Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 69 additions & 10 deletions datafusion/functions-nested/src/resize.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,20 +24,20 @@ use arrow::array::{
};
use arrow::buffer::OffsetBuffer;
use arrow::datatypes::DataType;
use arrow::datatypes::{ArrowNativeType, Field};
use arrow::datatypes::Field;
use arrow::datatypes::{
DataType::{LargeList, List},
FieldRef,
};
use datafusion_common::cast::{as_int64_array, as_large_list_array, as_list_array};
use datafusion_common::utils::ListCoercion;
use datafusion_common::{Result, ScalarValue, exec_err, internal_datafusion_err};
use datafusion_common::{Result, ScalarValue, exec_err};
use datafusion_expr::{
ArrayFunctionArgument, ArrayFunctionSignature, ColumnarValue, Documentation,
ScalarFunctionArgs, ScalarUDFImpl, Signature, TypeSignature, Volatility,
};
use datafusion_macros::user_doc;
use std::sync::Arc;
use std::{mem::size_of, sync::Arc};

make_udf_expr_and_func!(
ArrayResize,
Expand Down Expand Up @@ -206,18 +206,20 @@ fn general_list_resize<O: OffsetSizeTrait + TryInto<i64>>(
if array.is_null(row_index) {
continue;
}
let target_count = count_array.value(row_index).to_usize().ok_or_else(|| {
internal_datafusion_err!("array_resize: failed to convert size to usize")
})?;
let target_count = target_count::<O>(count_array, row_index)?;
output_values_len =
output_values_len.checked_add(target_count).ok_or_else(|| {
internal_datafusion_err!("array_resize: output size overflow")
datafusion_common::DataFusionError::Execution(
"array_resize: target size too large".to_string(),
)
})?;
let current_len = (offset_window[1] - offset_window[0]).to_usize().unwrap();
if target_count > current_len {
max_extra = max_extra.max(target_count - current_len);
}
}
validate_value_capacity(&data_type, output_values_len)?;
validate_value_capacity(&data_type, max_extra)?;

// The fast path is valid when at least one row grows and every row would
// use the same fill value.
Expand Down Expand Up @@ -315,9 +317,7 @@ where
}
null_builder.append_non_null();

let count = count_array.value(row_index).to_usize().ok_or_else(|| {
internal_datafusion_err!("array_resize: failed to convert size to usize")
})?;
let count = target_count::<O>(count_array, row_index)?;
let count = O::usize_as(count);
let start = offset_window[0];
if start + count > offset_window[1] {
Expand All @@ -341,3 +341,62 @@ where
null_builder.finish(),
)?))
}

fn target_count<O: OffsetSizeTrait>(
count_array: &Int64Array,
row_index: usize,
) -> Result<usize> {
let count = count_array.value(row_index);
if count < 0 {
return exec_err!("array_resize: size must be non-negative");
}

let count = count as usize;
if O::from_usize(count).is_none() {
return exec_err!("array_resize: target size too large");
}

Ok(count)
}

fn validate_value_capacity(data_type: &DataType, len: usize) -> Result<()> {
let width = minimum_value_width(data_type);
let Some(byte_len) = len.checked_mul(width) else {
return exec_err!("array_resize: target size too large");
};
if byte_len >= isize::MAX as usize {
return exec_err!("array_resize: target size too large");
}

Ok(())
}

fn minimum_value_width(data_type: &DataType) -> usize {
match data_type {
DataType::Boolean | DataType::Null => 1,
DataType::Utf8 | DataType::Binary | List(_) => size_of::<i32>(),
DataType::LargeUtf8 | DataType::LargeBinary | LargeList(_) => size_of::<i64>(),
_ => data_type.primitive_width().unwrap_or(1).max(1),
}
}

#[cfg(test)]
mod tests {
use super::*;
use arrow::array::ListArray;
use arrow::datatypes::Int64Type;
use datafusion_common::assert_contains;

#[test]
fn array_resize_rejects_target_count_overflow() {
let list = Arc::new(ListArray::from_iter_primitive::<Int64Type, _, _>(vec![
Some(vec![Some(1)]),
])) as ArrayRef;
let count = Arc::new(Int64Array::from(vec![i64::MAX])) as ArrayRef;
let fill = Arc::new(Int64Array::from(vec![0])) as ArrayRef;

let err = array_resize_inner(&[list, count, fill]).unwrap_err();

assert_contains!(err.to_string(), "array_resize: target size too large");
}
}
3 changes: 3 additions & 0 deletions datafusion/sqllogictest/test_files/array/array_resize.slt
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,9 @@ select array_resize(arrow_cast(make_array(1, 2, 3), 'LargeList(Int64)'), 5, 4);
query error
select array_resize(make_array(1, 2, 3), -5, 2);

query error DataFusion error: Execution error: array_resize: target size too large
select array_resize(make_array(1), 9223372036854775807, 0);

# array_resize scalar function #5
query ?
select array_resize(make_array(1.1, 2.2, 3.3), 10, 9.9);
Expand Down