From cdfedb447bf26edb88f893e73daf1aec0dbe246b Mon Sep 17 00:00:00 2001 From: Abhisheklearn12 Date: Sun, 19 Apr 2026 17:26:50 +0530 Subject: [PATCH 1/3] feat(arrow-cast): fast path for Dictionary->View cast for large types and cross cast --- arrow-cast/src/cast/dictionary.rs | 94 +++++++++++++++++- arrow-cast/src/cast/mod.rs | 152 ++++++++++++++++++++++++++++++ 2 files changed, 243 insertions(+), 3 deletions(-) diff --git a/arrow-cast/src/cast/dictionary.rs b/arrow-cast/src/cast/dictionary.rs index 0092bc0c87dd..f46dc5c926eb 100644 --- a/arrow-cast/src/cast/dictionary.rs +++ b/arrow-cast/src/cast/dictionary.rs @@ -36,9 +36,6 @@ pub(crate) fn dictionary_cast( } // `unpack_dictionary` can handle Utf8View/BinaryView types, but incurs unnecessary data // copy of the value buffer. Fast path which avoids copying underlying values buffer. - // TODO: handle LargeUtf8/LargeBinary -> View (need to check offsets can fit) - // TODO: handle cross types (String -> BinaryView, Binary -> StringView) - // (need to validate utf8?) (Utf8, Utf8View) => view_from_dict_values::( array.keys(), array.values().as_string::(), @@ -47,6 +44,35 @@ pub(crate) fn dictionary_cast( array.keys(), array.values().as_binary::(), ), + // LargeUtf8/LargeBinary -> View: fast path only when i64 offsets fit in u32 (buffer < 4GiB). + // If the buffer is too large, fall back to the general path. + (LargeUtf8, Utf8View) => { + let values = array.values().as_string::(); + if values.values().len() < u32::MAX as usize { + view_from_dict_values::(array.keys(), values) + } else { + unpack_dictionary(array, to_type, cast_options) + } + } + (LargeBinary, BinaryView) => { + let values = array.values().as_binary::(); + if values.values().len() < u32::MAX as usize { + view_from_dict_values::(array.keys(), values) + } else { + unpack_dictionary(array, to_type, cast_options) + } + } + // Cross casts: Utf8 -> BinaryView is always zero-copy safe (valid UTF-8 is valid binary). + (Utf8, BinaryView) => view_from_dict_values::( + array.keys(), + array.values().as_string::(), + ), + // Cross cast: Binary -> Utf8View requires UTF-8 validation of the dictionary values. + (Binary, Utf8View) => binary_dict_to_string_view::( + array.keys(), + array.values().as_binary::(), + cast_options, + ), _ => unpack_dictionary(array, to_type, cast_options), } } @@ -108,6 +134,68 @@ fn dictionary_to_dictionary_cast( Ok(new_array) } +/// Cast `Dict` to `Utf8View`, validating UTF-8 for each dictionary value. +/// +/// Fast path when all values are valid UTF-8: reuses the values buffer without copying. +/// When some values are invalid and `cast_options.safe` is true, rows pointing to those +/// values become null. When `cast_options.safe` is false, returns an error immediately. +fn binary_dict_to_string_view( + keys: &PrimitiveArray, + values: &GenericByteArray, + cast_options: &CastOptions, +) -> Result { + match GenericStringArray::::try_from_binary(values.clone()) { + Ok(_) => { + // All dictionary values are valid UTF-8: reuse the buffer zero-copy. + view_from_dict_values::(keys, values) + } + Err(e) => { + if !cast_options.safe { + return Err(ArrowError::CastError(format!( + "Cannot cast binary dictionary to Utf8View: {e}" + ))); + } + // safe=true: validate each dictionary value individually so we can nullify + // only the rows whose key points to an invalid UTF-8 value. + let valid: Vec = (0..values.len()) + .map(|i| !values.is_null(i) && std::str::from_utf8(values.value(i)).is_ok()) + .collect(); + + let value_buffer = values.values(); + let value_offsets = values.value_offsets(); + let mut builder = StringViewBuilder::with_capacity(keys.len()); + builder.append_block(value_buffer.clone()); + + for key in keys.iter() { + match key { + Some(v) => { + let idx = v.to_usize().ok_or_else(|| { + ArrowError::ComputeError("Invalid dictionary index".to_string()) + })?; + if valid[idx] { + // Safety: + // (1) idx is a valid index into value_offsets (Arrow invariant) + // (2) offsets are monotonically increasing, so end >= offset + // (3) the slice [offset..end] is within the buffer + // (4) the bytes are valid UTF-8 (checked above for valid[idx]) + unsafe { + let offset = value_offsets.get_unchecked(idx).as_usize(); + let end = value_offsets.get_unchecked(idx + 1).as_usize(); + let length = end - offset; + builder.append_view_unchecked(0, offset as u32, length as u32); + } + } else { + builder.append_null(); + } + } + None => builder.append_null(), + } + } + Ok(Arc::new(builder.finish())) + } + } +} + fn view_from_dict_values( keys: &PrimitiveArray, values: &GenericByteArray, diff --git a/arrow-cast/src/cast/mod.rs b/arrow-cast/src/cast/mod.rs index 5f08dcbfc138..3337a9e2e33a 100644 --- a/arrow-cast/src/cast/mod.rs +++ b/arrow-cast/src/cast/mod.rs @@ -7257,6 +7257,158 @@ mod tests { assert_eq!(casted_binary_array.as_ref(), &binary_view_array); } + #[test] + fn test_dict_large_utf8_to_utf8view() { + // Dict -> Utf8View fast path (offsets fit in u32) + let values = LargeStringArray::from(vec![ + Some("hello"), + Some("large payload over 12 bytes"), + Some("hello"), + ]); + let keys = Int8Array::from_iter([Some(0), Some(1), None, Some(0), Some(1)]); + let dict_array = DictionaryArray::::try_new(keys, Arc::new(values)).unwrap(); + + assert!(can_cast_types(dict_array.data_type(), &DataType::Utf8View)); + let casted = cast(&dict_array, &DataType::Utf8View).unwrap(); + assert_eq!(casted.data_type(), &DataType::Utf8View); + + let expected = StringViewArray::from(vec![ + Some("hello"), + Some("large payload over 12 bytes"), + None, + Some("hello"), + Some("large payload over 12 bytes"), + ]); + assert_eq!(casted.as_ref(), &expected); + } + + #[test] + fn test_dict_large_binary_to_binary_view() { + // Dict -> BinaryView fast path (offsets fit in u32) + let mut builder = GenericBinaryBuilder::::new(); + builder.append_value(b"hello"); + builder.append_value(b"world"); + let values = builder.finish(); + + let keys = Int8Array::from_iter([Some(0), Some(1), None, Some(0)]); + let dict_array = DictionaryArray::::try_new(keys, Arc::new(values)).unwrap(); + + assert!(can_cast_types( + dict_array.data_type(), + &DataType::BinaryView + )); + let casted = cast(&dict_array, &DataType::BinaryView).unwrap(); + assert_eq!(casted.data_type(), &DataType::BinaryView); + + let expected = BinaryViewArray::from_iter(vec![ + Some(b"hello".as_slice()), + Some(b"world".as_slice()), + None, + Some(b"hello".as_slice()), + ]); + assert_eq!(casted.as_ref(), &expected); + } + + #[test] + fn test_dict_utf8_to_binary_view() { + // Dict -> BinaryView cross cast: UTF-8 strings are always valid binary + let values = StringArray::from(VIEW_TEST_DATA.to_vec()); + let keys = Int8Array::from_iter([Some(1), Some(0), None, Some(3), None, Some(1), Some(4)]); + let dict_array = DictionaryArray::::try_new(keys, Arc::new(values)).unwrap(); + + assert!(can_cast_types( + dict_array.data_type(), + &DataType::BinaryView + )); + let casted = cast(&dict_array, &DataType::BinaryView).unwrap(); + assert_eq!(casted.data_type(), &DataType::BinaryView); + + let expected = BinaryViewArray::from_iter(vec![ + VIEW_TEST_DATA[1], + VIEW_TEST_DATA[0], + None, + VIEW_TEST_DATA[3], + None, + VIEW_TEST_DATA[1], + VIEW_TEST_DATA[4], + ]); + assert_eq!(casted.as_ref(), &expected); + } + + #[test] + fn test_dict_binary_to_utf8view_valid() { + // Dict -> Utf8View cross cast: all values are valid UTF-8 + let values = BinaryArray::from_iter_values([b"hello".as_slice(), b"world", b"foo"]); + let keys = Int8Array::from_iter([Some(0), Some(1), None, Some(0), Some(2)]); + let dict_array = DictionaryArray::::try_new(keys, Arc::new(values)).unwrap(); + + assert!(can_cast_types(dict_array.data_type(), &DataType::Utf8View)); + let casted = cast(&dict_array, &DataType::Utf8View).unwrap(); + assert_eq!(casted.data_type(), &DataType::Utf8View); + + let result: Vec<_> = casted.as_string_view().iter().collect(); + assert_eq!( + result, + vec![ + Some("hello"), + Some("world"), + None, + Some("hello"), + Some("foo") + ] + ); + } + + #[test] + fn test_dict_binary_to_utf8view_invalid_utf8_strict() { + // Dict -> Utf8View with invalid UTF-8: safe=false returns an error + let mut builder = BinaryBuilder::new(); + builder.append_value(b"valid"); + builder.append_value(&[0xFF]); // invalid UTF-8 + builder.append_value(b"also valid"); + let values = builder.finish(); + + let keys = Int8Array::from_iter([Some(0), Some(1), Some(2)]); + let dict_array = DictionaryArray::::try_new(keys, Arc::new(values)).unwrap(); + + let strict = CastOptions { + safe: false, + ..Default::default() + }; + let err = cast_with_options(&dict_array, &DataType::Utf8View, &strict).unwrap_err(); + assert!( + matches!(err, ArrowError::CastError(_)), + "expected CastError, got {err:?}" + ); + } + + #[test] + fn test_dict_binary_to_utf8view_invalid_utf8_safe() { + // Dict -> Utf8View with invalid UTF-8: safe=true nullifies affected rows + let mut builder = BinaryBuilder::new(); + builder.append_value(b"valid"); + builder.append_value(&[0xFF]); // invalid UTF-8 - dict index 1 + builder.append_value(b"also valid"); + let values = builder.finish(); + + // keys: 0, 1, 2, 1, 0 -> "valid", INVALID, "also valid", INVALID, "valid" + let keys = Int8Array::from_iter([Some(0), Some(1), Some(2), Some(1), Some(0)]); + let dict_array = DictionaryArray::::try_new(keys, Arc::new(values)).unwrap(); + + let safe = CastOptions { + safe: true, + ..Default::default() + }; + let casted = cast_with_options(&dict_array, &DataType::Utf8View, &safe).unwrap(); + assert_eq!(casted.data_type(), &DataType::Utf8View); + + let result: Vec<_> = casted.as_string_view().iter().collect(); + assert_eq!( + result, + vec![Some("valid"), None, Some("also valid"), None, Some("valid")] + ); + } + #[test] fn test_view_to_dict() { let string_view_array = StringViewArray::from_iter(VIEW_TEST_DATA); From ef75fcfa138c3a51be5a7e6ead2af17eb3403a08 Mon Sep 17 00:00:00 2001 From: Abhisheklearn12 Date: Sun, 19 Apr 2026 18:38:27 +0530 Subject: [PATCH 2/3] fix: remove needless borrows in invalid UTF-8 test cases --- arrow-cast/src/cast/mod.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/arrow-cast/src/cast/mod.rs b/arrow-cast/src/cast/mod.rs index 3337a9e2e33a..9cde2e8ee8e3 100644 --- a/arrow-cast/src/cast/mod.rs +++ b/arrow-cast/src/cast/mod.rs @@ -7364,7 +7364,7 @@ mod tests { // Dict -> Utf8View with invalid UTF-8: safe=false returns an error let mut builder = BinaryBuilder::new(); builder.append_value(b"valid"); - builder.append_value(&[0xFF]); // invalid UTF-8 + builder.append_value([0xFF]); // invalid UTF-8 builder.append_value(b"also valid"); let values = builder.finish(); @@ -7387,7 +7387,7 @@ mod tests { // Dict -> Utf8View with invalid UTF-8: safe=true nullifies affected rows let mut builder = BinaryBuilder::new(); builder.append_value(b"valid"); - builder.append_value(&[0xFF]); // invalid UTF-8 - dict index 1 + builder.append_value([0xFF]); // invalid UTF-8 - dict index 1 builder.append_value(b"also valid"); let values = builder.finish(); From fce2027e0eb2021d1739897f8ea59ac3a08b8051 Mon Sep 17 00:00:00 2001 From: Abhisheklearn12 Date: Fri, 24 Apr 2026 00:45:57 +0530 Subject: [PATCH 3/3] fix: propagate InvalidArgumentError from binary_dict_to_string_view --- arrow-cast/src/cast/dictionary.rs | 4 +--- arrow-cast/src/cast/mod.rs | 4 ++-- 2 files changed, 3 insertions(+), 5 deletions(-) diff --git a/arrow-cast/src/cast/dictionary.rs b/arrow-cast/src/cast/dictionary.rs index f46dc5c926eb..3125a60addbd 100644 --- a/arrow-cast/src/cast/dictionary.rs +++ b/arrow-cast/src/cast/dictionary.rs @@ -151,9 +151,7 @@ fn binary_dict_to_string_view( } Err(e) => { if !cast_options.safe { - return Err(ArrowError::CastError(format!( - "Cannot cast binary dictionary to Utf8View: {e}" - ))); + return Err(e); } // safe=true: validate each dictionary value individually so we can nullify // only the rows whose key points to an invalid UTF-8 value. diff --git a/arrow-cast/src/cast/mod.rs b/arrow-cast/src/cast/mod.rs index 9cde2e8ee8e3..752ba0160e06 100644 --- a/arrow-cast/src/cast/mod.rs +++ b/arrow-cast/src/cast/mod.rs @@ -7377,8 +7377,8 @@ mod tests { }; let err = cast_with_options(&dict_array, &DataType::Utf8View, &strict).unwrap_err(); assert!( - matches!(err, ArrowError::CastError(_)), - "expected CastError, got {err:?}" + matches!(err, ArrowError::InvalidArgumentError(_)), + "expected InvalidArgumentError, got {err:?}" ); }