diff --git a/arrow/src/compute/kernels/arithmetic.rs b/arrow/src/compute/kernels/arithmetic.rs index 2b11e4e3e16..e0b2a151e95 100644 --- a/arrow/src/compute/kernels/arithmetic.rs +++ b/arrow/src/compute/kernels/arithmetic.rs @@ -97,13 +97,13 @@ where } #[cfg(feature = "simd")] -fn simd_float_unary_math_op( +fn simd_unary_math_op( array: &PrimitiveArray, simd_op: SIMD_OP, scalar_op: SCALAR_OP, ) -> Result> where - T: datatypes::ArrowFloatNumericType, + T: ArrowNumericType, SIMD_OP: Fn(T::Simd) -> T::Simd, SCALAR_OP: Fn(T::Native) -> T::Native, { @@ -990,69 +990,6 @@ where Ok(PrimitiveArray::::from(data)) } -/// SIMD vectorized version of adding a scalar to an array. -#[cfg(feature = "simd")] -fn simd_add_scalar( - array: &PrimitiveArray, - scalar: T::Native, -) -> Result> -where - T: ArrowNumericType, - T::Native: Add - + Sub - + Mul - + Div - + Rem - + Zero - + One, -{ - let lanes = T::lanes(); - let buffer_size = array.len() * std::mem::size_of::(); - let mut result = MutableBuffer::new(buffer_size).with_bitset(buffer_size, false); - - // safety: result is newly created above, always written as a T below - let mut result_chunks = unsafe { result.typed_data_mut().chunks_exact_mut(lanes) }; - let mut array_chunks = array.values().chunks_exact(lanes); - - let simd_right = T::init(scalar); - - result_chunks - .borrow_mut() - .zip(array_chunks.borrow_mut()) - .for_each(|(result_slice, array_slice)| { - let simd_left = T::load(array_slice); - - let simd_result = T::bin_op(simd_left, simd_right, |a, b| a + b); - T::write(simd_result, result_slice); - }); - - let result_remainder = result_chunks.into_remainder(); - let array_remainder = array_chunks.remainder(); - - result_remainder - .iter_mut() - .zip(array_remainder.iter()) - .for_each(|(scalar_result, scalar_array)| { - *scalar_result = *scalar_array + scalar; - }); - - let data = unsafe { - ArrayData::new_unchecked( - T::DATA_TYPE, - array.len(), - None, - array - .data_ref() - .null_buffer() - .map(|b| b.bit_slice(array.offset(), array.len())), - 0, - vec![result.into()], - vec![], - ) - }; - Ok(PrimitiveArray::::from(data)) -} - /// Perform `left + right` operation on two arrays. If either left or right value is null /// then the result is also null. pub fn add( @@ -1090,7 +1027,14 @@ where + One, { #[cfg(feature = "simd")] - return simd_add_scalar(&array, scalar); + { + let scalar_vector = T::init(scalar); + return simd_unary_math_op( + array, + |x| x + scalar_vector, + |x| x + scalar, + ); + } #[cfg(not(feature = "simd"))] return Ok(unary(array, |value| value + scalar)); } @@ -1139,11 +1083,7 @@ where #[cfg(feature = "simd")] { let raise_vector = T::init(raise); - return simd_float_unary_math_op( - array, - |x| T::pow(x, raise_vector), - |x| x.pow(raise), - ); + return simd_unary_math_op(array, |x| T::pow(x, raise_vector), |x| x.pow(raise)); } #[cfg(not(feature = "simd"))] return Ok(unary(array, |x| x.pow(raise)));