diff --git a/arrow-select/src/filter.rs b/arrow-select/src/filter.rs index c365d0b841be..2af19ff85056 100644 --- a/arrow-select/src/filter.rs +++ b/arrow-select/src/filter.rs @@ -17,13 +17,16 @@ //! Defines filter kernels +use std::ops::AddAssign; use std::sync::Arc; use arrow_array::builder::BooleanBufferBuilder; use arrow_array::cast::AsArray; -use arrow_array::types::{ArrowDictionaryKeyType, ByteArrayType}; +use arrow_array::types::{ + ArrowDictionaryKeyType, ArrowPrimitiveType, ByteArrayType, RunEndIndexType, +}; use arrow_array::*; -use arrow_buffer::{bit_util, BooleanBuffer, NullBuffer}; +use arrow_buffer::{bit_util, BooleanBuffer, NullBuffer, RunEndBuffer}; use arrow_buffer::{Buffer, MutableBuffer}; use arrow_data::bit_iterator::{BitIndexIterator, BitSliceIterator}; use arrow_data::transform::MutableArrayData; @@ -336,6 +339,12 @@ fn filter_array(values: &dyn Array, predicate: &FilterPredicate) -> Result { Ok(Arc::new(filter_bytes(values.as_binary::(), predicate))) } + DataType::RunEndEncoded(_, _) => { + downcast_run_array!{ + values => Ok(Arc::new(filter_run_end_array(values, predicate)?)), + t => unimplemented!("Filter not supported for RunEndEncoded type {:?}", t) + } + } DataType::Dictionary(_, _) => downcast_dictionary_array! { values => Ok(Arc::new(filter_dict(values, predicate))), t => unimplemented!("Filter not supported for dictionary type {:?}", t) @@ -368,6 +377,55 @@ fn filter_array(values: &dyn Array, predicate: &FilterPredicate) -> Result( + re_arr: &RunArray, + pred: &FilterPredicate, +) -> Result, ArrowError> +where + R::Native: Into + From, + R::Native: AddAssign, +{ + let run_ends: &RunEndBuffer = re_arr.run_ends(); + let mut values_filter = BooleanBufferBuilder::new(run_ends.len()); + let mut new_run_ends = vec![R::default_value(); run_ends.len()]; + + let mut start = 0i64; + let mut i = 0; + let filter_values = pred.filter.values(); + let mut count = R::default_value(); + + for end in run_ends.inner().into_iter().map(|i| (*i).into()) { + let mut keep = false; + // in filter_array the predicate array is checked to have the same len as the run end array + // this means the largest value in the run_ends is == to pred.len() + // so we're always within bounds when calling value_unchecked + for pred in (start..end).map(|i| unsafe { filter_values.value_unchecked(i as usize) }) { + count += R::Native::from(pred); + keep |= pred + } + // this is to avoid branching + new_run_ends[i] = count; + i += keep as usize; + + values_filter.append(keep); + start = end; + } + + new_run_ends.truncate(i); + + if values_filter.is_empty() { + new_run_ends.clear(); + } + + let values = re_arr.values(); + let pred = BooleanArray::new(values_filter.finish(), None); + let values = filter(&values, &pred)?; + + let run_ends = PrimitiveArray::::new(new_run_ends.into(), None); + RunArray::try_new(&run_ends, &values) +} + /// Computes a new null mask for `data` based on `predicate` /// /// If the predicate selected no null-rows, returns `None`, otherwise returns @@ -635,6 +693,7 @@ where #[cfg(test)] mod tests { use arrow_array::builder::*; + use arrow_array::cast::as_run_array; use arrow_array::types::*; use rand::distributions::{Alphanumeric, Standard}; use rand::prelude::*; @@ -844,6 +903,78 @@ mod tests { assert_eq!(9, d.value(1)); } + #[test] + fn test_filter_run_end_encoding_array() { + let run_ends = Int64Array::from(vec![2, 3, 8]); + let values = Int64Array::from(vec![7, -2, 9]); + let a = RunArray::try_new(&run_ends, &values).expect("Failed to create RunArray"); + let b = BooleanArray::from(vec![true, false, true, false, true, false, true, false]); + let c = filter(&a, &b).unwrap(); + let actual: &RunArray = as_run_array(&c); + assert_eq!(4, actual.len()); + + let expected = RunArray::try_new( + &Int64Array::from(vec![1, 2, 4]), + &Int64Array::from(vec![7, -2, 9]), + ) + .expect("Failed to make expected RunArray test is broken"); + + assert_eq!(&actual.run_ends().values(), &expected.run_ends().values()); + assert_eq!(actual.values(), expected.values()) + } + + #[test] + fn test_filter_run_end_encoding_array_remove_value() { + let run_ends = Int32Array::from(vec![2, 3, 8, 10]); + let values = Int32Array::from(vec![7, -2, 9, -8]); + let a = RunArray::try_new(&run_ends, &values).expect("Failed to create RunArray"); + let b = BooleanArray::from(vec![ + false, true, false, false, true, false, true, false, false, false, + ]); + let c = filter(&a, &b).unwrap(); + let actual: &RunArray = as_run_array(&c); + assert_eq!(3, actual.len()); + + let expected = + RunArray::try_new(&Int32Array::from(vec![1, 3]), &Int32Array::from(vec![7, 9])) + .expect("Failed to make expected RunArray test is broken"); + + assert_eq!(&actual.run_ends().values(), &expected.run_ends().values()); + assert_eq!(actual.values(), expected.values()) + } + + #[test] + fn test_filter_run_end_encoding_array_remove_all_but_one() { + let run_ends = Int16Array::from(vec![2, 3, 8, 10]); + let values = Int16Array::from(vec![7, -2, 9, -8]); + let a = RunArray::try_new(&run_ends, &values).expect("Failed to create RunArray"); + let b = BooleanArray::from(vec![ + false, false, false, false, false, false, true, false, false, false, + ]); + let c = filter(&a, &b).unwrap(); + let actual: &RunArray = as_run_array(&c); + assert_eq!(1, actual.len()); + + let expected = RunArray::try_new(&Int16Array::from(vec![1]), &Int16Array::from(vec![9])) + .expect("Failed to make expected RunArray test is broken"); + + assert_eq!(&actual.run_ends().values(), &expected.run_ends().values()); + assert_eq!(actual.values(), expected.values()) + } + + #[test] + fn test_filter_run_end_encoding_array_empty() { + let run_ends = Int64Array::from(vec![2, 3, 8, 10]); + let values = Int64Array::from(vec![7, -2, 9, -8]); + let a = RunArray::try_new(&run_ends, &values).expect("Failed to create RunArray"); + let b = BooleanArray::from(vec![ + false, false, false, false, false, false, false, false, false, false, + ]); + let c = filter(&a, &b).unwrap(); + let actual: &RunArray = as_run_array(&c); + assert_eq!(0, actual.len()); + } + #[test] fn test_filter_dictionary_array() { let values = [Some("hello"), None, Some("world"), Some("!")];