From 762685bf43b7f89cf365a8e17f44e4bae222b665 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Tue, 11 Jan 2022 11:18:52 -0800 Subject: [PATCH] Add add_scalar kernel (#1151) * Add add_scalar * move simd_float_unary_math_op to simd_unary_math_op --- arrow/src/compute/kernels/arithmetic.rs | 54 +++++++++++++++++++++---- 1 file changed, 47 insertions(+), 7 deletions(-) diff --git a/arrow/src/compute/kernels/arithmetic.rs b/arrow/src/compute/kernels/arithmetic.rs index 0ee060847db..590bb647859 100644 --- a/arrow/src/compute/kernels/arithmetic.rs +++ b/arrow/src/compute/kernels/arithmetic.rs @@ -97,13 +97,13 @@ where } #[cfg(feature = "simd")] -fn simd_float_unary_math_op( +fn simd_unary_math_op( array: &PrimitiveArray, simd_op: SIMD_OP, scalar_op: SCALAR_OP, ) -> Result> where - T: datatypes::ArrowFloatNumericType, + T: ArrowNumericType, SIMD_OP: Fn(T::Simd) -> T::Simd, SCALAR_OP: Fn(T::Native) -> T::Native, { @@ -1016,6 +1016,31 @@ where return math_op(left, right, |a, b| a + b); } +/// Add every value in an array by a scalar. If any value in the array is null then the +/// result is also null. +pub fn add_scalar( + array: &PrimitiveArray, + scalar: T::Native, +) -> Result> +where + T: datatypes::ArrowNumericType, + T::Native: Add + + Sub + + Mul + + Div + + Rem + + Zero + + One, +{ + #[cfg(feature = "simd")] + { + let scalar_vector = T::init(scalar); + return simd_unary_math_op(array, |x| x + scalar_vector, |x| x + scalar); + } + #[cfg(not(feature = "simd"))] + return Ok(unary(array, |value| value + scalar)); +} + /// Perform `left - right` operation on two arrays. If either left or right value is null /// then the result is also null. pub fn subtract( @@ -1060,11 +1085,7 @@ where #[cfg(feature = "simd")] { let raise_vector = T::init(raise); - return simd_float_unary_math_op( - array, - |x| T::pow(x, raise_vector), - |x| x.pow(raise), - ); + return simd_unary_math_op(array, |x| T::pow(x, raise_vector), |x| x.pow(raise)); } #[cfg(not(feature = "simd"))] return Ok(unary(array, |x| x.pow(raise))); @@ -1234,6 +1255,25 @@ mod tests { ); } + #[test] + fn test_primitive_array_add_scalar() { + let a = Int32Array::from(vec![15, 14, 9, 8, 1]); + let b = 3; + let c = add_scalar(&a, b).unwrap(); + let expected = Int32Array::from(vec![18, 17, 12, 11, 4]); + assert_eq!(c, expected); + } + + #[test] + fn test_primitive_array_add_scalar_sliced() { + let a = Int32Array::from(vec![Some(15), None, Some(9), Some(8), None]); + let a = a.slice(1, 4); + let a = as_primitive_array(&a); + let actual = add_scalar(a, 3).unwrap(); + let expected = Int32Array::from(vec![None, Some(12), Some(11), None]); + assert_eq!(actual, expected); + } + #[test] fn test_primitive_array_subtract() { let a = Int32Array::from(vec![1, 2, 3, 4, 5]);