Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 89 additions & 3 deletions arrow-cast/src/cast/dictionary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,6 @@ pub(crate) fn dictionary_cast<K: ArrowDictionaryKeyType>(
}
// `unpack_dictionary` can handle Utf8View/BinaryView types, but incurs unnecessary data
// copy of the value buffer. Fast path which avoids copying underlying values buffer.
// TODO: handle LargeUtf8/LargeBinary -> View (need to check offsets can fit)
// TODO: handle cross types (String -> BinaryView, Binary -> StringView)
// (need to validate utf8?)
(Utf8, Utf8View) => view_from_dict_values::<K, Utf8Type, StringViewType>(
array.keys(),
array.values().as_string::<i32>(),
Expand All @@ -47,6 +44,35 @@ pub(crate) fn dictionary_cast<K: ArrowDictionaryKeyType>(
array.keys(),
array.values().as_binary::<i32>(),
),
// LargeUtf8/LargeBinary -> View: fast path only when i64 offsets fit in u32 (buffer < 4GiB).
// If the buffer is too large, fall back to the general path.
(LargeUtf8, Utf8View) => {
let values = array.values().as_string::<i64>();
if values.values().len() < u32::MAX as usize {
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This check reads a little odd to me as usually this could mean unpack_dictionary may also fail if offsets don't fit?

view_from_dict_values::<K, LargeUtf8Type, StringViewType>(array.keys(), values)
} else {
unpack_dictionary(array, to_type, cast_options)
}
}
(LargeBinary, BinaryView) => {
let values = array.values().as_binary::<i64>();
if values.values().len() < u32::MAX as usize {
view_from_dict_values::<K, LargeBinaryType, BinaryViewType>(array.keys(), values)
} else {
unpack_dictionary(array, to_type, cast_options)
}
}
// Cross casts: Utf8 -> BinaryView is always zero-copy safe (valid UTF-8 is valid binary).
(Utf8, BinaryView) => view_from_dict_values::<K, Utf8Type, BinaryViewType>(
array.keys(),
array.values().as_string::<i32>(),
),
// Cross cast: Binary -> Utf8View requires UTF-8 validation of the dictionary values.
(Binary, Utf8View) => binary_dict_to_string_view::<K>(
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel this arm specifically should be benchmarked as it introduces new logic compared to the other arms

array.keys(),
array.values().as_binary::<i32>(),
cast_options,
),
_ => unpack_dictionary(array, to_type, cast_options),
}
}
Expand Down Expand Up @@ -108,6 +134,66 @@ fn dictionary_to_dictionary_cast<K: ArrowDictionaryKeyType>(
Ok(new_array)
}

/// Cast `Dict<K, Binary>` to `Utf8View`, validating UTF-8 for each dictionary value.
///
/// Fast path when all values are valid UTF-8: reuses the values buffer without copying.
/// When some values are invalid and `cast_options.safe` is true, rows pointing to those
/// values become null. When `cast_options.safe` is false, returns an error immediately.
fn binary_dict_to_string_view<K: ArrowDictionaryKeyType>(
keys: &PrimitiveArray<K>,
values: &GenericByteArray<BinaryType>,
cast_options: &CastOptions,
) -> Result<ArrayRef, ArrowError> {
match GenericStringArray::<i32>::try_from_binary(values.clone()) {
Ok(_) => {
// All dictionary values are valid UTF-8: reuse the buffer zero-copy.
view_from_dict_values::<K, BinaryType, StringViewType>(keys, values)
}
Err(e) => {
if !cast_options.safe {
return Err(e);
}
// safe=true: validate each dictionary value individually so we can nullify
// only the rows whose key points to an invalid UTF-8 value.
let valid: Vec<bool> = (0..values.len())
.map(|i| !values.is_null(i) && std::str::from_utf8(values.value(i)).is_ok())
.collect();

let value_buffer = values.values();
let value_offsets = values.value_offsets();
let mut builder = StringViewBuilder::with_capacity(keys.len());
builder.append_block(value_buffer.clone());

for key in keys.iter() {
match key {
Some(v) => {
let idx = v.to_usize().ok_or_else(|| {
ArrowError::ComputeError("Invalid dictionary index".to_string())
})?;
if valid[idx] {
// Safety:
// (1) idx is a valid index into value_offsets (Arrow invariant)
// (2) offsets are monotonically increasing, so end >= offset
// (3) the slice [offset..end] is within the buffer
// (4) the bytes are valid UTF-8 (checked above for valid[idx])
unsafe {
let offset = value_offsets.get_unchecked(idx).as_usize();
let end = value_offsets.get_unchecked(idx + 1).as_usize();
let length = end - offset;
builder.append_view_unchecked(0, offset as u32, length as u32);
}
} else {
builder.append_null();
}
}
None => builder.append_null(),
}
}
Ok(Arc::new(builder.finish()))
}
}
}

fn view_from_dict_values<K: ArrowDictionaryKeyType, V: ByteArrayType, T: ByteViewType>(
keys: &PrimitiveArray<K>,
values: &GenericByteArray<V>,
Expand Down
152 changes: 152 additions & 0 deletions arrow-cast/src/cast/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7257,6 +7257,158 @@ mod tests {
assert_eq!(casted_binary_array.as_ref(), &binary_view_array);
}

#[test]
fn test_dict_large_utf8_to_utf8view() {
// Dict<Int8, LargeUtf8> -> Utf8View fast path (offsets fit in u32)
let values = LargeStringArray::from(vec![
Some("hello"),
Some("large payload over 12 bytes"),
Some("hello"),
]);
let keys = Int8Array::from_iter([Some(0), Some(1), None, Some(0), Some(1)]);
let dict_array = DictionaryArray::<Int8Type>::try_new(keys, Arc::new(values)).unwrap();

assert!(can_cast_types(dict_array.data_type(), &DataType::Utf8View));
let casted = cast(&dict_array, &DataType::Utf8View).unwrap();
assert_eq!(casted.data_type(), &DataType::Utf8View);

let expected = StringViewArray::from(vec![
Some("hello"),
Some("large payload over 12 bytes"),
None,
Some("hello"),
Some("large payload over 12 bytes"),
]);
assert_eq!(casted.as_ref(), &expected);
}

#[test]
fn test_dict_large_binary_to_binary_view() {
// Dict<Int8, LargeBinary> -> BinaryView fast path (offsets fit in u32)
let mut builder = GenericBinaryBuilder::<i64>::new();
builder.append_value(b"hello");
builder.append_value(b"world");
let values = builder.finish();

let keys = Int8Array::from_iter([Some(0), Some(1), None, Some(0)]);
let dict_array = DictionaryArray::<Int8Type>::try_new(keys, Arc::new(values)).unwrap();

assert!(can_cast_types(
dict_array.data_type(),
&DataType::BinaryView
));
let casted = cast(&dict_array, &DataType::BinaryView).unwrap();
assert_eq!(casted.data_type(), &DataType::BinaryView);

let expected = BinaryViewArray::from_iter(vec![
Some(b"hello".as_slice()),
Some(b"world".as_slice()),
None,
Some(b"hello".as_slice()),
]);
assert_eq!(casted.as_ref(), &expected);
}

#[test]
fn test_dict_utf8_to_binary_view() {
// Dict<Int8, Utf8> -> BinaryView cross cast: UTF-8 strings are always valid binary
let values = StringArray::from(VIEW_TEST_DATA.to_vec());
let keys = Int8Array::from_iter([Some(1), Some(0), None, Some(3), None, Some(1), Some(4)]);
let dict_array = DictionaryArray::<Int8Type>::try_new(keys, Arc::new(values)).unwrap();

assert!(can_cast_types(
dict_array.data_type(),
&DataType::BinaryView
));
let casted = cast(&dict_array, &DataType::BinaryView).unwrap();
assert_eq!(casted.data_type(), &DataType::BinaryView);

let expected = BinaryViewArray::from_iter(vec![
VIEW_TEST_DATA[1],
VIEW_TEST_DATA[0],
None,
VIEW_TEST_DATA[3],
None,
VIEW_TEST_DATA[1],
VIEW_TEST_DATA[4],
]);
assert_eq!(casted.as_ref(), &expected);
}

#[test]
fn test_dict_binary_to_utf8view_valid() {
// Dict<Int8, Binary> -> Utf8View cross cast: all values are valid UTF-8
let values = BinaryArray::from_iter_values([b"hello".as_slice(), b"world", b"foo"]);
let keys = Int8Array::from_iter([Some(0), Some(1), None, Some(0), Some(2)]);
let dict_array = DictionaryArray::<Int8Type>::try_new(keys, Arc::new(values)).unwrap();

assert!(can_cast_types(dict_array.data_type(), &DataType::Utf8View));
let casted = cast(&dict_array, &DataType::Utf8View).unwrap();
assert_eq!(casted.data_type(), &DataType::Utf8View);

let result: Vec<_> = casted.as_string_view().iter().collect();
assert_eq!(
result,
vec![
Some("hello"),
Some("world"),
None,
Some("hello"),
Some("foo")
]
);
}

#[test]
fn test_dict_binary_to_utf8view_invalid_utf8_strict() {
// Dict<Int8, Binary> -> Utf8View with invalid UTF-8: safe=false returns an error
let mut builder = BinaryBuilder::new();
builder.append_value(b"valid");
builder.append_value([0xFF]); // invalid UTF-8
builder.append_value(b"also valid");
let values = builder.finish();

let keys = Int8Array::from_iter([Some(0), Some(1), Some(2)]);
let dict_array = DictionaryArray::<Int8Type>::try_new(keys, Arc::new(values)).unwrap();

let strict = CastOptions {
safe: false,
..Default::default()
};
let err = cast_with_options(&dict_array, &DataType::Utf8View, &strict).unwrap_err();
assert!(
matches!(err, ArrowError::InvalidArgumentError(_)),
"expected InvalidArgumentError, got {err:?}"
);
}

#[test]
fn test_dict_binary_to_utf8view_invalid_utf8_safe() {
// Dict<Int8, Binary> -> Utf8View with invalid UTF-8: safe=true nullifies affected rows
let mut builder = BinaryBuilder::new();
builder.append_value(b"valid");
builder.append_value([0xFF]); // invalid UTF-8 - dict index 1
builder.append_value(b"also valid");
let values = builder.finish();

// keys: 0, 1, 2, 1, 0 -> "valid", INVALID, "also valid", INVALID, "valid"
let keys = Int8Array::from_iter([Some(0), Some(1), Some(2), Some(1), Some(0)]);
let dict_array = DictionaryArray::<Int8Type>::try_new(keys, Arc::new(values)).unwrap();

let safe = CastOptions {
safe: true,
..Default::default()
};
let casted = cast_with_options(&dict_array, &DataType::Utf8View, &safe).unwrap();
assert_eq!(casted.data_type(), &DataType::Utf8View);

let result: Vec<_> = casted.as_string_view().iter().collect();
assert_eq!(
result,
vec![Some("valid"), None, Some("also valid"), None, Some("valid")]
);
}

#[test]
fn test_view_to_dict() {
let string_view_array = StringViewArray::from_iter(VIEW_TEST_DATA);
Expand Down
Loading