diff --git a/arrow/src/array/data.rs b/arrow/src/array/data.rs index 9d5b0ee023d..172bdaac9eb 100644 --- a/arrow/src/array/data.rs +++ b/arrow/src/array/data.rs @@ -412,7 +412,7 @@ impl ArrayData { } /// Returns a new empty [ArrayData] valid for `data_type`. - pub(super) fn new_empty(data_type: &DataType) -> Self { + pub fn new_empty(data_type: &DataType) -> Self { let buffers = new_buffers(data_type, 0); let [buffer1, buffer2] = buffers; let buffers = into_buffers(data_type, buffer1, buffer2); diff --git a/arrow/src/compute/kernels/filter.rs b/arrow/src/compute/kernels/filter.rs index f8e68b52db0..6701c464436 100644 --- a/arrow/src/compute/kernels/filter.rs +++ b/arrow/src/compute/kernels/filter.rs @@ -197,15 +197,18 @@ pub fn build_filter(filter: &BooleanArray) -> Result { let chunks = iter.collect::>(); Ok(Box::new(move |array: &ArrayData| { - if filter_count == array.len() { - return array.clone(); + match filter_count { + // return all + len if len == array.len() => array.clone(), + 0 => ArrayData::new_empty(array.data_type()), + _ => { + let mut mutable = MutableArrayData::new(vec![array], false, filter_count); + chunks + .iter() + .for_each(|(start, end)| mutable.extend(0, *start, *end)); + mutable.freeze() + } } - - let mut mutable = MutableArrayData::new(vec![array], false, filter_count); - chunks - .iter() - .for_each(|(start, end)| mutable.extend(0, *start, *end)); - mutable.freeze() })) } @@ -251,15 +254,25 @@ pub fn filter(array: &Array, filter: &BooleanArray) -> Result { } let iter = SlicesIterator::new(filter); - if iter.filter_count == array.len() { - let data = array.data().clone(); - Ok(make_array(data)) - } else { - let mut mutable = - MutableArrayData::new(vec![array.data_ref()], false, iter.filter_count); - iter.for_each(|(start, end)| mutable.extend(0, start, end)); - let data = mutable.freeze(); - Ok(make_array(data)) + match iter.filter_count { + 0 => { + // return empty + let data = ArrayData::new_empty(array.data_type()); + Ok(make_array(data)) + } + len if len == array.len() => { + // return all + let data = array.data().clone(); + Ok(make_array(data)) + } + _ => { + // actually filter + let mut mutable = + MutableArrayData::new(vec![array.data_ref()], false, iter.filter_count); + iter.for_each(|(start, end)| mutable.extend(0, start, end)); + let data = mutable.freeze(); + Ok(make_array(data)) + } } } @@ -652,6 +665,8 @@ mod tests { fn test_fast_path() -> Result<()> { let a: PrimitiveArray = PrimitiveArray::from(vec![Some(1), Some(2), None]); + + // all true let mask = BooleanArray::from(vec![true, true, true]); let out = filter(&a, &mask)?; let b = out @@ -659,6 +674,12 @@ mod tests { .downcast_ref::>() .unwrap(); assert_eq!(&a, b); + + // all false + let mask = BooleanArray::from(vec![false, false, false]); + let out = filter(&a, &mask)?; + assert_eq!(out.len(), 0); + assert_eq!(out.data_type(), &DataType::Int64); Ok(()) } }