diff --git a/arrow/src/compute/kernels/arithmetic.rs b/arrow/src/compute/kernels/arithmetic.rs index 09d4b9fd6cd..46bd42faf4f 100644 --- a/arrow/src/compute/kernels/arithmetic.rs +++ b/arrow/src/compute/kernels/arithmetic.rs @@ -97,13 +97,13 @@ where } #[cfg(feature = "simd")] -fn simd_float_unary_math_op( +fn simd_unary_math_op( array: &PrimitiveArray, simd_op: SIMD_OP, scalar_op: SCALAR_OP, ) -> Result> where - T: datatypes::ArrowFloatNumericType, + T: ArrowNumericType, SIMD_OP: Fn(T::Simd) -> T::Simd, SCALAR_OP: Fn(T::Native) -> T::Native, { @@ -1010,6 +1010,31 @@ where return math_op(left, right, |a, b| a + b); } +/// Add every value in an array by a scalar. If any value in the array is null then the +/// result is also null. +pub fn add_scalar( + array: &PrimitiveArray, + scalar: T::Native, +) -> Result> +where + T: datatypes::ArrowNumericType, + T::Native: Add + + Sub + + Mul + + Div + + Rem + + Zero + + One, +{ + #[cfg(feature = "simd")] + { + let scalar_vector = T::init(scalar); + return simd_unary_math_op(array, |x| x + scalar_vector, |x| x + scalar); + } + #[cfg(not(feature = "simd"))] + return Ok(unary(array, |value| value + scalar)); +} + /// Perform `left - right` operation on two arrays. If either left or right value is null /// then the result is also null. pub fn subtract( @@ -1054,11 +1079,7 @@ where #[cfg(feature = "simd")] { let raise_vector = T::init(raise); - return simd_float_unary_math_op( - array, - |x| T::pow(x, raise_vector), - |x| x.pow(raise), - ); + return simd_unary_math_op(array, |x| T::pow(x, raise_vector), |x| x.pow(raise)); } #[cfg(not(feature = "simd"))] return Ok(unary(array, |x| x.pow(raise))); @@ -1228,6 +1249,25 @@ mod tests { ); } + #[test] + fn test_primitive_array_add_scalar() { + let a = Int32Array::from(vec![15, 14, 9, 8, 1]); + let b = 3; + let c = add_scalar(&a, b).unwrap(); + let expected = Int32Array::from(vec![18, 17, 12, 11, 4]); + assert_eq!(c, expected); + } + + #[test] + fn test_primitive_array_add_scalar_sliced() { + let a = Int32Array::from(vec![Some(15), None, Some(9), Some(8), None]); + let a = a.slice(1, 4); + let a = as_primitive_array(&a); + let actual = add_scalar(a, 3).unwrap(); + let expected = Int32Array::from(vec![None, Some(12), Some(11), None]); + assert_eq!(actual, expected); + } + #[test] fn test_primitive_array_subtract() { let a = Int32Array::from(vec![1, 2, 3, 4, 5]);