diff --git a/arrow-array/src/array/run_array.rs b/arrow-array/src/array/run_array.rs index 2e378c90fd4..33738d649f7 100644 --- a/arrow-array/src/array/run_array.rs +++ b/arrow-array/src/array/run_array.rs @@ -67,9 +67,9 @@ pub struct RunArray { } impl RunArray { - // calculates the logical length of the array encoded - // by the given run_ends array. - fn logical_len(run_ends: &PrimitiveArray) -> usize { + /// Calculates the logical length of the array encoded + /// by the given run_ends array. + pub fn logical_len(run_ends: &PrimitiveArray) -> usize { let len = run_ends.len(); if len == 0 { return 0; @@ -145,14 +145,15 @@ impl RunArray { } /// Returns index to the physical array for the given index to the logical array. + /// The function does not adjust the input logical index based on `ArrayData::offset`. /// Performs a binary search on the run_ends array for the input index. #[inline] - pub fn get_physical_index(&self, logical_index: usize) -> Option { - if logical_index >= self.len() { + pub fn get_zero_offset_physical_index(&self, logical_index: usize) -> Option { + if logical_index >= Self::logical_len(&self.run_ends) { return None; } let mut st: usize = 0; - let mut en: usize = self.run_ends().len(); + let mut en: usize = self.run_ends.len(); while st + 1 < en { let mid: usize = (st + en) / 2; if logical_index @@ -164,7 +165,7 @@ impl RunArray { // `en` starts with len. The condition `st + 1 < en` ensures // `st` and `en` differs atleast by two. So the value of `mid` // will never be either `st` or `en` - self.run_ends().value_unchecked(mid - 1).as_usize() + self.run_ends.value_unchecked(mid - 1).as_usize() } { en = mid @@ -175,6 +176,17 @@ impl RunArray { Some(st) } + /// Returns index to the physical array for the given index to the logical array. + /// This function adjusts the input logical index based on `ArrayData::offset` + /// Performs a binary search on the run_ends array for the input index. + #[inline] + pub fn get_physical_index(&self, logical_index: usize) -> Option { + if logical_index >= self.len() { + return None; + } + self.get_zero_offset_physical_index(logical_index + self.offset()) + } + /// Returns the physical indices of the input logical indices. Returns error if any of the logical /// index cannot be converted to physical index. The logical indices are sorted and iterated along /// with run_ends array to find matching physical index. The approach used here was chosen over @@ -192,6 +204,10 @@ impl RunArray { { let indices_len = logical_indices.len(); + if indices_len == 0 { + return Ok(vec![]); + } + // `ordered_indices` store index into `logical_indices` and can be used // to iterate `logical_indices` in sorted order. let mut ordered_indices: Vec = (0..indices_len).collect(); @@ -204,12 +220,30 @@ impl RunArray { .unwrap() }); + // Return early if all the logical indices cannot be converted to physical indices. + let largest_logical_index = + logical_indices[*ordered_indices.last().unwrap()].as_usize(); + if largest_logical_index >= self.len() { + return Err(ArrowError::InvalidArgumentError(format!( + "Cannot convert all logical indices to physical indices. The logical index cannot be converted is {largest_logical_index}.", + ))); + } + + // Skip some physical indices based on offset. + let skip_value = if self.offset() > 0 { + self.get_zero_offset_physical_index(self.offset()).unwrap() + } else { + 0 + }; + let mut physical_indices = vec![0; indices_len]; let mut ordered_index = 0_usize; - for (physical_index, run_end) in self.run_ends.values().iter().enumerate() { - // Get the run end index of current physical index - let run_end_value = run_end.as_usize(); + for (physical_index, run_end) in + self.run_ends.values().iter().enumerate().skip(skip_value) + { + // Get the run end index (relative to offset) of current physical index + let run_end_value = run_end.as_usize() - self.offset(); // All the `logical_indices` that are less than current run end index // belongs to current physical index. @@ -552,6 +586,34 @@ mod tests { result } + // Asserts that `logical_array[logical_indices[*]] == physical_array[physical_indices[*]]` + fn compare_logical_and_physical_indices( + logical_indices: &[u32], + logical_array: &[Option], + physical_indices: &[usize], + physical_array: &PrimitiveArray, + ) { + assert_eq!(logical_indices.len(), physical_indices.len()); + + // check value in logical index in the logical_array matches physical index in physical_array + logical_indices + .iter() + .map(|f| f.as_usize()) + .zip(physical_indices.iter()) + .for_each(|(logical_ix, physical_ix)| { + let expected = logical_array[logical_ix]; + match expected { + Some(val) => { + assert!(physical_array.is_valid(*physical_ix)); + let actual = physical_array.value(*physical_ix); + assert_eq!(val, actual); + } + None => { + assert!(physical_array.is_null(*physical_ix)) + } + }; + }); + } #[test] fn test_run_array() { // Construct a value array @@ -824,23 +886,77 @@ mod tests { assert_eq!(logical_indices.len(), physical_indices.len()); // check value in logical index in the input_array matches physical index in typed_run_array - logical_indices - .iter() - .map(|f| f.as_usize()) - .zip(physical_indices.iter()) - .for_each(|(logical_ix, physical_ix)| { - let expected = input_array[logical_ix]; - match expected { - Some(val) => { - assert!(physical_values_array.is_valid(*physical_ix)); - let actual = physical_values_array.value(*physical_ix); - assert_eq!(val, actual); - } - None => { - assert!(physical_values_array.is_null(*physical_ix)) - } - }; - }); + compare_logical_and_physical_indices( + &logical_indices, + &input_array, + &physical_indices, + physical_values_array, + ); + } + } + + #[test] + fn test_get_physical_indices_sliced() { + let total_len = 80; + let input_array = build_input_array(total_len); + + // Encode the input_array to run array + let mut builder = + PrimitiveRunBuilder::::with_capacity(input_array.len()); + builder.extend(input_array.iter().copied()); + let run_array = builder.finish(); + let physical_values_array = as_primitive_array::(run_array.values()); + + // test for all slice lengths. + for slice_len in 1..=total_len { + // create an array consisting of all the indices repeated twice and shuffled. + let mut logical_indices: Vec = (0_u32..(slice_len as u32)).collect(); + // add same indices once more + logical_indices.append(&mut logical_indices.clone()); + let mut rng = thread_rng(); + logical_indices.shuffle(&mut rng); + + // test for offset = 0 and slice length = slice_len + // slice the input array using which the run array was built. + let sliced_input_array = &input_array[0..slice_len]; + + // slice the run array + let sliced_run_array: RunArray = + run_array.slice(0, slice_len).into_data().into(); + + // Get physical indices. + let physical_indices = sliced_run_array + .get_physical_indices(&logical_indices) + .unwrap(); + + compare_logical_and_physical_indices( + &logical_indices, + sliced_input_array, + &physical_indices, + physical_values_array, + ); + + // test for offset = total_len - slice_len and slice length = slice_len + // slice the input array using which the run array was built. + let sliced_input_array = &input_array[total_len - slice_len..total_len]; + + // slice the run array + let sliced_run_array: RunArray = run_array + .slice(total_len - slice_len, slice_len) + .into_data() + .into(); + + // Get physical indices + let physical_indices = sliced_run_array + .get_physical_indices(&logical_indices) + .unwrap(); + + compare_logical_and_physical_indices( + &logical_indices, + sliced_input_array, + &physical_indices, + physical_values_array, + ); } } } diff --git a/arrow-data/src/data.rs b/arrow-data/src/data.rs index 709262e8346..c5e32635d1d 100644 --- a/arrow-data/src/data.rs +++ b/arrow-data/src/data.rs @@ -1326,9 +1326,9 @@ impl ArrayData { DataType::RunEndEncoded(run_ends, _values) => { let run_ends_data = self.child_data()[0].clone(); match run_ends.data_type() { - DataType::Int16 => run_ends_data.check_run_ends::(self.len()), - DataType::Int32 => run_ends_data.check_run_ends::(self.len()), - DataType::Int64 => run_ends_data.check_run_ends::(self.len()), + DataType::Int16 => run_ends_data.check_run_ends::(), + DataType::Int32 => run_ends_data.check_run_ends::(), + DataType::Int64 => run_ends_data.check_run_ends::(), _ => unreachable!(), } } @@ -1487,7 +1487,7 @@ impl ArrayData { } /// Validates that each value in run_ends array is positive and strictly increasing. - fn check_run_ends(&self, array_len: usize) -> Result<(), ArrowError> + fn check_run_ends(&self) -> Result<(), ArrowError> where T: ArrowNativeType + TryInto + num::Num + std::fmt::Display, { @@ -1514,9 +1514,10 @@ impl ArrayData { Ok(()) })?; - if prev_value.as_usize() != array_len { + if prev_value.as_usize() < (self.offset + self.len) { return Err(ArrowError::InvalidArgumentError(format!( - "The length of array does not match the last value in the run_ends array. The last value of run_ends array is {prev_value} and length of array is {array_len}." + "The offset + length of array should be less or equal to last value in the run_ends array. The last value of run_ends array is {prev_value} and offset + length of array is {}.", + self.offset + self.len ))); } Ok(()) diff --git a/arrow-data/src/equal/mod.rs b/arrow-data/src/equal/mod.rs index aff61e3d37e..871a312ca47 100644 --- a/arrow-data/src/equal/mod.rs +++ b/arrow-data/src/equal/mod.rs @@ -31,6 +31,7 @@ mod fixed_list; mod list; mod null; mod primitive; +mod run; mod structure; mod union; mod utils; @@ -50,6 +51,8 @@ use structure::struct_equal; use union::union_equal; use variable_size::variable_sized_equal; +use self::run::run_equal; + /// Compares the values of two [ArrayData] starting at `lhs_start` and `rhs_start` respectively /// for `len` slots. #[inline] @@ -137,7 +140,7 @@ fn equal_values( }, DataType::Float16 => primitive_equal::(lhs, rhs, lhs_start, rhs_start, len), DataType::Map(_, _) => list_equal::(lhs, rhs, lhs_start, rhs_start, len), - DataType::RunEndEncoded(_, _) => todo!(), + DataType::RunEndEncoded(_, _) => run_equal(lhs, rhs, lhs_start, rhs_start, len), } } diff --git a/arrow-data/src/equal/run.rs b/arrow-data/src/equal/run.rs new file mode 100644 index 00000000000..ede172c999f --- /dev/null +++ b/arrow-data/src/equal/run.rs @@ -0,0 +1,84 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::data::ArrayData; + +use super::equal_range; + +/// The current implementation of comparison of run array support physical comparison. +/// Comparing run encoded array based on logical indices (`lhs_start`, `rhs_start`) will +/// be time consuming as converting from logical index to physical index cannot be done +/// in constant time. The current comparison compares the underlying physical arrays. +pub(super) fn run_equal( + lhs: &ArrayData, + rhs: &ArrayData, + lhs_start: usize, + rhs_start: usize, + len: usize, +) -> bool { + if lhs_start != 0 + || rhs_start != 0 + || (lhs.len() != len && rhs.len() != len) + || lhs.offset() > 0 + || rhs.offset() > 0 + { + unimplemented!("Logical comparison for run array not supported.") + } + + if lhs.len() != rhs.len() { + return false; + } + + let lhs_run_ends_array = lhs.child_data().get(0).unwrap(); + let lhs_values_array = lhs.child_data().get(1).unwrap(); + + let rhs_run_ends_array = rhs.child_data().get(0).unwrap(); + let rhs_values_array = rhs.child_data().get(1).unwrap(); + + if lhs_run_ends_array.len() != rhs_run_ends_array.len() { + return false; + } + + if lhs_values_array.len() != rhs_values_array.len() { + return false; + } + + // check run ends array are equal. The length of the physical array + // is used to validate the child arrays. + let run_ends_equal = equal_range( + lhs_run_ends_array, + rhs_run_ends_array, + lhs_start, + rhs_start, + lhs_run_ends_array.len(), + ); + + // if run ends array are not the same return early without validating + // values array. + if !run_ends_equal { + return false; + } + + // check values array are equal + equal_range( + lhs_values_array, + rhs_values_array, + lhs_start, + rhs_start, + rhs_values_array.len(), + ) +} diff --git a/arrow-ipc/regen.sh b/arrow-ipc/regen.sh index 9d384b6b63b..8d8862ccc7f 100755 --- a/arrow-ipc/regen.sh +++ b/arrow-ipc/regen.sh @@ -18,15 +18,13 @@ DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )" -# Change to the toplevel Rust directory -pushd $DIR/../../ +# Change to the toplevel `arrow-rs` directory +pushd $DIR/../ echo "Build flatc from source ..." FB_URL="https://github.com/google/flatbuffers" -# https://github.com/google/flatbuffers/pull/6393 -FB_COMMIT="408cf5802415e1dea65fef7489a6c2f3740fb381" -FB_DIR="rust/arrow/.flatbuffers" +FB_DIR="arrow/.flatbuffers" FLATC="$FB_DIR/bazel-bin/flatc" if [ -z $(which bazel) ]; then @@ -44,28 +42,21 @@ else git -C $FB_DIR pull fi -echo "hard reset to $FB_COMMIT" -git -C $FB_DIR reset --hard $FB_COMMIT - pushd $FB_DIR echo "run: bazel build :flatc ..." bazel build :flatc popd -FB_PATCH="rust/arrow/format-0ed34c83.patch" -echo "Patch flatbuffer files with ${FB_PATCH} for cargo doc" -echo "NOTE: the patch MAY need update in case of changes in format/*.fbs" -git apply --check ${FB_PATCH} && git apply ${FB_PATCH} # Execute the code generation: -$FLATC --filename-suffix "" --rust -o rust/arrow/src/ipc/gen/ format/*.fbs +$FLATC --filename-suffix "" --rust -o arrow-ipc/src/gen/ format/*.fbs # Reset changes to format/ git checkout -- format # Now the files are wrongly named so we have to change that. popd -pushd $DIR/src/ipc/gen +pushd $DIR/src/gen PREFIX=$(cat <<'HEREDOC' // Licensed to the Apache Software Foundation (ASF) under one @@ -94,9 +85,9 @@ use flatbuffers::EndianScalar; HEREDOC ) -SCHEMA_IMPORT="\nuse crate::ipc::gen::Schema::*;" -SPARSE_TENSOR_IMPORT="\nuse crate::ipc::gen::SparseTensor::*;" -TENSOR_IMPORT="\nuse crate::ipc::gen::Tensor::*;" +SCHEMA_IMPORT="\nuse crate::gen::Schema::*;" +SPARSE_TENSOR_IMPORT="\nuse crate::gen::SparseTensor::*;" +TENSOR_IMPORT="\nuse crate::gen::Tensor::*;" # For flatbuffer(1.12.0+), remove: use crate::${name}::\*; names=("File" "Message" "Schema" "SparseTensor" "Tensor") @@ -119,8 +110,9 @@ for f in `ls *.rs`; do sed -i '' '/} \/\/ pub mod arrow/d' $f sed -i '' '/} \/\/ pub mod apache/d' $f sed -i '' '/} \/\/ pub mod org/d' $f - sed -i '' '/use std::mem;/d' $f - sed -i '' '/use std::cmp::Ordering;/d' $f + sed -i '' '/use core::mem;/d' $f + sed -i '' '/use core::cmp::Ordering;/d' $f + sed -i '' '/use self::flatbuffers::{EndianScalar, Follow};/d' $f # required by flatc 1.12.0+ sed -i '' "/\#\!\[allow(unused_imports, dead_code)\]/d" $f @@ -150,7 +142,7 @@ done # Return back to base directory popd -cargo +stable fmt -- src/ipc/gen/* +cargo +stable fmt -- src/gen/* echo "DONE!" echo "Please run 'cargo doc' and 'cargo test' with nightly and stable, " diff --git a/arrow-ipc/src/convert.rs b/arrow-ipc/src/convert.rs index c5681b0c8f1..aede8a448a0 100644 --- a/arrow-ipc/src/convert.rs +++ b/arrow-ipc/src/convert.rs @@ -364,6 +364,18 @@ pub(crate) fn get_data_type(field: crate::Field, may_be_dictionary: bool) -> Dat DataType::Struct(fields) } + crate::Type::RunEndEncoded => { + let children = field.children().unwrap(); + if children.len() != 2 { + panic!( + "RunEndEncoded type should have exactly two children. Found {}", + children.len() + ) + } + let run_ends_field = children.get(0).into(); + let values_field = children.get(1).into(); + DataType::RunEndEncoded(Box::new(run_ends_field), Box::new(values_field)) + } crate::Type::Map => { let map = field.type_as_map().unwrap(); let children = field.children().unwrap(); @@ -710,7 +722,18 @@ pub(crate) fn get_fb_field_type<'a>( children: Some(fbb.create_vector(&children[..])), } } - RunEndEncoded(_, _) => todo!(), + RunEndEncoded(run_ends, values) => { + let run_ends_field = build_field(fbb, run_ends); + let values_field = build_field(fbb, values); + let children = vec![run_ends_field, values_field]; + FBFieldType { + type_type: crate::Type::RunEndEncoded, + type_: crate::RunEndEncodedBuilder::new(fbb) + .finish() + .as_union_value(), + children: Some(fbb.create_vector(&children[..])), + } + } Map(map_field, keys_sorted) => { let child = build_field(fbb, map_field); let mut field_type = crate::MapBuilder::new(fbb); diff --git a/arrow-ipc/src/gen/Schema.rs b/arrow-ipc/src/gen/Schema.rs index 6479bece721..cf3ea0bd4ab 100644 --- a/arrow-ipc/src/gen/Schema.rs +++ b/arrow-ipc/src/gen/Schema.rs @@ -735,13 +735,13 @@ pub const ENUM_MIN_TYPE: u8 = 0; since = "2.0.0", note = "Use associated constants instead. This will no longer be generated in 2021." )] -pub const ENUM_MAX_TYPE: u8 = 21; +pub const ENUM_MAX_TYPE: u8 = 22; #[deprecated( since = "2.0.0", note = "Use associated constants instead. This will no longer be generated in 2021." )] #[allow(non_camel_case_types)] -pub const ENUM_VALUES_TYPE: [Type; 22] = [ +pub const ENUM_VALUES_TYPE: [Type; 23] = [ Type::NONE, Type::Null, Type::Int, @@ -764,6 +764,7 @@ pub const ENUM_VALUES_TYPE: [Type; 22] = [ Type::LargeBinary, Type::LargeUtf8, Type::LargeList, + Type::RunEndEncoded, ]; /// ---------------------------------------------------------------------- @@ -796,9 +797,10 @@ impl Type { pub const LargeBinary: Self = Self(19); pub const LargeUtf8: Self = Self(20); pub const LargeList: Self = Self(21); + pub const RunEndEncoded: Self = Self(22); pub const ENUM_MIN: u8 = 0; - pub const ENUM_MAX: u8 = 21; + pub const ENUM_MAX: u8 = 22; pub const ENUM_VALUES: &'static [Self] = &[ Self::NONE, Self::Null, @@ -822,6 +824,7 @@ impl Type { Self::LargeBinary, Self::LargeUtf8, Self::LargeList, + Self::RunEndEncoded, ]; /// Returns the variant's name or "" if unknown. pub fn variant_name(self) -> Option<&'static str> { @@ -848,6 +851,7 @@ impl Type { Self::LargeBinary => Some("LargeBinary"), Self::LargeUtf8 => Some("LargeUtf8"), Self::LargeList => Some("LargeList"), + Self::RunEndEncoded => Some("RunEndEncoded"), _ => None, } } @@ -2646,6 +2650,90 @@ impl core::fmt::Debug for Bool<'_> { ds.finish() } } +pub enum RunEndEncodedOffset {} +#[derive(Copy, Clone, PartialEq)] + +/// Contains two child arrays, run_ends and values. +/// The run_ends child array must be a 16/32/64-bit integer array +/// which encodes the indices at which the run with the value in +/// each corresponding index in the values child array ends. +/// Like list/struct types, the value array can be of any type. +pub struct RunEndEncoded<'a> { + pub _tab: flatbuffers::Table<'a>, +} + +impl<'a> flatbuffers::Follow<'a> for RunEndEncoded<'a> { + type Inner = RunEndEncoded<'a>; + #[inline] + unsafe fn follow(buf: &'a [u8], loc: usize) -> Self::Inner { + Self { + _tab: flatbuffers::Table::new(buf, loc), + } + } +} + +impl<'a> RunEndEncoded<'a> { + #[inline] + pub unsafe fn init_from_table(table: flatbuffers::Table<'a>) -> Self { + RunEndEncoded { _tab: table } + } + #[allow(unused_mut)] + pub fn create<'bldr: 'args, 'args: 'mut_bldr, 'mut_bldr>( + _fbb: &'mut_bldr mut flatbuffers::FlatBufferBuilder<'bldr>, + _args: &'args RunEndEncodedArgs, + ) -> flatbuffers::WIPOffset> { + let mut builder = RunEndEncodedBuilder::new(_fbb); + builder.finish() + } +} + +impl flatbuffers::Verifiable for RunEndEncoded<'_> { + #[inline] + fn run_verifier( + v: &mut flatbuffers::Verifier, + pos: usize, + ) -> Result<(), flatbuffers::InvalidFlatbuffer> { + use flatbuffers::Verifiable; + v.visit_table(pos)?.finish(); + Ok(()) + } +} +pub struct RunEndEncodedArgs {} +impl<'a> Default for RunEndEncodedArgs { + #[inline] + fn default() -> Self { + RunEndEncodedArgs {} + } +} + +pub struct RunEndEncodedBuilder<'a: 'b, 'b> { + fbb_: &'b mut flatbuffers::FlatBufferBuilder<'a>, + start_: flatbuffers::WIPOffset, +} +impl<'a: 'b, 'b> RunEndEncodedBuilder<'a, 'b> { + #[inline] + pub fn new( + _fbb: &'b mut flatbuffers::FlatBufferBuilder<'a>, + ) -> RunEndEncodedBuilder<'a, 'b> { + let start = _fbb.start_table(); + RunEndEncodedBuilder { + fbb_: _fbb, + start_: start, + } + } + #[inline] + pub fn finish(self) -> flatbuffers::WIPOffset> { + let o = self.fbb_.end_table(self.start_); + flatbuffers::WIPOffset::new(o.value()) + } +} + +impl core::fmt::Debug for RunEndEncoded<'_> { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + let mut ds = f.debug_struct("RunEndEncoded"); + ds.finish() + } +} pub enum DecimalOffset {} #[derive(Copy, Clone, PartialEq)] @@ -4316,6 +4404,21 @@ impl<'a> Field<'a> { None } } + + #[inline] + #[allow(non_snake_case)] + pub fn type_as_run_end_encoded(&self) -> Option> { + if self.type_type() == Type::RunEndEncoded { + self.type_().map(|t| { + // Safety: + // Created from a valid Table for this object + // Which contains a valid union in this slot + unsafe { RunEndEncoded::init_from_table(t) } + }) + } else { + None + } + } } impl flatbuffers::Verifiable for Field<'_> { @@ -4351,6 +4454,7 @@ impl flatbuffers::Verifiable for Field<'_> { Type::LargeBinary => v.verify_union_variant::>("Type::LargeBinary", pos), Type::LargeUtf8 => v.verify_union_variant::>("Type::LargeUtf8", pos), Type::LargeList => v.verify_union_variant::>("Type::LargeList", pos), + Type::RunEndEncoded => v.verify_union_variant::>("Type::RunEndEncoded", pos), _ => Ok(()), } })? @@ -4686,6 +4790,16 @@ impl core::fmt::Debug for Field<'_> { ) } } + Type::RunEndEncoded => { + if let Some(x) = self.type_as_run_end_encoded() { + ds.field("type_", &x) + } else { + ds.field( + "type_", + &"InvalidFlatbuffer: Union discriminant does not match value.", + ) + } + } _ => { let x: Option<()> = None; ds.field("type_", &x) diff --git a/arrow-ipc/src/gen/SparseTensor.rs b/arrow-ipc/src/gen/SparseTensor.rs index c5e06c30e03..83fed4873b6 100644 --- a/arrow-ipc/src/gen/SparseTensor.rs +++ b/arrow-ipc/src/gen/SparseTensor.rs @@ -1524,6 +1524,20 @@ impl<'a> SparseTensor<'a> { } } + #[inline] + #[allow(non_snake_case)] + pub fn type_as_run_end_encoded(&self) -> Option> { + if self.type_type() == Type::RunEndEncoded { + let u = self.type_(); + // Safety: + // Created from a valid Table for this object + // Which contains a valid union in this slot + Some(unsafe { RunEndEncoded::init_from_table(u) }) + } else { + None + } + } + #[inline] #[allow(non_snake_case)] pub fn sparseIndex_as_sparse_tensor_index_coo( @@ -1604,6 +1618,7 @@ impl flatbuffers::Verifiable for SparseTensor<'_> { Type::LargeBinary => v.verify_union_variant::>("Type::LargeBinary", pos), Type::LargeUtf8 => v.verify_union_variant::>("Type::LargeUtf8", pos), Type::LargeList => v.verify_union_variant::>("Type::LargeList", pos), + Type::RunEndEncoded => v.verify_union_variant::>("Type::RunEndEncoded", pos), _ => Ok(()), } })? @@ -1943,6 +1958,16 @@ impl core::fmt::Debug for SparseTensor<'_> { ) } } + Type::RunEndEncoded => { + if let Some(x) = self.type_as_run_end_encoded() { + ds.field("type_", &x) + } else { + ds.field( + "type_", + &"InvalidFlatbuffer: Union discriminant does not match value.", + ) + } + } _ => { let x: Option<()> = None; ds.field("type_", &x) diff --git a/arrow-ipc/src/gen/Tensor.rs b/arrow-ipc/src/gen/Tensor.rs index 954ebd29012..43133fec036 100644 --- a/arrow-ipc/src/gen/Tensor.rs +++ b/arrow-ipc/src/gen/Tensor.rs @@ -565,6 +565,20 @@ impl<'a> Tensor<'a> { None } } + + #[inline] + #[allow(non_snake_case)] + pub fn type_as_run_end_encoded(&self) -> Option> { + if self.type_type() == Type::RunEndEncoded { + let u = self.type_(); + // Safety: + // Created from a valid Table for this object + // Which contains a valid union in this slot + Some(unsafe { RunEndEncoded::init_from_table(u) }) + } else { + None + } + } } impl flatbuffers::Verifiable for Tensor<'_> { @@ -598,6 +612,7 @@ impl flatbuffers::Verifiable for Tensor<'_> { Type::LargeBinary => v.verify_union_variant::>("Type::LargeBinary", pos), Type::LargeUtf8 => v.verify_union_variant::>("Type::LargeUtf8", pos), Type::LargeList => v.verify_union_variant::>("Type::LargeList", pos), + Type::RunEndEncoded => v.verify_union_variant::>("Type::RunEndEncoded", pos), _ => Ok(()), } })? @@ -907,6 +922,16 @@ impl core::fmt::Debug for Tensor<'_> { ) } } + Type::RunEndEncoded => { + if let Some(x) = self.type_as_run_end_encoded() { + ds.field("type_", &x) + } else { + ds.field( + "type_", + &"InvalidFlatbuffer: Union discriminant does not match value.", + ) + } + } _ => { let x: Option<()> = None; ds.field("type_", &x) diff --git a/arrow-ipc/src/reader.rs b/arrow-ipc/src/reader.rs index 17f521e423a..6842474fb4e 100644 --- a/arrow-ipc/src/reader.rs +++ b/arrow-ipc/src/reader.rs @@ -194,6 +194,50 @@ fn create_array( }; Arc::new(struct_array) } + RunEndEncoded(run_ends_field, values_field) => { + let run_node = nodes.get(node_index); + node_index += 1; + + let run_ends_triple = create_array( + nodes, + run_ends_field, + data, + buffers, + dictionaries_by_id, + node_index, + buffer_index, + compression_codec, + metadata, + )?; + node_index = run_ends_triple.1; + buffer_index = run_ends_triple.2; + + let values_triple = create_array( + nodes, + values_field, + data, + buffers, + dictionaries_by_id, + node_index, + buffer_index, + compression_codec, + metadata, + )?; + node_index = values_triple.1; + buffer_index = values_triple.2; + + let run_array_length = run_node.length() as usize; + let run_array_null_count = run_node.null_count() as usize; + let data = ArrayData::builder(data_type.clone()) + .len(run_array_length) + .null_count(run_array_null_count) + .offset(0) + .add_child_data(run_ends_triple.0.into_data()) + .add_child_data(values_triple.0.into_data()) + .build()?; + + make_array(data) + } // Create dictionary array from RecordBatch Dictionary(_, _) => { let index_node = nodes.get(node_index); @@ -361,6 +405,17 @@ fn skip_field( buffer_index = tuple.1; } } + RunEndEncoded(run_ends_field, values_field) => { + node_index += 1; + + let tuple = skip_field(run_ends_field.data_type(), node_index, buffer_index)?; + node_index = tuple.0; + buffer_index = tuple.1; + + let tuple = skip_field(values_field.data_type(), node_index, buffer_index)?; + node_index = tuple.0; + buffer_index = tuple.1; + } Dictionary(_, _) => { node_index += 1; buffer_index += 2; @@ -1189,9 +1244,11 @@ impl RecordBatchReader for StreamReader { #[cfg(test)] mod tests { + use crate::writer::unslice_run_array; + use super::*; - use arrow_array::builder::UnionBuilder; + use arrow_array::builder::{PrimitiveRunBuilder, UnionBuilder}; use arrow_array::types::*; use arrow_buffer::ArrowNativeType; use arrow_data::ArrayDataBuilder; @@ -1227,6 +1284,11 @@ mod tests { ]; let struct_data_type = DataType::Struct(struct_fields); + let run_encoded_data_type = DataType::RunEndEncoded( + Box::new(Field::new("run_ends", DataType::Int16, false)), + Box::new(Field::new("values", DataType::Int32, true)), + ); + // define schema Schema::new(vec![ Field::new("f0", DataType::UInt32, false), @@ -1239,9 +1301,10 @@ mod tests { Field::new("f7", DataType::FixedSizeBinary(3), true), Field::new("f8", fixed_size_list_data_type, false), Field::new("f9", struct_data_type, false), - Field::new("f10", DataType::Boolean, false), - Field::new("f11", dict_data_type, false), - Field::new("f12", DataType::Utf8, false), + Field::new("f10", run_encoded_data_type, false), + Field::new("f11", DataType::Boolean, false), + Field::new("f12", dict_data_type, false), + Field::new("f13", DataType::Utf8, false), ]) } @@ -1296,14 +1359,19 @@ mod tests { .unwrap(); let array9: ArrayRef = Arc::new(StructArray::from(array9)); - let array10 = BooleanArray::from(vec![false, false, true]); + let array10_input = vec![Some(1_i32), None, None]; + let mut array10_builder = PrimitiveRunBuilder::::new(); + array10_builder.extend(array10_input.into_iter()); + let array10 = array10_builder.finish(); + + let array11 = BooleanArray::from(vec![false, false, true]); - let array11_values = StringArray::from(vec!["x", "yy", "zzz"]); - let array11_keys = Int8Array::from_iter_values([1, 1, 2]); - let array11 = - DictionaryArray::::try_new(&array11_keys, &array11_values).unwrap(); + let array12_values = StringArray::from(vec!["x", "yy", "zzz"]); + let array12_keys = Int8Array::from_iter_values([1, 1, 2]); + let array12 = + DictionaryArray::::try_new(&array12_keys, &array12_values).unwrap(); - let array12 = StringArray::from(vec!["a", "bb", "ccc"]); + let array13 = StringArray::from(vec!["a", "bb", "ccc"]); // create record batch RecordBatch::try_new( @@ -1322,6 +1390,7 @@ mod tests { Arc::new(array10), Arc::new(array11), Arc::new(array12), + Arc::new(array13), ], ) .unwrap() @@ -1510,6 +1579,43 @@ mod tests { check_union_with_builder(UnionBuilder::new_sparse()); } + #[test] + fn test_roundtrip_stream_run_array_sliced() { + let run_array_1: Int32RunArray = vec!["a", "a", "a", "b", "b", "c", "c", "c"] + .into_iter() + .collect(); + let run_array_1_sliced = run_array_1.slice(2, 5); + + let run_array_2_inupt = vec![Some(1_i32), None, None, Some(2), Some(2)]; + let mut run_array_2_builder = PrimitiveRunBuilder::::new(); + run_array_2_builder.extend(run_array_2_inupt.into_iter()); + let run_array_2 = run_array_2_builder.finish(); + + let schema = Arc::new(Schema::new(vec![ + Field::new( + "run_array_1_sliced", + run_array_1_sliced.data_type().clone(), + false, + ), + Field::new("run_array_2", run_array_2.data_type().clone(), false), + ])); + let input_batch = RecordBatch::try_new( + schema, + vec![Arc::new(run_array_1_sliced.clone()), Arc::new(run_array_2)], + ) + .unwrap(); + let output_batch = roundtrip_ipc_stream(&input_batch); + + // As partial comparison not yet supported for run arrays, the sliced run array + // has to be unsliced before comparing with the output. the second run array + // can be compared as such. + assert_eq!(input_batch.column(1), output_batch.column(1)); + + let run_array_1_unsliced = + unslice_run_array(run_array_1_sliced.into_data()).unwrap(); + assert_eq!(run_array_1_unsliced, output_batch.column(0).into_data()); + } + #[test] fn test_roundtrip_stream_nested_dict() { let xs = vec!["AA", "BB", "AA", "CC", "BB"]; diff --git a/arrow-ipc/src/writer.rs b/arrow-ipc/src/writer.rs index 8835cb49ffc..f019340154a 100644 --- a/arrow-ipc/src/writer.rs +++ b/arrow-ipc/src/writer.rs @@ -24,14 +24,15 @@ use std::cmp::min; use std::collections::HashMap; use std::io::{BufWriter, Write}; +use arrow_array::types::{Int16Type, Int32Type, Int64Type, RunEndIndexType}; use flatbuffers::FlatBufferBuilder; use arrow_array::builder::BufferBuilder; use arrow_array::cast::*; use arrow_array::*; use arrow_buffer::bit_util; -use arrow_buffer::{Buffer, MutableBuffer}; -use arrow_data::{layout, ArrayData, BufferSpec}; +use arrow_buffer::{ArrowNativeType, Buffer, MutableBuffer}; +use arrow_data::{layout, ArrayData, ArrayDataBuilder, BufferSpec}; use arrow_schema::*; use crate::compression::CompressionCodec; @@ -218,6 +219,24 @@ impl IpcDataGenerator { )?; } } + DataType::RunEndEncoded(_, values) => { + if column.data().child_data().len() != 2 { + return Err(ArrowError::InvalidArgumentError(format!( + "The run encoded array should have exactly two child arrays. Found {}", + column.data().child_data().len() + ))); + } + // The run_ends array is not expected to be dictionoary encoded. Hence encode dictionaries + // only for values array. + let values_array = make_array(column.data().child_data()[1].clone()); + self.encode_dictionaries( + values, + &values_array, + encoded_dictionaries, + dictionary_tracker, + write_options, + )?; + } DataType::List(field) => { let list = as_list_array(column); self.encode_dictionaries( @@ -533,6 +552,94 @@ impl IpcDataGenerator { } } +pub(crate) fn unslice_run_array(arr: ArrayData) -> Result { + match arr.data_type() { + DataType::RunEndEncoded(k, _) => match k.data_type() { + DataType::Int16 => Ok(into_zero_offset_run_array( + RunArray::::from(arr), + )? + .into_data()), + DataType::Int32 => Ok(into_zero_offset_run_array( + RunArray::::from(arr), + )? + .into_data()), + DataType::Int64 => Ok(into_zero_offset_run_array( + RunArray::::from(arr), + )? + .into_data()), + d => unreachable!("Unexpected data type {d}"), + }, + d => Err(ArrowError::InvalidArgumentError(format!( + "The given array is not a run array. Data type of given array: {d}" + ))), + } +} + +// Returns a `RunArray` with zero offset and length matching the last value +// in run_ends array. +fn into_zero_offset_run_array( + run_array: RunArray, +) -> Result, ArrowError> { + if run_array.offset() == 0 + && run_array.len() == RunArray::::logical_len(run_array.run_ends()) + { + return Ok(run_array); + } + // The physical index of original run_ends array from which the `ArrayData`is sliced. + let start_physical_index = run_array + .get_zero_offset_physical_index(run_array.offset()) + .unwrap(); + + // The logical length of original run_ends array until which the `ArrayData` is sliced. + let end_logical_index = run_array.offset() + run_array.len() - 1; + // The physical index of original run_ends array until which the `ArrayData`is sliced. + let end_physical_index = run_array + .get_zero_offset_physical_index(end_logical_index) + .unwrap(); + + let physical_length = end_physical_index - start_physical_index + 1; + + // build new run_ends array by subtrating offset from run ends. + let mut builder = BufferBuilder::::new(physical_length); + for ix in start_physical_index..end_physical_index { + let run_end_value = unsafe { + // Safety: + // start_physical_index and end_physical_index are within + // run_ends array bounds. + run_array.run_ends().value_unchecked(ix).as_usize() + }; + let run_end_value = run_end_value - run_array.offset(); + builder.append(R::Native::from_usize(run_end_value).unwrap()); + } + builder.append(R::Native::from_usize(run_array.len()).unwrap()); + let new_run_ends = unsafe { + // Safety: + // The function builds a valid run_ends array and hence need not be validated. + ArrayDataBuilder::new(run_array.run_ends().data_type().clone()) + .len(physical_length) + .null_count(0) + .add_buffer(builder.finish()) + .build_unchecked() + }; + + // build new values by slicing physical indices. + let new_values = run_array + .values() + .slice(start_physical_index, physical_length) + .into_data(); + + let builder = ArrayDataBuilder::new(run_array.data_type().clone()) + .len(run_array.len()) + .add_child_data(new_run_ends) + .add_child_data(new_values); + let array_data = unsafe { + // Safety: + // This function builds a valid run array and hence can skip validation. + builder.build_unchecked() + }; + Ok(array_data.into()) +} + /// Keeps track of dictionaries that have been written, to avoid emitting the same dictionary /// multiple times. Can optionally error if an update to an existing dictionary is attempted, which /// isn't allowed in the `FileWriter`. @@ -968,11 +1075,15 @@ fn write_continuation( /// In V4, null types have no validity bitmap /// In V5 and later, null and union types have no validity bitmap +/// Run end encoded type has no validity bitmap. fn has_validity_bitmap(data_type: &DataType, write_options: &IpcWriteOptions) -> bool { if write_options.metadata_version < crate::MetadataVersion::V5 { !matches!(data_type, DataType::Null) } else { - !matches!(data_type, DataType::Null | DataType::Union(_, _, _)) + !matches!( + data_type, + DataType::Null | DataType::Union(_, _, _) | DataType::RunEndEncoded(_, _) + ) } } @@ -1242,24 +1353,45 @@ fn write_array_data( } } - if !matches!(array_data.data_type(), DataType::Dictionary(_, _)) { - // recursively write out nested structures - for data_ref in array_data.child_data() { - // write the nested data (e.g list data) - offset = write_array_data( - data_ref, - buffers, - arrow_data, - nodes, - offset, - data_ref.len(), - data_ref.null_count(), - compression_codec, - write_options, - )?; + match array_data.data_type() { + DataType::Dictionary(_, _) => {} + DataType::RunEndEncoded(_, _) => { + // unslice the run encoded array. + let arr = unslice_run_array(array_data.clone())?; + // recursively write out nested structures + for data_ref in arr.child_data() { + // write the nested data (e.g list data) + offset = write_array_data( + data_ref, + buffers, + arrow_data, + nodes, + offset, + data_ref.len(), + data_ref.null_count(), + compression_codec, + write_options, + )?; + } + } + _ => { + // recursively write out nested structures + for data_ref in array_data.child_data() { + // write the nested data (e.g list data) + offset = write_array_data( + data_ref, + buffers, + arrow_data, + nodes, + offset, + data_ref.len(), + data_ref.null_count(), + compression_codec, + write_options, + )?; + } } } - Ok(offset) } @@ -1322,6 +1454,7 @@ mod tests { use crate::MetadataVersion; use crate::reader::*; + use arrow_array::builder::PrimitiveRunBuilder; use arrow_array::builder::UnionBuilder; use arrow_array::types::*; use arrow_schema::DataType; @@ -1992,4 +2125,62 @@ mod tests { let batch2 = reader.next().unwrap().unwrap(); assert_eq!(batch, batch2); } + + #[test] + fn test_run_array_unslice() { + let total_len = 80; + let vals: Vec> = + vec![Some(1), None, Some(2), Some(3), Some(4), None, Some(5)]; + let repeats: Vec = vec![3, 4, 1, 2]; + let mut input_array: Vec> = Vec::with_capacity(total_len); + for ix in 0_usize..32 { + let repeat: usize = repeats[ix % repeats.len()]; + let val: Option = vals[ix % vals.len()]; + input_array.resize(input_array.len() + repeat, val); + } + + // Encode the input_array to run array + let mut builder = + PrimitiveRunBuilder::::with_capacity(input_array.len()); + builder.extend(input_array.iter().copied()); + let run_array = builder.finish(); + + // test for all slice lengths. + for slice_len in 1..=total_len { + // test for offset = 0, slice length = slice_len + let sliced_run_array: RunArray = + run_array.slice(0, slice_len).into_data().into(); + + // Create unsliced run array. + let unsliced_run_array = + into_zero_offset_run_array(sliced_run_array).unwrap(); + let typed = unsliced_run_array + .downcast::>() + .unwrap(); + let expected: Vec> = + input_array.iter().take(slice_len).copied().collect(); + let actual: Vec> = typed.into_iter().collect(); + assert_eq!(expected, actual); + + // test for offset = total_len - slice_len, length = slice_len + let sliced_run_array: RunArray = run_array + .slice(total_len - slice_len, slice_len) + .into_data() + .into(); + + // Create unsliced run array. + let unsliced_run_array = + into_zero_offset_run_array(sliced_run_array).unwrap(); + let typed = unsliced_run_array + .downcast::>() + .unwrap(); + let expected: Vec> = input_array + .iter() + .skip(total_len - slice_len) + .copied() + .collect(); + let actual: Vec> = typed.into_iter().collect(); + assert_eq!(expected, actual); + } + } } diff --git a/arrow-select/src/take.rs b/arrow-select/src/take.rs index f8668b56e1d..f8383bbe3d2 100644 --- a/arrow-select/src/take.rs +++ b/arrow-select/src/take.rs @@ -832,7 +832,6 @@ macro_rules! primitive_run_take { /// for e.g. an input `RunArray{ run_ends = [2,4,6,8], values=[1,2,1,2] }` and `indices=[2,7]` /// would be converted to `physical_indices=[1,3]` which will be used to build /// output `RunArray{ run_ends=[2], values=[2] }` - fn take_run( run_array: &RunArray, logical_indices: &PrimitiveArray, diff --git a/format/Schema.fbs b/format/Schema.fbs index 7ee827b5de8..6337f72ec9d 100644 --- a/format/Schema.fbs +++ b/format/Schema.fbs @@ -19,8 +19,9 @@ /// Format Version History. /// Version 1.0 - Forward and backwards compatibility guaranteed. -/// Version 1.1 - Add Decimal256 (No format release). -/// Version 1.2 (Pending)- Add Interval MONTH_DAY_NANO +/// Version 1.1 - Add Decimal256. +/// Version 1.2 - Add Interval MONTH_DAY_NANO +/// Version 1.3 - Add Run-End Encoded. namespace org.apache.arrow.flatbuf; @@ -178,6 +179,14 @@ table FixedSizeBinary { table Bool { } +/// Contains two child arrays, run_ends and values. +/// The run_ends child array must be a 16/32/64-bit integer array +/// which encodes the indices at which the run with the value in +/// each corresponding index in the values child array ends. +/// Like list/struct types, the value array can be of any type. +table RunEndEncoded { +} + /// Exact decimal value represented as an integer value in two's /// complement. Currently only 128-bit (16-byte) and 256-bit (32-byte) integers /// are used. The representation uses the endianness indicated @@ -417,6 +426,7 @@ union Type { LargeBinary, LargeUtf8, LargeList, + RunEndEncoded, } /// ----------------------------------------------------------------------