diff --git a/arrow/src/array/data.rs b/arrow/src/array/data.rs index 9d5b0ee023d..172bdaac9eb 100644 --- a/arrow/src/array/data.rs +++ b/arrow/src/array/data.rs @@ -412,7 +412,7 @@ impl ArrayData { } /// Returns a new empty [ArrayData] valid for `data_type`. - pub(super) fn new_empty(data_type: &DataType) -> Self { + pub fn new_empty(data_type: &DataType) -> Self { let buffers = new_buffers(data_type, 0); let [buffer1, buffer2] = buffers; let buffers = into_buffers(data_type, buffer1, buffer2); diff --git a/arrow/src/compute/kernels/filter.rs b/arrow/src/compute/kernels/filter.rs index 4da07b89edd..b15692e90f2 100644 --- a/arrow/src/compute/kernels/filter.rs +++ b/arrow/src/compute/kernels/filter.rs @@ -197,14 +197,37 @@ pub fn build_filter(filter: &BooleanArray) -> Result { let chunks = iter.collect::>(); Ok(Box::new(move |array: &ArrayData| { - let mut mutable = MutableArrayData::new(vec![array], false, filter_count); - chunks - .iter() - .for_each(|(start, end)| mutable.extend(0, *start, *end)); - mutable.freeze() + match filter_count { + // return all + len if len == array.len() => array.clone(), + 0 => ArrayData::new_empty(array.data_type()), + _ => { + let mut mutable = MutableArrayData::new(vec![array], false, filter_count); + chunks + .iter() + .for_each(|(start, end)| mutable.extend(0, *start, *end)); + mutable.freeze() + } + } })) } +/// Remove null values by do a bitmask AND operation with null bits and the boolean bits. +fn prep_null_mask_filter(filter: &BooleanArray) -> BooleanArray { + let array_data = filter.data_ref(); + let null_bitmap = array_data.null_buffer().unwrap(); + let mask = filter.values(); + let offset = filter.offset(); + + let new_mask = buffer_bin_and(mask, offset, null_bitmap, offset, filter.len()); + + let array_data = ArrayData::builder(DataType::Boolean) + .len(filter.len()) + .add_buffer(new_mask) + .build(); + BooleanArray::from(array_data) +} + /// Filters an [Array], returning elements matching the filter (i.e. where the values are true). /// /// # Example @@ -221,43 +244,49 @@ pub fn build_filter(filter: &BooleanArray) -> Result { /// # Ok(()) /// # } /// ``` -pub fn filter(array: &Array, filter: &BooleanArray) -> Result { - if filter.null_count() > 0 { +pub fn filter(array: &Array, predicate: &BooleanArray) -> Result { + if predicate.null_count() > 0 { // this greatly simplifies subsequent filtering code // now we only have a boolean mask to deal with - let array_data = filter.data_ref(); - let null_bitmap = array_data.null_buffer().unwrap(); - let mask = filter.values(); - let offset = filter.offset(); - - let new_mask = buffer_bin_and(mask, offset, null_bitmap, offset, filter.len()); - - let array_data = ArrayData::builder(DataType::Boolean) - .len(filter.len()) - .add_buffer(new_mask) - .build(); - let filter = BooleanArray::from(array_data); - // fully qualified syntax, because we have an argument with the same name - return crate::compute::kernels::filter::filter(array, &filter); + let predicate = prep_null_mask_filter(predicate); + return filter(array, &predicate); } - let iter = SlicesIterator::new(filter); - - let mut mutable = - MutableArrayData::new(vec![array.data_ref()], false, iter.filter_count); - iter.for_each(|(start, end)| mutable.extend(0, start, end)); - let data = mutable.freeze(); - Ok(make_array(data)) + let iter = SlicesIterator::new(predicate); + match iter.filter_count { + 0 => { + // return empty + Ok(new_empty_array(array.data_type())) + } + len if len == array.len() => { + // return all + let data = array.data().clone(); + Ok(make_array(data)) + } + _ => { + // actually filter + let mut mutable = + MutableArrayData::new(vec![array.data_ref()], false, iter.filter_count); + iter.for_each(|(start, end)| mutable.extend(0, start, end)); + let data = mutable.freeze(); + Ok(make_array(data)) + } + } } /// Returns a new [RecordBatch] with arrays containing only values matching the filter. -/// WARNING: the nulls of `filter` are ignored and the value on its slot is considered. -/// Therefore, it is considered undefined behavior to pass `filter` with null values. pub fn filter_record_batch( record_batch: &RecordBatch, - filter: &BooleanArray, + predicate: &BooleanArray, ) -> Result { - let filter = build_filter(filter)?; + if predicate.null_count() > 0 { + // this greatly simplifies subsequent filtering code + // now we only have a boolean mask to deal with + let predicate = prep_null_mask_filter(predicate); + return filter_record_batch(record_batch, &predicate); + } + + let filter = build_filter(predicate)?; let filtered_arrays = record_batch .columns() .iter() @@ -625,4 +654,26 @@ mod tests { assert_eq!(out_arr0, out_arr1); Ok(()) } + + #[test] + fn test_fast_path() -> Result<()> { + let a: PrimitiveArray = + PrimitiveArray::from(vec![Some(1), Some(2), None]); + + // all true + let mask = BooleanArray::from(vec![true, true, true]); + let out = filter(&a, &mask)?; + let b = out + .as_any() + .downcast_ref::>() + .unwrap(); + assert_eq!(&a, b); + + // all false + let mask = BooleanArray::from(vec![false, false, false]); + let out = filter(&a, &mask)?; + assert_eq!(out.len(), 0); + assert_eq!(out.data_type(), &DataType::Int64); + Ok(()) + } }