Skip to content

Commit

Permalink
Add add_scalar kernel (#1151)
Browse files Browse the repository at this point in the history
* Add add_scalar

* move simd_float_unary_math_op to simd_unary_math_op
  • Loading branch information
viirya committed Jan 11, 2022
1 parent f085647 commit 762685b
Showing 1 changed file with 47 additions and 7 deletions.
54 changes: 47 additions & 7 deletions arrow/src/compute/kernels/arithmetic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -97,13 +97,13 @@ where
}

#[cfg(feature = "simd")]
fn simd_float_unary_math_op<T, SIMD_OP, SCALAR_OP>(
fn simd_unary_math_op<T, SIMD_OP, SCALAR_OP>(
array: &PrimitiveArray<T>,
simd_op: SIMD_OP,
scalar_op: SCALAR_OP,
) -> Result<PrimitiveArray<T>>
where
T: datatypes::ArrowFloatNumericType,
T: ArrowNumericType,
SIMD_OP: Fn(T::Simd) -> T::Simd,
SCALAR_OP: Fn(T::Native) -> T::Native,
{
Expand Down Expand Up @@ -1016,6 +1016,31 @@ where
return math_op(left, right, |a, b| a + b);
}

/// Add every value in an array by a scalar. If any value in the array is null then the
/// result is also null.
pub fn add_scalar<T>(
array: &PrimitiveArray<T>,
scalar: T::Native,
) -> Result<PrimitiveArray<T>>
where
T: datatypes::ArrowNumericType,
T::Native: Add<Output = T::Native>
+ Sub<Output = T::Native>
+ Mul<Output = T::Native>
+ Div<Output = T::Native>
+ Rem<Output = T::Native>
+ Zero
+ One,
{
#[cfg(feature = "simd")]
{
let scalar_vector = T::init(scalar);
return simd_unary_math_op(array, |x| x + scalar_vector, |x| x + scalar);
}
#[cfg(not(feature = "simd"))]
return Ok(unary(array, |value| value + scalar));
}

/// Perform `left - right` operation on two arrays. If either left or right value is null
/// then the result is also null.
pub fn subtract<T>(
Expand Down Expand Up @@ -1060,11 +1085,7 @@ where
#[cfg(feature = "simd")]
{
let raise_vector = T::init(raise);
return simd_float_unary_math_op(
array,
|x| T::pow(x, raise_vector),
|x| x.pow(raise),
);
return simd_unary_math_op(array, |x| T::pow(x, raise_vector), |x| x.pow(raise));
}
#[cfg(not(feature = "simd"))]
return Ok(unary(array, |x| x.pow(raise)));
Expand Down Expand Up @@ -1234,6 +1255,25 @@ mod tests {
);
}

#[test]
fn test_primitive_array_add_scalar() {
let a = Int32Array::from(vec![15, 14, 9, 8, 1]);
let b = 3;
let c = add_scalar(&a, b).unwrap();
let expected = Int32Array::from(vec![18, 17, 12, 11, 4]);
assert_eq!(c, expected);
}

#[test]
fn test_primitive_array_add_scalar_sliced() {
let a = Int32Array::from(vec![Some(15), None, Some(9), Some(8), None]);
let a = a.slice(1, 4);
let a = as_primitive_array(&a);
let actual = add_scalar(a, 3).unwrap();
let expected = Int32Array::from(vec![None, Some(12), Some(11), None]);
assert_eq!(actual, expected);
}

#[test]
fn test_primitive_array_subtract() {
let a = Int32Array::from(vec![1, 2, 3, 4, 5]);
Expand Down

0 comments on commit 762685b

Please sign in to comment.