Skip to content

Commit

Permalink
move simd_float_unary_math_op to simd_unary_math_op
Browse files Browse the repository at this point in the history
  • Loading branch information
viirya committed Jan 11, 2022
1 parent c9a0c00 commit 9413f52
Showing 1 changed file with 11 additions and 71 deletions.
82 changes: 11 additions & 71 deletions arrow/src/compute/kernels/arithmetic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -97,13 +97,13 @@ where
}

#[cfg(feature = "simd")]
fn simd_float_unary_math_op<T, SIMD_OP, SCALAR_OP>(
fn simd_unary_math_op<T, SIMD_OP, SCALAR_OP>(
array: &PrimitiveArray<T>,
simd_op: SIMD_OP,
scalar_op: SCALAR_OP,
) -> Result<PrimitiveArray<T>>
where
T: datatypes::ArrowFloatNumericType,
T: ArrowNumericType,
SIMD_OP: Fn(T::Simd) -> T::Simd,
SCALAR_OP: Fn(T::Native) -> T::Native,
{
Expand Down Expand Up @@ -990,69 +990,6 @@ where
Ok(PrimitiveArray::<T>::from(data))
}

/// SIMD vectorized version of adding a scalar to an array.
#[cfg(feature = "simd")]
fn simd_add_scalar<T>(
array: &PrimitiveArray<T>,
scalar: T::Native,
) -> Result<PrimitiveArray<T>>
where
T: ArrowNumericType,
T::Native: Add<Output = T::Native>
+ Sub<Output = T::Native>
+ Mul<Output = T::Native>
+ Div<Output = T::Native>
+ Rem<Output = T::Native>
+ Zero
+ One,
{
let lanes = T::lanes();
let buffer_size = array.len() * std::mem::size_of::<T::Native>();
let mut result = MutableBuffer::new(buffer_size).with_bitset(buffer_size, false);

// safety: result is newly created above, always written as a T below
let mut result_chunks = unsafe { result.typed_data_mut().chunks_exact_mut(lanes) };
let mut array_chunks = array.values().chunks_exact(lanes);

let simd_right = T::init(scalar);

result_chunks
.borrow_mut()
.zip(array_chunks.borrow_mut())
.for_each(|(result_slice, array_slice)| {
let simd_left = T::load(array_slice);

let simd_result = T::bin_op(simd_left, simd_right, |a, b| a + b);
T::write(simd_result, result_slice);
});

let result_remainder = result_chunks.into_remainder();
let array_remainder = array_chunks.remainder();

result_remainder
.iter_mut()
.zip(array_remainder.iter())
.for_each(|(scalar_result, scalar_array)| {
*scalar_result = *scalar_array + scalar;
});

let data = unsafe {
ArrayData::new_unchecked(
T::DATA_TYPE,
array.len(),
None,
array
.data_ref()
.null_buffer()
.map(|b| b.bit_slice(array.offset(), array.len())),
0,
vec![result.into()],
vec![],
)
};
Ok(PrimitiveArray::<T>::from(data))
}

/// Perform `left + right` operation on two arrays. If either left or right value is null
/// then the result is also null.
pub fn add<T>(
Expand Down Expand Up @@ -1090,7 +1027,14 @@ where
+ One,
{
#[cfg(feature = "simd")]
return simd_add_scalar(&array, scalar);
{
let scalar_vector = T::init(scalar);
return simd_unary_math_op(
array,
|x| x + scalar_vector,
|x| x + scalar,
);
}
#[cfg(not(feature = "simd"))]
return Ok(unary(array, |value| value + scalar));
}
Expand Down Expand Up @@ -1139,11 +1083,7 @@ where
#[cfg(feature = "simd")]
{
let raise_vector = T::init(raise);
return simd_float_unary_math_op(
array,
|x| T::pow(x, raise_vector),
|x| x.pow(raise),
);
return simd_unary_math_op(array, |x| T::pow(x, raise_vector), |x| x.pow(raise));
}
#[cfg(not(feature = "simd"))]
return Ok(unary(array, |x| x.pow(raise)));
Expand Down

0 comments on commit 9413f52

Please sign in to comment.