diff --git a/rust/arrow/benches/filter_kernels.rs b/rust/arrow/benches/filter_kernels.rs index 1348238b0747c..9ad46dc37a567 100644 --- a/rust/arrow/benches/filter_kernels.rs +++ b/rust/arrow/benches/filter_kernels.rs @@ -14,128 +14,141 @@ // KIND, either express or implied. See the License for the // specific language governing permissions and limitations // under the License. +extern crate arrow; + +use arrow::{compute::Filter, util::test_util::seedable_rng}; +use rand::{ + distributions::{Alphanumeric, Standard}, + prelude::Distribution, + Rng, +}; use arrow::array::*; -use arrow::compute::{filter, FilterContext}; +use arrow::compute::{build_filter, filter}; use arrow::datatypes::ArrowNumericType; +use arrow::datatypes::{Float32Type, UInt8Type}; + use criterion::{criterion_group, criterion_main, Criterion}; -fn create_primitive_array(size: usize, value_fn: F) -> PrimitiveArray +fn create_primitive_array(size: usize, null_density: f32) -> PrimitiveArray where T: ArrowNumericType, - F: Fn(usize) -> T::Native, + Standard: Distribution, { + // use random numbers to avoid spurious compiler optimizations wrt to branching + let mut rng = seedable_rng(); let mut builder = PrimitiveArray::::builder(size); - for i in 0..size { - builder.append_value(value_fn(i)).unwrap(); + + for _ in 0..size { + if rng.gen::() < null_density { + builder.append_null().unwrap(); + } else { + builder.append_value(rng.gen()).unwrap(); + } } builder.finish() } -fn create_u8_array_with_nulls(size: usize) -> UInt8Array { - let mut builder = UInt8Builder::new(size); - for i in 0..size { - if i % 2 == 0 { - builder.append_value(1).unwrap(); - } else { +fn create_string_array(size: usize, null_density: f32) -> StringArray { + // use random numbers to avoid spurious compiler optimizations wrt to branching + let mut rng = seedable_rng(); + let mut builder = StringBuilder::new(size); + + for _ in 0..size { + if rng.gen::() < null_density { builder.append_null().unwrap(); + } else { + let value = (&mut rng) + .sample_iter(&Alphanumeric) + .take(10) + .collect::(); + builder.append_value(&value).unwrap(); } } builder.finish() } -fn create_bool_array(size: usize, value_fn: F) -> BooleanArray -where - F: Fn(usize) -> bool, -{ +fn create_bool_array(size: usize, trues_density: f32) -> BooleanArray { + let mut rng = seedable_rng(); let mut builder = BooleanBuilder::new(size); - for i in 0..size { - builder.append_value(value_fn(i)).unwrap(); + for _ in 0..size { + let value = rng.gen::() < trues_density; + builder.append_value(value).unwrap(); } builder.finish() } -fn bench_filter_u8(data_array: &UInt8Array, filter_array: &BooleanArray) { - filter( - criterion::black_box(data_array), - criterion::black_box(filter_array), - ) - .unwrap(); +fn bench_filter(data_array: &UInt8Array, filter_array: &BooleanArray) { + criterion::black_box(filter(data_array, filter_array).unwrap()); } -// fn bench_filter_f32(data_array: &Float32Array, filter_array: &BooleanArray) { -// filter(criterion::black_box(data_array), criterion::black_box(filter_array)).unwrap(); -// } - -fn bench_filter_context_u8(data_array: &UInt8Array, filter_context: &FilterContext) { - filter_context - .filter(criterion::black_box(data_array)) - .unwrap(); -} - -fn bench_filter_context_f32(data_array: &Float32Array, filter_context: &FilterContext) { - filter_context - .filter(criterion::black_box(data_array)) - .unwrap(); +fn bench_built_filter<'a>(filter: &Filter<'a>, data: &impl Array) { + criterion::black_box(filter(&data.data())); } fn add_benchmark(c: &mut Criterion) { let size = 65536; - let filter_array = create_bool_array(size, |i| matches!(i % 2, 0)); - let sparse_filter_array = create_bool_array(size, |i| matches!(i % 8000, 0)); - let dense_filter_array = create_bool_array(size, |i| !matches!(i % 8000, 0)); + let filter_array = create_bool_array(size, 0.5); + let sparse_filter_array = create_bool_array(size, 1.0 - 1.0 / 8000.0); + let dense_filter_array = create_bool_array(size, 1.0 / 8000.0); - let filter_context = FilterContext::new(&filter_array).unwrap(); - let sparse_filter_context = FilterContext::new(&sparse_filter_array).unwrap(); - let dense_filter_context = FilterContext::new(&dense_filter_array).unwrap(); + let filter = build_filter(&filter_array).unwrap(); + let sparse_filter = build_filter(&sparse_filter_array).unwrap(); + let dense_filter = build_filter(&dense_filter_array).unwrap(); + + let data_array = create_primitive_array::(size, 0.0); - let data_array = create_primitive_array(size, |i| match i % 2 { - 0 => 1, - _ => 0, - }); c.bench_function("filter u8 low selectivity", |b| { - b.iter(|| bench_filter_u8(&data_array, &filter_array)) + b.iter(|| bench_filter(&data_array, &filter_array)) }); c.bench_function("filter u8 high selectivity", |b| { - b.iter(|| bench_filter_u8(&data_array, &sparse_filter_array)) + b.iter(|| bench_filter(&data_array, &sparse_filter_array)) }); c.bench_function("filter u8 very low selectivity", |b| { - b.iter(|| bench_filter_u8(&data_array, &dense_filter_array)) + b.iter(|| bench_filter(&data_array, &dense_filter_array)) }); c.bench_function("filter context u8 low selectivity", |b| { - b.iter(|| bench_filter_context_u8(&data_array, &filter_context)) + b.iter(|| bench_built_filter(&filter, &data_array)) }); c.bench_function("filter context u8 high selectivity", |b| { - b.iter(|| bench_filter_context_u8(&data_array, &sparse_filter_context)) + b.iter(|| bench_built_filter(&sparse_filter, &data_array)) }); c.bench_function("filter context u8 very low selectivity", |b| { - b.iter(|| bench_filter_context_u8(&data_array, &dense_filter_context)) + b.iter(|| bench_built_filter(&dense_filter, &data_array)) }); - let data_array = create_u8_array_with_nulls(size); + let data_array = create_primitive_array::(size, 0.5); c.bench_function("filter context u8 w NULLs low selectivity", |b| { - b.iter(|| bench_filter_context_u8(&data_array, &filter_context)) + b.iter(|| bench_built_filter(&filter, &data_array)) }); c.bench_function("filter context u8 w NULLs high selectivity", |b| { - b.iter(|| bench_filter_context_u8(&data_array, &sparse_filter_context)) + b.iter(|| bench_built_filter(&sparse_filter, &data_array)) }); c.bench_function("filter context u8 w NULLs very low selectivity", |b| { - b.iter(|| bench_filter_context_u8(&data_array, &dense_filter_context)) + b.iter(|| bench_built_filter(&dense_filter, &data_array)) }); - let data_array = create_primitive_array(size, |i| match i % 2 { - 0 => 1.0, - _ => 0.0, - }); + let data_array = create_primitive_array::(size, 0.5); c.bench_function("filter context f32 low selectivity", |b| { - b.iter(|| bench_filter_context_f32(&data_array, &filter_context)) + b.iter(|| bench_built_filter(&filter, &data_array)) }); c.bench_function("filter context f32 high selectivity", |b| { - b.iter(|| bench_filter_context_f32(&data_array, &sparse_filter_context)) + b.iter(|| bench_built_filter(&sparse_filter, &data_array)) }); c.bench_function("filter context f32 very low selectivity", |b| { - b.iter(|| bench_filter_context_f32(&data_array, &dense_filter_context)) + b.iter(|| bench_built_filter(&dense_filter, &data_array)) + }); + + let data_array = create_string_array(size, 0.5); + c.bench_function("filter context string low selectivity", |b| { + b.iter(|| bench_built_filter(&filter, &data_array)) + }); + c.bench_function("filter context string high selectivity", |b| { + b.iter(|| bench_built_filter(&sparse_filter, &data_array)) + }); + c.bench_function("filter context string very low selectivity", |b| { + b.iter(|| bench_built_filter(&dense_filter, &data_array)) }); } diff --git a/rust/arrow/src/array/mod.rs b/rust/arrow/src/array/mod.rs index d8cfb46449f30..dd1fa0f57eeea 100644 --- a/rust/arrow/src/array/mod.rs +++ b/rust/arrow/src/array/mod.rs @@ -99,6 +99,7 @@ mod iterator; mod null; mod ord; mod raw_pointer; +mod transform; use crate::datatypes::*; @@ -249,6 +250,8 @@ pub type DurationMillisecondBuilder = PrimitiveBuilder; pub type DurationMicrosecondBuilder = PrimitiveBuilder; pub type DurationNanosecondBuilder = PrimitiveBuilder; +pub use self::transform::MutableArrayData; + // --------------------- Array Iterator --------------------- pub use self::iterator::*; diff --git a/rust/arrow/src/array/transform/boolean.rs b/rust/arrow/src/array/transform/boolean.rs new file mode 100644 index 0000000000000..889b99be88ecd --- /dev/null +++ b/rust/arrow/src/array/transform/boolean.rs @@ -0,0 +1,40 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::array::ArrayData; + +use super::{ + Extend, _MutableArrayData, + utils::{reserve_for_bits, set_bits}, +}; + +pub(super) fn build_extend(array: &ArrayData) -> Extend { + let values = array.buffers()[0].data(); + Box::new( + move |mutable: &mut _MutableArrayData, start: usize, len: usize| { + let buffer = &mut mutable.buffers[0]; + reserve_for_bits(buffer, mutable.len + len); + set_bits( + &mut buffer.data_mut(), + values, + mutable.len, + array.offset() + start, + len, + ); + }, + ) +} diff --git a/rust/arrow/src/array/transform/list.rs b/rust/arrow/src/array/transform/list.rs new file mode 100644 index 0000000000000..ff4df85464314 --- /dev/null +++ b/rust/arrow/src/array/transform/list.rs @@ -0,0 +1,75 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::{ + array::{ArrayData, OffsetSizeTrait}, + datatypes::ToByteSlice, +}; + +use super::{Extend, _MutableArrayData, utils::extend_offsets}; + +pub(super) fn build_extend(array: &ArrayData) -> Extend { + let offsets = array.buffer::(0); + if array.null_count() == 0 { + // fast case where we can copy regions without nullability checks + Box::new( + move |mutable: &mut _MutableArrayData, start: usize, len: usize| { + let mutable_offsets = mutable.buffer::(0); + let last_offset = mutable_offsets[mutable_offsets.len() - 1]; + // offsets + extend_offsets::( + &mut mutable.buffers[0], + last_offset, + &offsets[start..start + len + 1], + ); + + mutable.child_data[0].extend( + offsets[start].to_usize().unwrap(), + offsets[start + len].to_usize().unwrap(), + ) + }, + ) + } else { + // nulls present: append item by item, ignoring null entries + Box::new( + move |mutable: &mut _MutableArrayData, start: usize, len: usize| { + let mutable_offsets = mutable.buffer::(0); + let mut last_offset = mutable_offsets[mutable_offsets.len() - 1]; + + let buffer = &mut mutable.buffers[0]; + let delta_len = array.len() - array.null_count(); + buffer.reserve(buffer.len() + delta_len * std::mem::size_of::()); + + let child = &mut mutable.child_data[0]; + (start..start + len).for_each(|i| { + if array.is_valid(i) { + // compute the new offset + last_offset = last_offset + offsets[i + 1] - offsets[i]; + + // append value + child.extend( + offsets[i].to_usize().unwrap(), + offsets[i + 1].to_usize().unwrap(), + ); + } + // append offset + buffer.extend_from_slice(last_offset.to_byte_slice()); + }) + }, + ) + } +} diff --git a/rust/arrow/src/array/transform/mod.rs b/rust/arrow/src/array/transform/mod.rs new file mode 100644 index 0000000000000..96362ab028453 --- /dev/null +++ b/rust/arrow/src/array/transform/mod.rs @@ -0,0 +1,536 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::{mem::size_of, sync::Arc}; + +use crate::{buffer::MutableBuffer, datatypes::DataType, util::bit_util}; + +use super::{ArrayData, ArrayDataRef}; + +mod boolean; +mod list; +mod primitive; +mod utils; +mod variable_size; + +type ExtendNullBits<'a> = Box; +// function that extends `[start..start+len]` to the mutable array. +// this is dynamic because different data_types influence how buffers and childs are extended. +type Extend<'a> = Box; + +/// A mutable [ArrayData] that knows how to freeze itself into an [ArrayData]. +/// This is just a data container. +#[derive(Debug)] +struct _MutableArrayData<'a> { + pub data_type: DataType, + pub null_count: usize, + + pub len: usize, + pub null_buffer: MutableBuffer, + + pub buffers: Vec, + pub child_data: Vec>, +} + +impl<'a> _MutableArrayData<'a> { + fn freeze(self, dictionary: Option) -> ArrayData { + let mut buffers = Vec::with_capacity(self.buffers.len()); + for buffer in self.buffers { + buffers.push(buffer.freeze()); + } + + let child_data = match self.data_type { + DataType::Dictionary(_, _) => vec![dictionary.unwrap()], + _ => { + let mut child_data = Vec::with_capacity(self.child_data.len()); + for child in self.child_data { + child_data.push(Arc::new(child.freeze())); + } + child_data + } + }; + ArrayData::new( + self.data_type, + self.len, + Some(self.null_count), + if self.null_count > 0 { + Some(self.null_buffer.freeze()) + } else { + None + }, + 0, + buffers, + child_data, + ) + } + + /// Returns the buffer `buffer` as a slice of type `T`. When the expected buffer is bit-packed, + /// the slice is not offset. + #[inline] + pub(super) fn buffer(&self, buffer: usize) -> &[T] { + let values = unsafe { self.buffers[buffer].data().align_to::() }; + if !values.0.is_empty() || !values.2.is_empty() { + // this is unreachable because + unreachable!("The buffer is not byte-aligned with its interpretation") + }; + &values.1 + } +} + +fn build_extend_nulls(array: &ArrayData) -> ExtendNullBits { + if let Some(bitmap) = array.null_bitmap() { + let bytes = bitmap.bits.data(); + Box::new(move |mutable, start, len| { + utils::reserve_for_bits(&mut mutable.null_buffer, mutable.len + len); + mutable.null_count += utils::set_bits( + mutable.null_buffer.data_mut(), + bytes, + mutable.len, + array.offset() + start, + len, + ); + }) + } else { + Box::new(|_, _, _| {}) + } +} + +/// Struct to efficiently and interactively create an [ArrayData] from an existing [ArrayData] by +/// copying chunks. +/// The main use case of this struct is to perform unary operations to arrays of arbitrary types, such as `filter` and `take`. +/// # Example: +/// +/// ``` +/// use std::sync::Arc; +/// use arrow::{array::{Int32Array, Array, MutableArrayData}}; +/// +/// let array = Int32Array::from(vec![1, 2, 3, 4, 5]).data(); +/// // Create a new `MutableArrayData` from an array and with a capacity. +/// // Capacity here is equivalent to `Vec::with_capacity` +/// let mut mutable = MutableArrayData::new(&array, 4); +/// mutable.extend(1, 3); // extend from the slice [1..3], [2,3] +/// mutable.extend(0, 3); // extend from the slice [0..3], [1,2,3] +/// // `.freeze()` to convert `MutableArrayData` into a `ArrayData`. +/// let new_array = Int32Array::from(Arc::new(mutable.freeze())); +/// assert_eq!(Int32Array::from(vec![2, 3, 1, 2, 3]), new_array); +/// ``` +pub struct MutableArrayData<'a> { + // The attributes in [_MutableArrayData] cannot be in [MutableArrayData] due to + // mutability invariants (interior mutability): + // [MutableArrayData] contains a function that can only mutate [_MutableArrayData], not + // [MutableArrayData] itself + data: _MutableArrayData<'a>, + + // the child data of the `Array` in Dictionary arrays. + // This is not stored in `MutableArrayData` because these values constant and only needed + // at the end, when freezing [_MutableArrayData]. + dictionary: Option, + + // the function used to extend values. This function's lifetime is bound to the array + // because it reads values from it. + extend_values: Extend<'a>, + // the function used to extend nulls. This function's lifetime is bound to the array + // because it reads nulls from it. + extend_nulls: ExtendNullBits<'a>, +} + +impl<'a> std::fmt::Debug for MutableArrayData<'a> { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + // ignores the closures. + f.debug_struct("MutableArrayData") + .field("data", &self.data) + .finish() + } +} + +impl<'a> MutableArrayData<'a> { + /// returns a new [MutableArrayData] with capacity to `capacity` slots and specialized to create an + /// [ArrayData] from `array` + pub fn new(array: &'a ArrayData, capacity: usize) -> Self { + let data_type = array.data_type(); + use crate::datatypes::*; + let extend_values = match &data_type { + DataType::Boolean => boolean::build_extend(array), + DataType::UInt8 => primitive::build_extend::(array), + DataType::UInt16 => primitive::build_extend::(array), + DataType::UInt32 => primitive::build_extend::(array), + DataType::UInt64 => primitive::build_extend::(array), + DataType::Int8 => primitive::build_extend::(array), + DataType::Int16 => primitive::build_extend::(array), + DataType::Int32 => primitive::build_extend::(array), + DataType::Int64 => primitive::build_extend::(array), + DataType::Float32 => primitive::build_extend::(array), + DataType::Float64 => primitive::build_extend::(array), + DataType::Date32(_) + | DataType::Time32(_) + | DataType::Interval(IntervalUnit::YearMonth) => { + primitive::build_extend::(array) + } + DataType::Date64(_) + | DataType::Time64(_) + | DataType::Timestamp(_, _) + | DataType::Duration(_) + | DataType::Interval(IntervalUnit::DayTime) => { + primitive::build_extend::(array) + } + DataType::Utf8 | DataType::Binary => { + variable_size::build_extend::(array) + } + DataType::LargeUtf8 | DataType::LargeBinary => { + variable_size::build_extend::(array) + } + DataType::List(_) => list::build_extend::(array), + DataType::LargeList(_) => list::build_extend::(array), + DataType::Dictionary(child_data_type, _) => match child_data_type.as_ref() { + DataType::UInt8 => primitive::build_extend::(array), + DataType::UInt16 => primitive::build_extend::(array), + DataType::UInt32 => primitive::build_extend::(array), + DataType::UInt64 => primitive::build_extend::(array), + DataType::Int8 => primitive::build_extend::(array), + DataType::Int16 => primitive::build_extend::(array), + DataType::Int32 => primitive::build_extend::(array), + DataType::Int64 => primitive::build_extend::(array), + _ => unreachable!(), + }, + DataType::Float16 => unreachable!(), + /* + DataType::Null => {} + DataType::FixedSizeBinary(_) => {} + DataType::FixedSizeList(_, _) => {} + DataType::Struct(_) => {} + DataType::Union(_) => {} + */ + _ => { + todo!("Take and filter operations still not supported for this datatype") + } + }; + + let buffers = match &data_type { + DataType::Boolean => { + let bytes = bit_util::ceil(capacity, 8); + let buffer = MutableBuffer::new(bytes).with_bitset(bytes, false); + vec![buffer] + } + DataType::UInt8 => vec![MutableBuffer::new(capacity * size_of::())], + DataType::UInt16 => vec![MutableBuffer::new(capacity * size_of::())], + DataType::UInt32 => vec![MutableBuffer::new(capacity * size_of::())], + DataType::UInt64 => vec![MutableBuffer::new(capacity * size_of::())], + DataType::Int8 => vec![MutableBuffer::new(capacity * size_of::())], + DataType::Int16 => vec![MutableBuffer::new(capacity * size_of::())], + DataType::Int32 => vec![MutableBuffer::new(capacity * size_of::())], + DataType::Int64 => vec![MutableBuffer::new(capacity * size_of::())], + DataType::Float32 => vec![MutableBuffer::new(capacity * size_of::())], + DataType::Float64 => vec![MutableBuffer::new(capacity * size_of::())], + DataType::Date32(_) | DataType::Time32(_) => { + vec![MutableBuffer::new(capacity * size_of::())] + } + DataType::Date64(_) + | DataType::Time64(_) + | DataType::Duration(_) + | DataType::Timestamp(_, _) => { + vec![MutableBuffer::new(capacity * size_of::())] + } + DataType::Interval(IntervalUnit::YearMonth) => { + vec![MutableBuffer::new(capacity * size_of::())] + } + DataType::Interval(IntervalUnit::DayTime) => { + vec![MutableBuffer::new(capacity * size_of::())] + } + DataType::Utf8 | DataType::Binary => { + let mut buffer = MutableBuffer::new((1 + capacity) * size_of::()); + buffer.extend_from_slice(&[0i32].to_byte_slice()); + vec![buffer, MutableBuffer::new(capacity * size_of::())] + } + DataType::LargeUtf8 | DataType::LargeBinary => { + let mut buffer = MutableBuffer::new((1 + capacity) * size_of::()); + buffer.extend_from_slice(&[0i64].to_byte_slice()); + vec![buffer, MutableBuffer::new(capacity * size_of::())] + } + DataType::List(_) => { + // offset buffer always starts with a zero + let mut buffer = MutableBuffer::new((1 + capacity) * size_of::()); + buffer.extend_from_slice(0i32.to_byte_slice()); + vec![buffer] + } + DataType::LargeList(_) => { + // offset buffer always starts with a zero + let mut buffer = MutableBuffer::new((1 + capacity) * size_of::()); + buffer.extend_from_slice(&[0i64].to_byte_slice()); + vec![buffer] + } + DataType::Dictionary(child_data_type, _) => match child_data_type.as_ref() { + DataType::UInt8 => vec![MutableBuffer::new(capacity * size_of::())], + DataType::UInt16 => vec![MutableBuffer::new(capacity * size_of::())], + DataType::UInt32 => vec![MutableBuffer::new(capacity * size_of::())], + DataType::UInt64 => vec![MutableBuffer::new(capacity * size_of::())], + DataType::Int8 => vec![MutableBuffer::new(capacity * size_of::())], + DataType::Int16 => vec![MutableBuffer::new(capacity * size_of::())], + DataType::Int32 => vec![MutableBuffer::new(capacity * size_of::())], + DataType::Int64 => vec![MutableBuffer::new(capacity * size_of::())], + _ => unreachable!(), + }, + DataType::Float16 => unreachable!(), + _ => { + todo!("Take and filter operations still not supported for this datatype") + } + }; + + let child_data = match &data_type { + DataType::Null + | DataType::Boolean + | DataType::UInt8 + | DataType::UInt16 + | DataType::UInt32 + | DataType::UInt64 + | DataType::Int8 + | DataType::Int16 + | DataType::Int32 + | DataType::Int64 + | DataType::Float32 + | DataType::Float64 + | DataType::Date32(_) + | DataType::Date64(_) + | DataType::Time32(_) + | DataType::Time64(_) + | DataType::Duration(_) + | DataType::Timestamp(_, _) + | DataType::Utf8 + | DataType::Binary + | DataType::LargeUtf8 + | DataType::LargeBinary + | DataType::Interval(_) + | DataType::FixedSizeBinary(_) => vec![], + DataType::List(_) | DataType::LargeList(_) => { + vec![MutableArrayData::new(&array.child_data()[0], capacity)] + } + // the dictionary type just appends keys and clones the values. + DataType::Dictionary(_, _) => vec![], + DataType::Float16 => unreachable!(), + _ => { + todo!("Take and filter operations still not supported for this datatype") + } + }; + + let dictionary = match &data_type { + DataType::Dictionary(_, _) => Some(array.child_data()[0].clone()), + _ => None, + }; + + let extend_nulls = build_extend_nulls(array); + + let null_bytes = bit_util::ceil(capacity, 8); + let null_buffer = MutableBuffer::new(null_bytes).with_bitset(null_bytes, false); + + let data = _MutableArrayData { + data_type: data_type.clone(), + len: 0, + null_count: 0, + null_buffer, + buffers, + child_data, + }; + Self { + data, + dictionary, + extend_values: Box::new(extend_values), + extend_nulls, + } + } + + /// Extends this [MutableArrayData] with elements from the bounded [ArrayData] at `start` + /// and for a size of `len`. + /// # Panic + /// This function panics if the range is out of bounds, i.e. if `start + len >= array.len()`. + pub fn extend(&mut self, start: usize, end: usize) { + let len = end - start; + (self.extend_nulls)(&mut self.data, start, len); + (self.extend_values)(&mut self.data, start, len); + self.data.len += len; + } + + /// Creates a [ArrayData] from the pushed regions up to this point, consuming `self`. + pub fn freeze(self) -> ArrayData { + self.data.freeze(self.dictionary) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + use crate::array::{ + Array, ArrayDataRef, BooleanArray, DictionaryArray, Int16Array, Int16Type, + Int64Builder, ListBuilder, PrimitiveBuilder, StringArray, + StringDictionaryBuilder, UInt8Array, + }; + use crate::{array::ListArray, error::Result}; + + /// tests extending from a primitive array w/ offset nor nulls + #[test] + fn test_primitive() { + let b = UInt8Array::from(vec![Some(1), Some(2), Some(3)]).data(); + let mut a = MutableArrayData::new(&b, 3); + a.extend(0, 2); + let result = a.freeze(); + let array = UInt8Array::from(Arc::new(result)); + let expected = UInt8Array::from(vec![Some(1), Some(2)]); + assert_eq!(array, expected); + } + + /// tests extending from a primitive array with offset w/ nulls + #[test] + fn test_primitive_offset() { + let b = UInt8Array::from(vec![Some(1), Some(2), Some(3)]); + let b = b.slice(1, 2).data(); + let mut a = MutableArrayData::new(&b, 2); + a.extend(0, 2); + let result = a.freeze(); + let array = UInt8Array::from(Arc::new(result)); + let expected = UInt8Array::from(vec![Some(2), Some(3)]); + assert_eq!(array, expected); + } + + /// tests extending from a primitive array with offset and nulls + #[test] + fn test_primitive_null_offset() { + let b = UInt8Array::from(vec![Some(1), None, Some(3)]); + let b = b.slice(1, 2).data(); + let mut a = MutableArrayData::new(&b, 2); + a.extend(0, 2); + let result = a.freeze(); + let array = UInt8Array::from(Arc::new(result)); + let expected = UInt8Array::from(vec![None, Some(3)]); + assert_eq!(array, expected); + } + + #[test] + fn test_list_null_offset() -> Result<()> { + let int_builder = Int64Builder::new(24); + let mut builder = ListBuilder::::new(int_builder); + builder.values().append_slice(&[1, 2, 3])?; + builder.append(true)?; + builder.values().append_slice(&[4, 5])?; + builder.append(true)?; + builder.values().append_slice(&[6, 7, 8])?; + builder.append(true)?; + let array = builder.finish().data(); + + let mut mutable = MutableArrayData::new(&array, 0); + mutable.extend(0, 1); + + let result = mutable.freeze(); + let array = ListArray::from(Arc::new(result)); + + let int_builder = Int64Builder::new(24); + let mut builder = ListBuilder::::new(int_builder); + builder.values().append_slice(&[1, 2, 3])?; + builder.append(true)?; + let expected = builder.finish(); + + assert_eq!(array, expected); + + Ok(()) + } + + /// tests extending from a variable-sized (strings and binary) array w/ offset with nulls + #[test] + fn test_variable_sized_nulls() { + let array = + StringArray::from(vec![Some("a"), Some("bc"), None, Some("defh")]).data(); + + let mut mutable = MutableArrayData::new(&array, 0); + + mutable.extend(1, 3); + + let result = mutable.freeze(); + let result = StringArray::from(Arc::new(result)); + + let expected = StringArray::from(vec![Some("bc"), None]); + assert_eq!(result, expected); + } + + /// tests extending from a variable-sized (strings and binary) array + /// with an offset and nulls + #[test] + fn test_variable_sized_offsets() { + let array = + StringArray::from(vec![Some("a"), Some("bc"), None, Some("defh")]).data(); + let array = array.slice(1, 3); + + let mut mutable = MutableArrayData::new(&array, 0); + + mutable.extend(0, 3); + + let result = mutable.freeze(); + let result = StringArray::from(Arc::new(result)); + + let expected = StringArray::from(vec![Some("bc"), None, Some("defh")]); + assert_eq!(result, expected); + } + + #[test] + fn test_bool() { + let array = + BooleanArray::from(vec![Some(false), Some(true), None, Some(false)]).data(); + + let mut mutable = MutableArrayData::new(&array, 0); + + mutable.extend(1, 3); + + let result = mutable.freeze(); + let result = BooleanArray::from(Arc::new(result)); + + let expected = BooleanArray::from(vec![Some(true), None]); + assert_eq!(result, expected); + } + + fn create_dictionary_array(values: &[&str], keys: &[Option<&str>]) -> ArrayDataRef { + let values = StringArray::from(values.to_vec()); + let mut builder = StringDictionaryBuilder::new_with_dictionary( + PrimitiveBuilder::::new(3), + &values, + ) + .unwrap(); + for key in keys { + if let Some(v) = key { + builder.append(v).unwrap(); + } else { + builder.append_null().unwrap() + } + } + builder.finish().data() + } + + #[test] + fn test_dictionary() { + // (a, b, c), (0, 1, 0, 2) => (a, b, a, c) + let array = create_dictionary_array( + &["a", "b", "c"], + &[Some("a"), Some("b"), None, Some("c")], + ); + + let mut mutable = MutableArrayData::new(&array, 0); + + mutable.extend(1, 3); + + let result = mutable.freeze(); + let result = DictionaryArray::from(Arc::new(result)); + + let expected = Int16Array::from(vec![Some(1), None]); + assert_eq!(result.keys(), &expected); + } +} diff --git a/rust/arrow/src/array/transform/primitive.rs b/rust/arrow/src/array/transform/primitive.rs new file mode 100644 index 0000000000000..d2b44f28d4276 --- /dev/null +++ b/rust/arrow/src/array/transform/primitive.rs @@ -0,0 +1,35 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::mem::size_of; + +use crate::{array::ArrayData, datatypes::ArrowNativeType}; + +use super::{Extend, _MutableArrayData}; + +pub(super) fn build_extend(array: &ArrayData) -> Extend { + let values = &array.buffers()[0].data()[array.offset() * size_of::()..]; + Box::new( + move |mutable: &mut _MutableArrayData, start: usize, len: usize| { + let start = start * size_of::(); + let len = len * size_of::(); + let bytes = &values[start..start + len]; + let buffer = &mut mutable.buffers[0]; + buffer.extend_from_slice(bytes); + }, + ) +} diff --git a/rust/arrow/src/array/transform/utils.rs b/rust/arrow/src/array/transform/utils.rs new file mode 100644 index 0000000000000..df9ce2453be14 --- /dev/null +++ b/rust/arrow/src/array/transform/utils.rs @@ -0,0 +1,63 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::{ + array::OffsetSizeTrait, buffer::MutableBuffer, datatypes::ToByteSlice, util::bit_util, +}; + +/// extends the `buffer` to be able to hold `len` bits, setting all bits of the new size to zero. +#[inline] +pub(super) fn reserve_for_bits(buffer: &mut MutableBuffer, len: usize) { + let needed_bytes = bit_util::ceil(len, 8); + if buffer.len() < needed_bytes { + buffer.extend(needed_bytes - buffer.len()); + } +} + +/// sets all bits on `write_data` on the range `[offset_write..offset_write+len]` to be equal to the +/// bits on `data` on the range `[offset_read..offset_read+len]` +pub(super) fn set_bits( + write_data: &mut [u8], + data: &[u8], + offset_write: usize, + offset_read: usize, + len: usize, +) -> usize { + let mut count = 0; + (0..len).for_each(|i| { + if bit_util::get_bit(data, offset_read + i) { + bit_util::set_bit(write_data, offset_write + i); + } else { + count += 1; + } + }); + count +} + +pub(super) fn extend_offsets( + buffer: &mut MutableBuffer, + mut last_offset: T, + offsets: &[T], +) { + buffer.reserve(buffer.len() + offsets.len() * std::mem::size_of::()); + offsets.windows(2).for_each(|offsets| { + // compute the new offset + let length = offsets[1] - offsets[0]; + last_offset = last_offset + length; + buffer.extend_from_slice(last_offset.to_byte_slice()); + }); +} diff --git a/rust/arrow/src/array/transform/variable_size.rs b/rust/arrow/src/array/transform/variable_size.rs new file mode 100644 index 0000000000000..6e7c80a97e1a9 --- /dev/null +++ b/rust/arrow/src/array/transform/variable_size.rs @@ -0,0 +1,93 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::{ + array::{ArrayData, OffsetSizeTrait}, + buffer::MutableBuffer, + datatypes::ToByteSlice, +}; + +use super::{Extend, _MutableArrayData, utils::extend_offsets}; + +fn extend_offset_values( + buffer: &mut MutableBuffer, + offsets: &[T], + values: &[u8], + start: usize, + len: usize, +) { + let start_values = offsets[start].to_usize().unwrap(); + let end_values = offsets[start + len].to_usize().unwrap(); + let new_values = &values[start_values..end_values]; + buffer.extend_from_slice(new_values); +} + +pub(super) fn build_extend(array: &ArrayData) -> Extend { + let offsets = array.buffer::(0); + let values = &array.buffers()[1].data()[array.offset()..]; + if array.null_count() == 0 { + // fast case where we can copy regions without null issues + Box::new( + move |mutable: &mut _MutableArrayData, start: usize, len: usize| { + let mutable_offsets = mutable.buffer::(0); + let last_offset = mutable_offsets[mutable_offsets.len() - 1]; + // offsets + let buffer = &mut mutable.buffers[0]; + extend_offsets::( + buffer, + last_offset, + &offsets[start..start + len + 1], + ); + // values + let buffer = &mut mutable.buffers[1]; + extend_offset_values::(buffer, offsets, values, start, len); + }, + ) + } else { + Box::new( + move |mutable: &mut _MutableArrayData, start: usize, len: usize| { + let mutable_offsets = mutable.buffer::(0); + let mut last_offset = mutable_offsets[mutable_offsets.len() - 1]; + + // nulls present: append item by item, ignoring null entries + let (offset_buffer, values_buffer) = mutable.buffers.split_at_mut(1); + let offset_buffer = &mut offset_buffer[0]; + let values_buffer = &mut values_buffer[0]; + offset_buffer.reserve( + offset_buffer.len() + array.len() * std::mem::size_of::(), + ); + + (start..start + len).for_each(|i| { + if array.is_valid(i) { + // compute the new offset + let length = offsets[i + 1] - offsets[i]; + last_offset = last_offset + length; + let length = length.to_usize().unwrap(); + + // append value + let start = offsets[i].to_usize().unwrap() + - offsets[0].to_usize().unwrap(); + let bytes = &values[start..(start + length)]; + values_buffer.extend_from_slice(bytes); + } + // offsets are always present + offset_buffer.extend_from_slice(last_offset.to_byte_slice()); + }) + }, + ) + } +} diff --git a/rust/arrow/src/buffer.rs b/rust/arrow/src/buffer.rs index 9975ccc7f201f..9217f0d4530cb 100644 --- a/rust/arrow/src/buffer.rs +++ b/rust/arrow/src/buffer.rs @@ -888,6 +888,15 @@ impl MutableBuffer { } self.len += bytes.len(); } + + /// Extends the buffer by `len` with all bytes equal to `0u8`, incrementing its capacity if needed. + pub fn extend(&mut self, len: usize) { + let remaining_capacity = self.capacity - self.len; + if len > remaining_capacity { + self.reserve(self.len + len); + } + self.len += len; + } } impl Drop for MutableBuffer { diff --git a/rust/arrow/src/compute/kernels/filter.rs b/rust/arrow/src/compute/kernels/filter.rs index eb8d3397cfcd1..ccdd37c32e892 100644 --- a/rust/arrow/src/compute/kernels/filter.rs +++ b/rust/arrow/src/compute/kernels/filter.rs @@ -18,798 +18,122 @@ //! Defines miscellaneous array kernels. use crate::array::*; -use crate::datatypes::*; -use crate::error::{ArrowError, Result}; +use crate::error::Result; use crate::record_batch::RecordBatch; -use crate::{ - bitmap::Bitmap, - buffer::{Buffer, MutableBuffer}, - util::bit_util, -}; -use std::{mem, sync::Arc}; - -/// trait for copying filtered null bitmap bits -trait CopyNullBit { - fn copy_null_bit(&mut self, source_index: usize); - fn copy_null_bits(&mut self, source_index: usize, count: usize); - fn null_count(&self) -> usize; - fn null_buffer(&mut self) -> Buffer; -} - -/// no-op null bitmap copy implementation, -/// used when the filtered data array doesn't have a null bitmap -struct NullBitNoop {} - -impl NullBitNoop { - fn new() -> Self { - NullBitNoop {} - } -} - -impl CopyNullBit for NullBitNoop { - #[inline] - fn copy_null_bit(&mut self, _source_index: usize) { - // do nothing - } - - #[inline] - fn copy_null_bits(&mut self, _source_index: usize, _count: usize) { - // do nothing - } - - fn null_count(&self) -> usize { - 0 - } - - fn null_buffer(&mut self) -> Buffer { - Buffer::from([0u8; 0]) - } -} - -/// null bitmap copy implementation, -/// used when the filtered data array has a null bitmap -struct NullBitSetter<'a> { - target_buffer: MutableBuffer, - source_bytes: &'a [u8], - target_index: usize, - null_count: usize, -} - -impl<'a> NullBitSetter<'a> { - fn new(null_bitmap: &'a Bitmap) -> Self { - let null_bytes = null_bitmap.buffer_ref().data(); - // create null bitmap buffer with same length and initialize null bitmap buffer to 1s - let null_buffer = - MutableBuffer::new(null_bytes.len()).with_bitset(null_bytes.len(), true); - NullBitSetter { - source_bytes: null_bytes, - target_buffer: null_buffer, - target_index: 0, - null_count: 0, - } - } -} - -impl<'a> CopyNullBit for NullBitSetter<'a> { - #[inline] - fn copy_null_bit(&mut self, source_index: usize) { - if !bit_util::get_bit(self.source_bytes, source_index) { - bit_util::unset_bit(self.target_buffer.data_mut(), self.target_index); - self.null_count += 1; - } - self.target_index += 1; - } - - #[inline] - fn copy_null_bits(&mut self, source_index: usize, count: usize) { - for i in 0..count { - self.copy_null_bit(source_index + i); - } - } - - fn null_count(&self) -> usize { - self.null_count - } - - fn null_buffer(&mut self) -> Buffer { - self.target_buffer.resize(self.target_index); - // use mem::replace to detach self.target_buffer from self so that it can be returned - let target_buffer = mem::replace(&mut self.target_buffer, MutableBuffer::new(0)); - target_buffer.freeze() - } -} - -fn get_null_bit_setter<'a>(data_array: &'a impl Array) -> Box { - if let Some(null_bitmap) = data_array.data_ref().null_bitmap() { - // only return an actual null bit copy implementation if null_bitmap is set - Box::new(NullBitSetter::new(null_bitmap)) - } else { - // otherwise return a no-op copy null bit implementation - // for improved performance when the filtered array doesn't contain NULLs - Box::new(NullBitNoop::new()) - } -} - -// transmute filter array to u64 -// - optimize filtering with highly selective filters by skipping entire batches of 64 filter bits -// - if the data array being filtered doesn't have a null bitmap, no time is wasted to copy a null bitmap -fn filter_array_impl( - filter_context: &FilterContext, - data_array: &impl Array, - array_type: DataType, - value_size: usize, -) -> Result { - if filter_context.filter_len > data_array.len() { - return Err(ArrowError::ComputeError( - "Filter array cannot be larger than data array".to_string(), - )); - } - let filtered_count = filter_context.filtered_count; - let filter_mask = &filter_context.filter_mask; - let filter_u64 = &filter_context.filter_u64; - let data_bytes = data_array.data_ref().buffers()[0].data(); - let mut target_buffer = MutableBuffer::new(filtered_count * value_size); - target_buffer.resize(filtered_count * value_size); - let target_bytes = target_buffer.data_mut(); - let mut target_byte_index: usize = 0; - let mut null_bit_setter = get_null_bit_setter(data_array); - let null_bit_setter = null_bit_setter.as_mut(); - let all_ones_batch = !0u64; - let data_array_offset = data_array.offset(); - - for (i, filter_batch) in filter_u64.iter().enumerate() { - // foreach u64 batch - let filter_batch = *filter_batch; - if filter_batch == 0 { - // if batch == 0, all items are filtered out, so skip entire batch - continue; - } else if filter_batch == all_ones_batch { - // if batch == all 1s: copy all 64 values in one go - let data_index = (i * 64) + data_array_offset; - null_bit_setter.copy_null_bits(data_index, 64); - let data_byte_index = data_index * value_size; - let data_len = value_size * 64; - target_bytes[target_byte_index..(target_byte_index + data_len)] - .copy_from_slice( - &data_bytes[data_byte_index..(data_byte_index + data_len)], - ); - target_byte_index += data_len; - continue; - } - for (j, filter_mask) in filter_mask.iter().enumerate() { - // foreach bit in batch: - if (filter_batch & *filter_mask) != 0 { - let data_index = (i * 64) + j + data_array_offset; - null_bit_setter.copy_null_bit(data_index); - // if filter bit == 1: copy data value bytes - let data_byte_index = data_index * value_size; - target_bytes[target_byte_index..(target_byte_index + value_size)] - .copy_from_slice( - &data_bytes[data_byte_index..(data_byte_index + value_size)], - ); - target_byte_index += value_size; +use std::sync::Arc; + +/// Function that can filter arbitrary arrays +pub type Filter<'a> = Box ArrayData + 'a>; + +/// Returns a vector of slices (start, end) denoting the sizes that are +/// required to be copied from any array when `filter` is applied +fn compute_slices(filter: &BooleanArray) -> (Vec<(usize, usize)>, usize) { + let mut slices = Vec::with_capacity(filter.len()); + let mut filter_count = 0; // the number of selected items. + let mut len = 0; // the len of the region of selection + let mut start = 0; // the start of the region of selection + let mut on_region = false; // whether it is in a region of selection + let all_ones = !0u64; + + let buffer = filter.values(); + let bit_chunks = buffer.bit_chunks(filter.offset(), filter.len()); + let iter = bit_chunks.iter(); + let chunks = iter.len(); + bit_chunks.iter().enumerate().for_each(|(i, mask)| { + if mask == 0 { + if on_region { + slices.push((start, start + len)); + filter_count += len; + len = 0; + on_region = false; } - } - } - - let mut array_data_builder = ArrayDataBuilder::new(array_type) - .len(filtered_count) - .add_buffer(target_buffer.freeze()); - if null_bit_setter.null_count() > 0 { - array_data_builder = array_data_builder - .null_count(null_bit_setter.null_count()) - .null_bit_buffer(null_bit_setter.null_buffer()); - } - - Ok(array_data_builder) -} - -/// FilterContext can be used to improve performance when -/// filtering multiple data arrays with the same filter array. -#[derive(Debug)] -pub struct FilterContext { - filter_u64: Vec, - filter_len: usize, - filtered_count: usize, - filter_mask: Vec, -} - -macro_rules! filter_primitive_array { - ($context:expr, $array:expr, $array_type:ident) => {{ - let input_array = $array.as_any().downcast_ref::<$array_type>().unwrap(); - let output_array = $context.filter_primitive_array(input_array)?; - Ok(Arc::new(output_array)) - }}; -} - -macro_rules! filter_dictionary_array { - ($context:expr, $array:expr, $array_type:ident) => {{ - let input_array = $array.as_any().downcast_ref::<$array_type>().unwrap(); - let output_array = $context.filter_dictionary_array(input_array)?; - Ok(Arc::new(output_array)) - }}; -} - -macro_rules! filter_primitive_item_list_array { - ($context:expr, $array:expr, $item_type:ident, $list_type:ident, $list_builder_type:ident) => {{ - let input_array = $array.as_any().downcast_ref::<$list_type>().unwrap(); - let values_builder = PrimitiveBuilder::<$item_type>::new($context.filtered_count); - let mut builder = $list_builder_type::new(values_builder); - for i in 0..$context.filter_u64.len() { - // foreach u64 batch - let filter_batch = $context.filter_u64[i]; - if filter_batch == 0 { - // if batch == 0, all items are filtered out, so skip entire batch - continue; + } else if mask == all_ones { + if !on_region { + start = i * 64; + on_region = true; } - for j in 0..64 { - // foreach bit in batch: - if (filter_batch & $context.filter_mask[j]) != 0 { - let data_index = (i * 64) + j; - if input_array.is_null(data_index) { - builder.append(false)?; - } else { - let this_inner_list = input_array.value(data_index); - let inner_list = this_inner_list - .as_any() - .downcast_ref::>() - .unwrap(); - for k in 0..inner_list.len() { - if inner_list.is_null(k) { - builder.values().append_null()?; - } else { - builder.values().append_value(inner_list.value(k))?; - } - } - builder.append(true)?; + len += 64; + } else { + (0..64).for_each(|j| { + if (mask & (1 << j)) != 0 { + if !on_region { + start = i * 64 + j; + on_region = true; } + len += 1; + } else if on_region { + slices.push((start, start + len)); + filter_count += len; + len = 0; + on_region = false; } - } + }) } - Ok(Arc::new(builder.finish())) - }}; -} - -macro_rules! filter_non_primitive_item_list_array { - ($context:expr, $array:expr, $item_array_type:ident, $item_builder:ident, $list_type:ident, $list_builder_type:ident) => {{ - let input_array = $array.as_any().downcast_ref::<$list_type>().unwrap(); - let values_builder = $item_builder::new($context.filtered_count); - let mut builder = $list_builder_type::new(values_builder); - for i in 0..$context.filter_u64.len() { - // foreach u64 batch - let filter_batch = $context.filter_u64[i]; - if filter_batch == 0 { - // if batch == 0, all items are filtered out, so skip entire batch - continue; - } - for j in 0..64 { - // foreach bit in batch: - if (filter_batch & $context.filter_mask[j]) != 0 { - let data_index = (i * 64) + j; - if input_array.is_null(data_index) { - builder.append(false)?; - } else { - let this_inner_list = input_array.value(data_index); - let inner_list = this_inner_list - .as_any() - .downcast_ref::<$item_array_type>() - .unwrap(); - for k in 0..inner_list.len() { - if inner_list.is_null(k) { - builder.values().append_null()?; - } else { - builder.values().append_value(inner_list.value(k))?; - } - } - builder.append(true)?; - } - } + }); + let mask = bit_chunks.remainder_bits(); + + (0..bit_chunks.remainder_len()).for_each(|j| { + if (mask & (1 << j)) != 0 { + if !on_region { + start = chunks * 64 + j; + on_region = true; } + len += 1; + } else if on_region { + slices.push((start, start + len)); + filter_count += len; + len = 0; + on_region = false; } - Ok(Arc::new(builder.finish())) - }}; + }); + if on_region { + slices.push((start, start + len)); + filter_count += len; + }; + // invariant: filter_count is the sum of the lens of all slices + (slices, filter_count) } -impl FilterContext { - /// Returns a new instance of FilterContext - pub fn new(filter_array: &BooleanArray) -> Result { - if filter_array.offset() > 0 { - return Err(ArrowError::ComputeError( - "Filter array cannot have offset > 0".to_string(), - )); - } - let filter_mask: Vec = (0..64).map(|x| 1u64 << x).collect(); - let filter_buffer = &filter_array.data_ref().buffers()[0]; - let filtered_count = filter_buffer.count_set_bits_offset(0, filter_array.len()); - - let filter_bytes = filter_buffer.data(); - - // transmute filter_bytes to &[u64] - let mut u64_buffer = MutableBuffer::new(filter_bytes.len()); - // add to the resulting len so is is a multiple of the size of u64 - let pad_addional_len = (8 - filter_bytes.len() % 8) % 8; - u64_buffer.extend_from_slice(filter_bytes); - u64_buffer.extend_from_slice(&vec![0; pad_addional_len]); - let mut filter_u64 = u64_buffer.typed_data_mut::().to_owned(); - - // mask of any bits outside of the given len - if filter_array.len() % 64 != 0 { - let last_idx = filter_u64.len() - 1; - let mask = u64::MAX >> (64 - filter_array.len() % 64); - filter_u64[last_idx] &= mask; - } - - Ok(FilterContext { - filter_u64, - filter_len: filter_array.len(), - filtered_count, - filter_mask, - }) - } - - /// Returns a new array, containing only the elements matching the filter - pub fn filter(&self, array: &Array) -> Result { - match array.data_type() { - DataType::UInt8 => filter_primitive_array!(self, array, UInt8Array), - DataType::UInt16 => filter_primitive_array!(self, array, UInt16Array), - DataType::UInt32 => filter_primitive_array!(self, array, UInt32Array), - DataType::UInt64 => filter_primitive_array!(self, array, UInt64Array), - DataType::Int8 => filter_primitive_array!(self, array, Int8Array), - DataType::Int16 => filter_primitive_array!(self, array, Int16Array), - DataType::Int32 => filter_primitive_array!(self, array, Int32Array), - DataType::Int64 => filter_primitive_array!(self, array, Int64Array), - DataType::Float32 => filter_primitive_array!(self, array, Float32Array), - DataType::Float64 => filter_primitive_array!(self, array, Float64Array), - DataType::Boolean => { - let input_array = array.as_any().downcast_ref::().unwrap(); - let mut builder = BooleanArray::builder(self.filtered_count); - for i in 0..self.filter_u64.len() { - // foreach u64 batch - let filter_batch = self.filter_u64[i]; - if filter_batch == 0 { - // if batch == 0, all items are filtered out, so skip entire batch - continue; - } - for j in 0..64 { - // foreach bit in batch: - if (filter_batch & self.filter_mask[j]) != 0 { - let data_index = (i * 64) + j; - if input_array.is_null(data_index) { - builder.append_null()?; - } else { - builder.append_value(input_array.value(data_index))?; - } - } - } - } - Ok(Arc::new(builder.finish())) - }, - DataType::Date32(_) => filter_primitive_array!(self, array, Date32Array), - DataType::Date64(_) => filter_primitive_array!(self, array, Date64Array), - DataType::Time32(TimeUnit::Second) => { - filter_primitive_array!(self, array, Time32SecondArray) - } - DataType::Time32(TimeUnit::Millisecond) => { - filter_primitive_array!(self, array, Time32MillisecondArray) - } - DataType::Time64(TimeUnit::Microsecond) => { - filter_primitive_array!(self, array, Time64MicrosecondArray) - } - DataType::Time64(TimeUnit::Nanosecond) => { - filter_primitive_array!(self, array, Time64NanosecondArray) - } - DataType::Duration(TimeUnit::Second) => { - filter_primitive_array!(self, array, DurationSecondArray) - } - DataType::Duration(TimeUnit::Millisecond) => { - filter_primitive_array!(self, array, DurationMillisecondArray) - } - DataType::Duration(TimeUnit::Microsecond) => { - filter_primitive_array!(self, array, DurationMicrosecondArray) - } - DataType::Duration(TimeUnit::Nanosecond) => { - filter_primitive_array!(self, array, DurationNanosecondArray) - } - DataType::Timestamp(TimeUnit::Second, _) => { - filter_primitive_array!(self, array, TimestampSecondArray) - } - DataType::Timestamp(TimeUnit::Millisecond, _) => { - filter_primitive_array!(self, array, TimestampMillisecondArray) - } - DataType::Timestamp(TimeUnit::Microsecond, _) => { - filter_primitive_array!(self, array, TimestampMicrosecondArray) - } - DataType::Timestamp(TimeUnit::Nanosecond, _) => { - filter_primitive_array!(self, array, TimestampNanosecondArray) - } - DataType::Binary => { - let input_array = array.as_any().downcast_ref::().unwrap(); - let mut values: Vec> = Vec::with_capacity(self.filtered_count); - for i in 0..self.filter_u64.len() { - // foreach u64 batch - let filter_batch = self.filter_u64[i]; - if filter_batch == 0 { - // if batch == 0, all items are filtered out, so skip entire batch - continue; - } - for j in 0..64 { - // foreach bit in batch: - if (filter_batch & self.filter_mask[j]) != 0 { - let data_index = (i * 64) + j; - if input_array.is_null(data_index) { - values.push(None) - } else { - values.push(Some(input_array.value(data_index))) - } - } - } - } - Ok(Arc::new(BinaryArray::from(values))) - } - DataType::Utf8 => { - let input_array = array.as_any().downcast_ref::().unwrap(); - let mut values: Vec> = Vec::with_capacity(self.filtered_count); - for i in 0..self.filter_u64.len() { - // foreach u64 batch - let filter_batch = self.filter_u64[i]; - if filter_batch == 0 { - // if batch == 0, all items are filtered out, so skip entire batch - continue; - } - for j in 0..64 { - // foreach bit in batch: - if (filter_batch & self.filter_mask[j]) != 0 { - let data_index = (i * 64) + j; - if input_array.is_null(data_index) { - values.push(None) - } else { - values.push(Some(input_array.value(data_index))) - } - } - } - } - Ok(Arc::new(StringArray::from(values))) - } - DataType::Dictionary(ref key_type, ref value_type) => match (key_type.as_ref(), value_type.as_ref()) { - (key_type, DataType::Utf8) => match key_type { - DataType::UInt8 => filter_dictionary_array!(self, array, UInt8DictionaryArray), - DataType::UInt16 => filter_dictionary_array!(self, array, UInt16DictionaryArray), - DataType::UInt32 => filter_dictionary_array!(self, array, UInt32DictionaryArray), - DataType::UInt64 => filter_dictionary_array!(self, array, UInt64DictionaryArray), - DataType::Int8 => filter_dictionary_array!(self, array, Int8DictionaryArray), - DataType::Int16 => filter_dictionary_array!(self, array, Int16DictionaryArray), - DataType::Int32 => filter_dictionary_array!(self, array, Int32DictionaryArray), - DataType::Int64 => filter_dictionary_array!(self, array, Int64DictionaryArray), - other => Err(ArrowError::ComputeError(format!( - "filter not supported for string dictionary with key of type {:?}", - other - ))) - } - (key_type, value_type) => Err(ArrowError::ComputeError(format!( - "filter not supported for Dictionary({:?}, {:?})", - key_type, value_type - ))) - } - DataType::List(dt) => match dt.data_type() { - DataType::UInt8 => { - filter_primitive_item_list_array!(self, array, UInt8Type, ListArray, ListBuilder) - } - DataType::UInt16 => { - filter_primitive_item_list_array!(self, array, UInt16Type, ListArray, ListBuilder) - } - DataType::UInt32 => { - filter_primitive_item_list_array!(self, array, UInt32Type, ListArray, ListBuilder) - } - DataType::UInt64 => { - filter_primitive_item_list_array!(self, array, UInt64Type, ListArray, ListBuilder) - } - DataType::Int8 => filter_primitive_item_list_array!(self, array, Int8Type, ListArray, ListBuilder), - DataType::Int16 => { - filter_primitive_item_list_array!(self, array, Int16Type, ListArray, ListBuilder) - } - DataType::Int32 => { - filter_primitive_item_list_array!(self, array, Int32Type, ListArray, ListBuilder) - } - DataType::Int64 => { - filter_primitive_item_list_array!(self, array, Int64Type, ListArray, ListBuilder) - } - DataType::Float32 => { - filter_primitive_item_list_array!(self, array, Float32Type, ListArray, ListBuilder) - } - DataType::Float64 => { - filter_primitive_item_list_array!(self, array, Float64Type, ListArray, ListBuilder) - } - DataType::Boolean => { - filter_primitive_item_list_array!(self, array, BooleanType, ListArray, ListBuilder) - } - DataType::Date32(_) => { - filter_primitive_item_list_array!(self, array, Date32Type, ListArray, ListBuilder) - } - DataType::Date64(_) => { - filter_primitive_item_list_array!(self, array, Date64Type, ListArray, ListBuilder) - } - DataType::Time32(TimeUnit::Second) => { - filter_primitive_item_list_array!(self, array, Time32SecondType, ListArray, ListBuilder) - } - DataType::Time32(TimeUnit::Millisecond) => { - filter_primitive_item_list_array!(self, array, Time32MillisecondType, ListArray, ListBuilder) - } - DataType::Time64(TimeUnit::Microsecond) => { - filter_primitive_item_list_array!(self, array, Time64MicrosecondType, ListArray, ListBuilder) - } - DataType::Time64(TimeUnit::Nanosecond) => { - filter_primitive_item_list_array!(self, array, Time64NanosecondType, ListArray, ListBuilder) - } - DataType::Duration(TimeUnit::Second) => { - filter_primitive_item_list_array!(self, array, DurationSecondType, ListArray, ListBuilder) - } - DataType::Duration(TimeUnit::Millisecond) => { - filter_primitive_item_list_array!(self, array, DurationMillisecondType, ListArray, ListBuilder) - } - DataType::Duration(TimeUnit::Microsecond) => { - filter_primitive_item_list_array!(self, array, DurationMicrosecondType, ListArray, ListBuilder) - } - DataType::Duration(TimeUnit::Nanosecond) => { - filter_primitive_item_list_array!(self, array, DurationNanosecondType, ListArray, ListBuilder) - } - DataType::Timestamp(TimeUnit::Second, _) => { - filter_primitive_item_list_array!(self, array, TimestampSecondType, ListArray, ListBuilder) - } - DataType::Timestamp(TimeUnit::Millisecond, _) => { - filter_primitive_item_list_array!(self, array, TimestampMillisecondType, ListArray, ListBuilder) - } - DataType::Timestamp(TimeUnit::Microsecond, _) => { - filter_primitive_item_list_array!(self, array, TimestampMicrosecondType, ListArray, ListBuilder) - } - DataType::Timestamp(TimeUnit::Nanosecond, _) => { - filter_primitive_item_list_array!(self, array, TimestampNanosecondType, ListArray, ListBuilder) - } - DataType::Binary => filter_non_primitive_item_list_array!( - self, - array, - BinaryArray, - BinaryBuilder, - ListArray, - ListBuilder - ), - DataType::LargeBinary => filter_non_primitive_item_list_array!( - self, - array, - LargeBinaryArray, - LargeBinaryBuilder, - ListArray, - ListBuilder - ), - DataType::Utf8 => filter_non_primitive_item_list_array!( - self, - array, - StringArray, - StringBuilder, - ListArray - ,ListBuilder - ), - DataType::LargeUtf8 => filter_non_primitive_item_list_array!( - self, - array, - LargeStringArray, - LargeStringBuilder, - ListArray, - ListBuilder - ), - other => { - Err(ArrowError::ComputeError(format!( - "filter not supported for List({:?})", - other - ))) - } - } - DataType::LargeList(dt) => match dt.data_type() { - DataType::UInt8 => { - filter_primitive_item_list_array!(self, array, UInt8Type, LargeListArray, LargeListBuilder) - } - DataType::UInt16 => { - filter_primitive_item_list_array!(self, array, UInt16Type, LargeListArray, LargeListBuilder) - } - DataType::UInt32 => { - filter_primitive_item_list_array!(self, array, UInt32Type, LargeListArray, LargeListBuilder) - } - DataType::UInt64 => { - filter_primitive_item_list_array!(self, array, UInt64Type, LargeListArray, LargeListBuilder) - } - DataType::Int8 => filter_primitive_item_list_array!(self, array, Int8Type, LargeListArray, LargeListBuilder), - DataType::Int16 => { - filter_primitive_item_list_array!(self, array, Int16Type, LargeListArray, LargeListBuilder) - } - DataType::Int32 => { - filter_primitive_item_list_array!(self, array, Int32Type, LargeListArray, LargeListBuilder) - } - DataType::Int64 => { - filter_primitive_item_list_array!(self, array, Int64Type, LargeListArray, LargeListBuilder) - } - DataType::Float32 => { - filter_primitive_item_list_array!(self, array, Float32Type, LargeListArray, LargeListBuilder) - } - DataType::Float64 => { - filter_primitive_item_list_array!(self, array, Float64Type, LargeListArray, LargeListBuilder) - } - DataType::Boolean => { - filter_primitive_item_list_array!(self, array, BooleanType, LargeListArray, LargeListBuilder) - } - DataType::Date32(_) => { - filter_primitive_item_list_array!(self, array, Date32Type, LargeListArray, LargeListBuilder) - } - DataType::Date64(_) => { - filter_primitive_item_list_array!(self, array, Date64Type, LargeListArray, LargeListBuilder) - } - DataType::Time32(TimeUnit::Second) => { - filter_primitive_item_list_array!(self, array, Time32SecondType, LargeListArray, LargeListBuilder) - } - DataType::Time32(TimeUnit::Millisecond) => { - filter_primitive_item_list_array!(self, array, Time32MillisecondType, LargeListArray, LargeListBuilder) - } - DataType::Time64(TimeUnit::Microsecond) => { - filter_primitive_item_list_array!(self, array, Time64MicrosecondType, LargeListArray, LargeListBuilder) - } - DataType::Time64(TimeUnit::Nanosecond) => { - filter_primitive_item_list_array!(self, array, Time64NanosecondType, LargeListArray, LargeListBuilder) - } - DataType::Duration(TimeUnit::Second) => { - filter_primitive_item_list_array!(self, array, DurationSecondType, LargeListArray, LargeListBuilder) - } - DataType::Duration(TimeUnit::Millisecond) => { - filter_primitive_item_list_array!(self, array, DurationMillisecondType, LargeListArray, LargeListBuilder) - } - DataType::Duration(TimeUnit::Microsecond) => { - filter_primitive_item_list_array!(self, array, DurationMicrosecondType, LargeListArray, LargeListBuilder) - } - DataType::Duration(TimeUnit::Nanosecond) => { - filter_primitive_item_list_array!(self, array, DurationNanosecondType, LargeListArray, LargeListBuilder) - } - DataType::Timestamp(TimeUnit::Second, _) => { - filter_primitive_item_list_array!(self, array, TimestampSecondType, LargeListArray, LargeListBuilder) - } - DataType::Timestamp(TimeUnit::Millisecond, _) => { - filter_primitive_item_list_array!(self, array, TimestampMillisecondType, LargeListArray, LargeListBuilder) - } - DataType::Timestamp(TimeUnit::Microsecond, _) => { - filter_primitive_item_list_array!(self, array, TimestampMicrosecondType, LargeListArray, LargeListBuilder) - } - DataType::Timestamp(TimeUnit::Nanosecond, _) => { - filter_primitive_item_list_array!(self, array, TimestampNanosecondType, LargeListArray, LargeListBuilder) - } - DataType::Binary => filter_non_primitive_item_list_array!( - self, - array, - BinaryArray, - BinaryBuilder, - LargeListArray, - LargeListBuilder - ), - DataType::LargeBinary => filter_non_primitive_item_list_array!( - self, - array, - LargeBinaryArray, - LargeBinaryBuilder, - LargeListArray, - LargeListBuilder - ), - DataType::Utf8 => filter_non_primitive_item_list_array!( - self, - array, - StringArray, - StringBuilder, - LargeListArray, - LargeListBuilder - ), - DataType::LargeUtf8 => filter_non_primitive_item_list_array!( - self, - array, - LargeStringArray, - LargeStringBuilder, - LargeListArray, - LargeListBuilder - ), - other => { - Err(ArrowError::ComputeError(format!( - "filter not supported for LargeList({:?})", - other - ))) - } - } - other => Err(ArrowError::ComputeError(format!( - "filter not supported for {:?}", - other - ))), - } - } - - /// Returns a new PrimitiveArray containing only those values from the array passed as the data_array parameter, - /// selected by the BooleanArray passed as the filter_array parameter - pub fn filter_primitive_array( - &self, - data_array: &PrimitiveArray, - ) -> Result> - where - T: ArrowNumericType, - { - let array_type = T::DATA_TYPE; - let value_size = mem::size_of::(); - let array_data_builder = - filter_array_impl(self, data_array, array_type, value_size)?; - let data = array_data_builder.build(); - Ok(PrimitiveArray::::from(data)) - } - - /// Returns a new DictionaryArray containing only those keys from the array passed as the data_array parameter, - /// selected by the BooleanArray passed as the filter_array parameter. The values are cloned from the data_array. - pub fn filter_dictionary_array( - &self, - data_array: &DictionaryArray, - ) -> Result> - where - T: ArrowNumericType, - { - let array_type = data_array.data_type().clone(); - let value_size = mem::size_of::(); - let mut array_data_builder = - filter_array_impl(self, data_array, array_type, value_size)?; - // copy dictionary values from input array - array_data_builder = - array_data_builder.add_child_data(data_array.values().data()); - let data = array_data_builder.build(); - Ok(DictionaryArray::::from(data)) - } +/// Returns a function with pre-computed values that can be used to filter arbitrary arrays, +/// thereby improving performance when filtering multiple arrays +pub fn build_filter<'a>(filter: &'a BooleanArray) -> Result> { + let (chunks, filter_count) = compute_slices(filter); + Ok(Box::new(move |array: &ArrayData| { + let mut mutable = MutableArrayData::new(array, filter_count); + chunks + .iter() + .for_each(|(start, end)| mutable.extend(*start, *end)); + mutable.freeze() + })) } /// Returns a new array, containing only the elements matching the filter. pub fn filter(array: &Array, filter: &BooleanArray) -> Result { - FilterContext::new(filter)?.filter(array) -} - -/// Returns a new PrimitiveArray containing only those values from the array passed as the data_array parameter, -/// selected by the BooleanArray passed as the filter_array parameter -pub fn filter_primitive_array( - data_array: &PrimitiveArray, - filter_array: &BooleanArray, -) -> Result> -where - T: ArrowNumericType, -{ - FilterContext::new(filter_array)?.filter_primitive_array(data_array) -} - -/// Returns a new DictionaryArray containing only those keys from the array passed as the data_array parameter, -/// selected by the BooleanArray passed as the filter_array parameter. The values are cloned from the data_array. -pub fn filter_dictionary_array( - data_array: &DictionaryArray, - filter_array: &BooleanArray, -) -> Result> -where - T: ArrowNumericType, -{ - FilterContext::new(filter_array)?.filter_dictionary_array(data_array) + Ok(make_array(Arc::new((build_filter(filter)?)(&array.data())))) } /// Returns a new RecordBatch with arrays containing only values matching the filter. -/// The same FilterContext is re-used when filtering arrays in the RecordBatch for better performance. pub fn filter_record_batch( record_batch: &RecordBatch, filter_array: &BooleanArray, ) -> Result { - let filter_context = FilterContext::new(filter_array)?; + let filter = build_filter(filter_array)?; let filtered_arrays = record_batch .columns() .iter() - .map(|a| filter_context.filter(a.as_ref())) - .collect::>>()?; + .map(|a| make_array(Arc::new(filter(&a.data())))) + .collect(); RecordBatch::try_new(record_batch.schema(), filtered_arrays) } #[cfg(test)] mod tests { use super::*; - use crate::buffer::Buffer; use crate::datatypes::ToByteSlice; + use crate::{ + buffer::Buffer, + datatypes::{DataType, Field}, + }; macro_rules! def_temporal_test { ($test:ident, $array_type: ident, $data: expr) => { @@ -974,7 +298,7 @@ mod tests { } #[test] - fn test_filter_string_array() { + fn test_filter_string_array_simple() { let a = StringArray::from(vec!["hello", " ", "world", "!"]); let b = BooleanArray::from(vec![true, false, true, false]); let c = filter(&a, &b).unwrap(); @@ -1096,36 +420,25 @@ mod tests { // a = [[0, 1, 2], [3, 4, 5], [6, 7], null] let a = LargeListArray::from(list_data); let b = BooleanArray::from(vec![false, true, false, true]); - let c = filter(&a, &b).unwrap(); - let d = c - .as_ref() - .as_any() - .downcast_ref::() - .unwrap(); + let result = filter(&a, &b).unwrap(); - assert_eq!(DataType::Int32, d.value_type()); + // expected: [[3, 4, 5], null] + let value_data = ArrayData::builder(DataType::Int32) + .len(3) + .add_buffer(Buffer::from(&[3, 4, 5].to_byte_slice())) + .build(); - // result should be [[3, 4, 5], null] - assert_eq!(2, d.len()); - assert_eq!(1, d.null_count()); - assert_eq!(true, d.is_null(1)); + let value_offsets = Buffer::from(&[0i64, 3, 3].to_byte_slice()); + + let list_data_type = + DataType::LargeList(Box::new(Field::new("item", DataType::Int32, false))); + let expected = ArrayData::builder(list_data_type) + .len(2) + .add_buffer(value_offsets) + .add_child_data(value_data) + .null_bit_buffer(Buffer::from([0b00000001])) + .build(); - assert_eq!(0, d.value_offset(0)); - assert_eq!(3, d.value_length(0)); - assert_eq!(3, d.value_offset(1)); - assert_eq!(0, d.value_length(1)); - assert_eq!( - Buffer::from(&[3, 4, 5].to_byte_slice()), - d.values().data().buffers()[0].clone() - ); - assert_eq!( - Buffer::from(&[0i64, 3, 3].to_byte_slice()), - d.data().buffers()[0].clone() - ); - let inner_list = d.value(0); - let inner_list = inner_list.as_any().downcast_ref::().unwrap(); - assert_eq!(3, inner_list.len()); - assert_eq!(0, inner_list.null_count()); - assert_eq!(inner_list, &Int32Array::from(vec![3, 4, 5])); + assert_eq!(&make_array(expected), &result); } }