diff --git a/arrow/src/compute/kernels/aggregate.rs b/arrow/src/compute/kernels/aggregate.rs index c8d0443c470..fb2f55582d6 100644 --- a/arrow/src/compute/kernels/aggregate.rs +++ b/arrow/src/compute/kernels/aggregate.rs @@ -185,7 +185,7 @@ pub fn min_string(array: &GenericStringArray) -> Option<& } /// Returns the sum of values in the array. -pub fn sum_dyn>(array: A) -> Option +pub fn sum_array>(array: A) -> Option where T: ArrowNumericType, T::Native: Add, @@ -215,6 +215,68 @@ where } } +/// Returns the min of values in the array. +pub fn min_array>(array: A) -> Option +where + T: ArrowNumericType, + T::Native: ArrowNativeType, +{ + min_max_array_helper::( + array, + |a, b| (!is_nan(*a) & is_nan(*b)) || a < b, + min, + ) +} + +/// Returns the max of values in the array. +pub fn max_array>(array: A) -> Option +where + T: ArrowNumericType, + T::Native: ArrowNativeType, +{ + min_max_array_helper::( + array, + |a, b| (is_nan(*a) & !is_nan(*b)) || a > b, + max, + ) +} + +fn min_max_array_helper, F, M>( + array: A, + cmp: F, + m: M, +) -> Option +where + T: ArrowNumericType, + F: Fn(&T::Native, &T::Native) -> bool, + M: Fn(&PrimitiveArray) -> Option, +{ + match array.data_type() { + DataType::Dictionary(_, _) => { + let null_count = array.null_count(); + + if null_count == array.len() { + return None; + } + + let mut has_value = false; + let mut n = T::default_value(); + let iter = ArrayIter::new(array); + iter.into_iter().for_each(|value| { + if let Some(value) = value { + if !has_value || cmp(&value, &n) { + has_value = true; + n = value; + } + } + }); + + Some(n) + } + _ => m(as_primitive_array(&array)), + } +} + /// Returns the sum of values in the primitive array. /// /// Returns `None` if the array is empty or only contains null values. @@ -656,7 +718,7 @@ mod tests { use super::*; use crate::array::*; use crate::compute::add; - use crate::datatypes::{Int32Type, Int8Type}; + use crate::datatypes::{Float32Type, Int32Type, Int8Type}; #[test] fn test_primitive_array_sum() { @@ -1043,19 +1105,63 @@ mod tests { let dict_array = DictionaryArray::try_new(&keys, &values).unwrap(); let array = dict_array.downcast_dict::().unwrap(); - assert_eq!(39, sum_dyn::(array).unwrap()); + assert_eq!(39, sum_array::(array).unwrap()); let a = Int32Array::from(vec![1, 2, 3, 4, 5]); - assert_eq!(15, sum_dyn::(&a).unwrap()); + assert_eq!(15, sum_array::(&a).unwrap()); let keys = Int8Array::from(vec![Some(2_i8), None, Some(4)]); let dict_array = DictionaryArray::try_new(&keys, &values).unwrap(); let array = dict_array.downcast_dict::().unwrap(); - assert_eq!(26, sum_dyn::(array).unwrap()); + assert_eq!(26, sum_array::(array).unwrap()); + + let keys = Int8Array::from(vec![None, None, None]); + let dict_array = DictionaryArray::try_new(&keys, &values).unwrap(); + let array = dict_array.downcast_dict::().unwrap(); + assert!(sum_array::(array).is_none()); + } + + #[test] + fn test_max_min_dyn() { + let values = Int8Array::from_iter_values([10_i8, 11, 12, 13, 14, 15, 16, 17]); + let keys = Int8Array::from_iter_values([2_i8, 3, 4]); + + let dict_array = DictionaryArray::try_new(&keys, &values).unwrap(); + let array = dict_array.downcast_dict::().unwrap(); + assert_eq!(14, max_array::(array).unwrap()); + + let array = dict_array.downcast_dict::().unwrap(); + assert_eq!(12, min_array::(array).unwrap()); + + let a = Int32Array::from(vec![1, 2, 3, 4, 5]); + assert_eq!(5, max_array::(&a).unwrap()); + assert_eq!(1, min_array::(&a).unwrap()); + + let keys = Int8Array::from(vec![Some(2_i8), None, Some(7)]); + let dict_array = DictionaryArray::try_new(&keys, &values).unwrap(); + let array = dict_array.downcast_dict::().unwrap(); + assert_eq!(17, max_array::(array).unwrap()); + let array = dict_array.downcast_dict::().unwrap(); + assert_eq!(12, min_array::(array).unwrap()); let keys = Int8Array::from(vec![None, None, None]); let dict_array = DictionaryArray::try_new(&keys, &values).unwrap(); let array = dict_array.downcast_dict::().unwrap(); - assert!(sum_dyn::(array).is_none()); + assert!(max_array::(array).is_none()); + let array = dict_array.downcast_dict::().unwrap(); + assert!(min_array::(array).is_none()); + } + + #[test] + fn test_max_min_dyn_nan() { + let values = Float32Array::from(vec![5.0_f32, 2.0_f32, f32::NAN]); + let keys = Int8Array::from_iter_values([0_i8, 1, 2]); + + let dict_array = DictionaryArray::try_new(&keys, &values).unwrap(); + let array = dict_array.downcast_dict::().unwrap(); + assert!(max_array::(array).unwrap().is_nan()); + + let array = dict_array.downcast_dict::().unwrap(); + assert_eq!(2.0_f32, min_array::(array).unwrap()); } }