Skip to content

Commit

Permalink
Add binary_mut and try_binary_mut (#3144)
Browse files Browse the repository at this point in the history
* Add add_mut

* Add try_binary_mut

* Add test

* Change result type

* Remove _mut kernels

* Fix clippy
  • Loading branch information
viirya committed Nov 30, 2022
1 parent 989ab8d commit 961e114
Show file tree
Hide file tree
Showing 2 changed files with 246 additions and 1 deletion.
31 changes: 30 additions & 1 deletion arrow/src/compute/kernels/arithmetic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1624,7 +1624,7 @@ where
mod tests {
use super::*;
use crate::array::Int32Array;
use crate::compute::{try_unary_mut, unary_mut};
use crate::compute::{binary_mut, try_binary_mut, try_unary_mut, unary_mut};
use crate::datatypes::{Date64Type, Int32Type, Int8Type};
use arrow_buffer::i256;
use chrono::NaiveDate;
Expand Down Expand Up @@ -3100,6 +3100,35 @@ mod tests {
assert_eq!(result.null_count(), 13);
}

#[test]
fn test_primitive_array_add_mut_by_binary_mut() {
let a = Int32Array::from(vec![15, 14, 9, 8, 1]);
let b = Int32Array::from(vec![Some(1), None, Some(3), None, Some(5)]);

let c = binary_mut(a, &b, |a, b| a.add_wrapping(b))
.unwrap()
.unwrap();
let expected = Int32Array::from(vec![Some(16), None, Some(12), None, Some(6)]);
assert_eq!(c, expected);
}

#[test]
fn test_primitive_add_mut_wrapping_overflow_by_try_binary_mut() {
let a = Int32Array::from(vec![i32::MAX, i32::MIN]);
let b = Int32Array::from(vec![1, 1]);

let wrapped = binary_mut(a, &b, |a, b| a.add_wrapping(b))
.unwrap()
.unwrap();
let expected = Int32Array::from(vec![-2147483648, -2147483647]);
assert_eq!(expected, wrapped);

let a = Int32Array::from(vec![i32::MAX, i32::MIN]);
let b = Int32Array::from(vec![1, 1]);
let overflow = try_binary_mut(a, &b, |a, b| a.add_checked(b));
let _ = overflow.unwrap().expect_err("overflow should be detected");
}

#[test]
fn test_primitive_add_scalar_by_unary_mut() {
let a = Int32Array::from(vec![15, 14, 9, 8, 1]);
Expand Down
216 changes: 216 additions & 0 deletions arrow/src/compute/kernels/arity.rs
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,75 @@ where
Ok(unsafe { build_primitive_array(len, buffer, null_count, null_buffer) })
}

/// Given two arrays of length `len`, calls `op(a[i], b[i])` for `i` in `0..len`, mutating
/// the mutable [`PrimitiveArray`] `a`. If any index is null in either `a` or `b`, the
/// corresponding index in the result will also be null.
///
/// Mutable primitive array means that the buffer is not shared with other arrays.
/// As a result, this mutates the buffer directly without allocating new buffer.
///
/// Like [`unary`] the provided function is evaluated for every index, ignoring validity. This
/// is beneficial when the cost of the operation is low compared to the cost of branching, and
/// especially when the operation can be vectorised, however, requires `op` to be infallible
/// for all possible values of its inputs
///
/// # Error
///
/// This function gives error if the arrays have different lengths.
/// This function gives error of original [`PrimitiveArray`] `a` if it is not a mutable
/// primitive array.
pub fn binary_mut<T, F>(
a: PrimitiveArray<T>,
b: &PrimitiveArray<T>,
op: F,
) -> std::result::Result<
std::result::Result<PrimitiveArray<T>, ArrowError>,
PrimitiveArray<T>,
>
where
T: ArrowPrimitiveType,
F: Fn(T::Native, T::Native) -> T::Native,
{
if a.len() != b.len() {
return Ok(Err(ArrowError::ComputeError(
"Cannot perform binary operation on arrays of different length".to_string(),
)));
}

if a.is_empty() {
return Ok(Ok(PrimitiveArray::from(ArrayData::new_empty(
&T::DATA_TYPE,
))));
}

let len = a.len();

let null_buffer = combine_option_bitmap(&[a.data(), b.data()], len).unwrap();
let null_count = null_buffer
.as_ref()
.map(|x| len - x.count_set_bits_offset(0, len))
.unwrap_or_default();

let mut builder = a.into_builder()?;

builder
.values_slice_mut()
.iter_mut()
.zip(b.values())
.for_each(|(l, r)| *l = op(*l, *r));

let array_builder = builder
.finish()
.data()
.clone()
.into_builder()
.null_bit_buffer(null_buffer)
.null_count(null_count);

let array_data = unsafe { array_builder.build_unchecked() };
Ok(Ok(PrimitiveArray::<T>::from(array_data)))
}

/// Applies the provided fallible binary operation across `a` and `b`, returning any error,
/// and collecting the results into a [`PrimitiveArray`]. If any index is null in either `a`
/// or `b`, the corresponding index in the result will also be null
Expand Down Expand Up @@ -289,6 +358,83 @@ where
}
}

/// Applies the provided fallible binary operation across `a` and `b` by mutating the mutable
/// [`PrimitiveArray`] `a` with the results, returning any error. If any index is null in
/// either `a` or `b`, the corresponding index in the result will also be null
///
/// Like [`try_unary`] the function is only evaluated for non-null indices
///
/// Mutable primitive array means that the buffer is not shared with other arrays.
/// As a result, this mutates the buffer directly without allocating new buffer.
///
/// # Error
///
/// Return an error if the arrays have different lengths or
/// the operation is under erroneous.
/// This function gives error of original [`PrimitiveArray`] `a` if it is not a mutable
/// primitive array.
pub fn try_binary_mut<T, F>(
a: PrimitiveArray<T>,
b: &PrimitiveArray<T>,
op: F,
) -> std::result::Result<
std::result::Result<PrimitiveArray<T>, ArrowError>,
PrimitiveArray<T>,
>
where
T: ArrowPrimitiveType,
F: Fn(T::Native, T::Native) -> Result<T::Native>,
{
if a.len() != b.len() {
return Ok(Err(ArrowError::ComputeError(
"Cannot perform binary operation on arrays of different length".to_string(),
)));
}
let len = a.len();

if a.is_empty() {
return Ok(Ok(PrimitiveArray::from(ArrayData::new_empty(
&T::DATA_TYPE,
))));
}

if a.null_count() == 0 && b.null_count() == 0 {
try_binary_no_nulls_mut(len, a, b, op)
} else {
let null_buffer = combine_option_bitmap(&[a.data(), b.data()], len).unwrap();
let null_count = null_buffer
.as_ref()
.map(|x| len - x.count_set_bits_offset(0, len))
.unwrap_or_default();

let mut builder = a.into_builder()?;

let slice = builder.values_slice_mut();

match try_for_each_valid_idx(len, 0, null_count, null_buffer.as_deref(), |idx| {
unsafe {
*slice.get_unchecked_mut(idx) =
op(*slice.get_unchecked(idx), b.value_unchecked(idx))?
};
Ok::<_, ArrowError>(())
}) {
Ok(_) => {}
Err(err) => return Ok(Err(err)),
};

let array_builder = builder
.finish()
.data()
.clone()
.into_builder()
.null_bit_buffer(null_buffer)
.null_count(null_count);

let array_data = unsafe { array_builder.build_unchecked() };
Ok(Ok(PrimitiveArray::<T>::from(array_data)))
}
}

/// This intentional inline(never) attribute helps LLVM optimize the loop.
#[inline(never)]
fn try_binary_no_nulls<A: ArrayAccessor, B: ArrayAccessor, F, O>(
Expand All @@ -310,6 +456,35 @@ where
Ok(unsafe { build_primitive_array(len, buffer.into(), 0, None) })
}

/// This intentional inline(never) attribute helps LLVM optimize the loop.
#[inline(never)]
fn try_binary_no_nulls_mut<T, F>(
len: usize,
a: PrimitiveArray<T>,
b: &PrimitiveArray<T>,
op: F,
) -> std::result::Result<
std::result::Result<PrimitiveArray<T>, ArrowError>,
PrimitiveArray<T>,
>
where
T: ArrowPrimitiveType,
F: Fn(T::Native, T::Native) -> Result<T::Native>,
{
let mut builder = a.into_builder()?;
let slice = builder.values_slice_mut();

for idx in 0..len {
unsafe {
match op(*slice.get_unchecked(idx), b.value_unchecked(idx)) {
Ok(value) => *slice.get_unchecked_mut(idx) = value,
Err(err) => return Ok(Err(err)),
};
};
}
Ok(Ok(builder.finish()))
}

#[inline(never)]
fn try_binary_opt_no_nulls<A: ArrayAccessor, B: ArrayAccessor, F, O>(
len: usize,
Expand Down Expand Up @@ -385,6 +560,7 @@ mod tests {
use super::*;
use crate::array::{as_primitive_array, Float64Array, PrimitiveDictionaryBuilder};
use crate::datatypes::{Float64Type, Int32Type, Int8Type};
use arrow_array::Int32Array;

#[test]
fn test_unary_f64_slice() {
Expand Down Expand Up @@ -444,4 +620,44 @@ mod tests {
&expected
);
}

#[test]
fn test_binary_mut() {
let a = Int32Array::from(vec![15, 14, 9, 8, 1]);
let b = Int32Array::from(vec![Some(1), None, Some(3), None, Some(5)]);
let c = binary_mut(a, &b, |l, r| l + r).unwrap().unwrap();

let expected = Int32Array::from(vec![Some(16), None, Some(12), None, Some(6)]);
assert_eq!(c, expected);
}

#[test]
fn test_try_binary_mut() {
let a = Int32Array::from(vec![15, 14, 9, 8, 1]);
let b = Int32Array::from(vec![Some(1), None, Some(3), None, Some(5)]);
let c = try_binary_mut(a, &b, |l, r| Ok(l + r)).unwrap().unwrap();

let expected = Int32Array::from(vec![Some(16), None, Some(12), None, Some(6)]);
assert_eq!(c, expected);

let a = Int32Array::from(vec![15, 14, 9, 8, 1]);
let b = Int32Array::from(vec![1, 2, 3, 4, 5]);
let c = try_binary_mut(a, &b, |l, r| Ok(l + r)).unwrap().unwrap();
let expected = Int32Array::from(vec![16, 16, 12, 12, 6]);
assert_eq!(c, expected);

let a = Int32Array::from(vec![15, 14, 9, 8, 1]);
let b = Int32Array::from(vec![Some(1), None, Some(3), None, Some(5)]);
let _ = try_binary_mut(a, &b, |l, r| {
if l == 1 {
Err(ArrowError::InvalidArgumentError(
"got error".parse().unwrap(),
))
} else {
Ok(l + r)
}
})
.unwrap()
.expect_err("should got error");
}
}

0 comments on commit 961e114

Please sign in to comment.