Skip to content

Commit

Permalink
Add divide dyn kernel which produces null for division by zero (#2764)
Browse files Browse the repository at this point in the history
* Add divide_dyn_opt kernel

* Add test

* Fix clippy

* Rename function
  • Loading branch information
viirya committed Sep 22, 2022
1 parent 48cc8be commit 80c0f1a
Show file tree
Hide file tree
Showing 2 changed files with 143 additions and 28 deletions.
102 changes: 102 additions & 0 deletions arrow/src/compute/kernels/arithmetic.rs
Expand Up @@ -697,6 +697,39 @@ where
)
}

#[cfg(feature = "dyn_arith_dict")]
fn math_divide_safe_op_dict<K, T, F>(
left: &DictionaryArray<K>,
right: &DictionaryArray<K>,
op: F,
) -> Result<ArrayRef>
where
K: ArrowNumericType,
T: ArrowNumericType,
T::Native: One + Zero,
F: Fn(T::Native, T::Native) -> Option<T::Native>,
{
let left = left.downcast_dict::<PrimitiveArray<T>>().unwrap();
let right = right.downcast_dict::<PrimitiveArray<T>>().unwrap();
let array: PrimitiveArray<T> = binary_opt::<_, _, _, T>(left, right, op)?;
Ok(Arc::new(array) as ArrayRef)
}

fn math_safe_divide_op<LT, RT, F>(
left: &PrimitiveArray<LT>,
right: &PrimitiveArray<RT>,
op: F,
) -> Result<ArrayRef>
where
LT: ArrowNumericType,
RT: ArrowNumericType,
RT::Native: One + Zero,
F: Fn(LT::Native, RT::Native) -> Option<LT::Native>,
{
let array: PrimitiveArray<LT> = binary_opt::<_, _, _, LT>(left, right, op)?;
Ok(Arc::new(array) as ArrayRef)
}

/// Perform `left + right` operation on two arrays. If either left or right value is null
/// then the result is also null.
///
Expand Down Expand Up @@ -1406,6 +1439,51 @@ pub fn divide_dyn_checked(left: &dyn Array, right: &dyn Array) -> Result<ArrayRe
}
}

/// Perform `left / right` operation on two arrays. If either left or right value is null
/// then the result is also null.
///
/// If any right hand value is zero, the operation value will be replaced with null in the
/// result.
///
/// Unlike `divide_dyn` or `divide_dyn_checked`, division by zero will get a null value instead
/// returning an `Err`, this also doesn't check overflowing, overflowing will just wrap
/// the result around.
pub fn divide_dyn_opt(left: &dyn Array, right: &dyn Array) -> Result<ArrayRef> {
match left.data_type() {
DataType::Dictionary(_, _) => {
typed_dict_math_op!(
left,
right,
|a, b| {
if b.is_zero() {
None
} else {
Some(a.div_wrapping(b))
}
},
math_divide_safe_op_dict
)
}
_ => {
downcast_primitive_array!(
(left, right) => {
math_safe_divide_op(left, right, |a, b| {
if b.is_zero() {
None
} else {
Some(a.div_wrapping(b))
}
})
}
_ => Err(ArrowError::CastError(format!(
"Unsupported data type {}, {}",
left.data_type(), right.data_type()
)))
)
}
}
}

/// Perform `left / right` operation on two arrays without checking for division by zero.
/// For floating point types, the result of dividing by zero follows normal floating point
/// rules. For other numeric types, dividing by zero will panic,
Expand Down Expand Up @@ -2752,4 +2830,28 @@ mod tests {
let overflow = divide_dyn_checked(&a, &b);
overflow.expect_err("overflow should be detected");
}

#[test]
#[cfg(feature = "dyn_arith_dict")]
fn test_div_dyn_opt_overflow_division_by_zero() {
let a = Int32Array::from(vec![i32::MIN]);
let b = Int32Array::from(vec![0]);

let division_by_zero = divide_dyn_opt(&a, &b);
let expected = Arc::new(Int32Array::from(vec![None])) as ArrayRef;
assert_eq!(&expected, &division_by_zero.unwrap());

let mut builder =
PrimitiveDictionaryBuilder::<Int8Type, Int32Type>::with_capacity(1, 1);
builder.append(i32::MIN).unwrap();
let a = builder.finish();

let mut builder =
PrimitiveDictionaryBuilder::<Int8Type, Int32Type>::with_capacity(1, 1);
builder.append(0).unwrap();
let b = builder.finish();

let division_by_zero = divide_dyn_opt(&a, &b);
assert_eq!(&expected, &division_by_zero.unwrap());
}
}
69 changes: 41 additions & 28 deletions arrow/src/compute/kernels/arity.rs
Expand Up @@ -357,6 +357,26 @@ where
Ok(unsafe { build_primitive_array(len, buffer.into(), 0, None) })
}

#[inline(never)]
fn try_binary_opt_no_nulls<A: ArrayAccessor, B: ArrayAccessor, F, O>(
len: usize,
a: A,
b: B,
op: F,
) -> Result<PrimitiveArray<O>>
where
O: ArrowPrimitiveType,
F: Fn(A::Item, B::Item) -> Option<O::Native>,
{
let mut buffer = Vec::with_capacity(10);
for idx in 0..len {
unsafe {
buffer.push(op(a.value_unchecked(idx), b.value_unchecked(idx)));
};
}
Ok(buffer.iter().collect())
}

/// Applies the provided binary operation across `a` and `b`, collecting the optional results
/// into a [`PrimitiveArray`]. If any index is null in either `a` or `b`, the corresponding
/// index in the result will also be null. The binary operation could return `None` which
Expand All @@ -367,16 +387,14 @@ where
/// # Error
///
/// This function gives error if the arrays have different lengths
pub(crate) fn binary_opt<A, B, F, O>(
a: &PrimitiveArray<A>,
b: &PrimitiveArray<B>,
pub(crate) fn binary_opt<A: ArrayAccessor + Array, B: ArrayAccessor + Array, F, O>(
a: A,
b: B,
op: F,
) -> Result<PrimitiveArray<O>>
where
A: ArrowPrimitiveType,
B: ArrowPrimitiveType,
O: ArrowPrimitiveType,
F: Fn(A::Native, B::Native) -> Option<O::Native>,
F: Fn(A::Item, B::Item) -> Option<O::Native>,
{
if a.len() != b.len() {
return Err(ArrowError::ComputeError(
Expand All @@ -389,29 +407,24 @@ where
}

if a.null_count() == 0 && b.null_count() == 0 {
Ok(a.values()
.iter()
.zip(b.values().iter())
.map(|(a, b)| op(*a, *b))
.collect())
} else {
let iter_a = ArrayIter::new(a);
let iter_b = ArrayIter::new(b);

let values =
iter_a
.into_iter()
.zip(iter_b.into_iter())
.map(|(item_a, item_b)| {
if let (Some(a), Some(b)) = (item_a, item_b) {
op(a, b)
} else {
None
}
});

Ok(values.collect())
return try_binary_opt_no_nulls(a.len(), a, b, op);
}

let iter_a = ArrayIter::new(a);
let iter_b = ArrayIter::new(b);

let values = iter_a
.into_iter()
.zip(iter_b.into_iter())
.map(|(item_a, item_b)| {
if let (Some(a), Some(b)) = (item_a, item_b) {
op(a, b)
} else {
None
}
});

Ok(values.collect())
}

#[cfg(test)]
Expand Down

0 comments on commit 80c0f1a

Please sign in to comment.