diff --git a/datafusion/core/tests/sql/decimal.rs b/datafusion/core/tests/sql/decimal.rs index e0c2c1773e48..f101777d7df4 100644 --- a/datafusion/core/tests/sql/decimal.rs +++ b/datafusion/core/tests/sql/decimal.rs @@ -582,20 +582,20 @@ async fn decimal_arithmetic_op() -> Result<()> { "+---------------------------------------+", "| decimal_simple.c1 / decimal_simple.c5 |", "+---------------------------------------+", - "| 0.7142857142857143296 |", + "| 0.7142857142857142857 |", "| 0.8000000000000000000 |", - "| 1.0526315789473683456 |", + "| 1.0526315789473684210 |", "| 0.9375000000000000000 |", - "| 0.8571428571428571136 |", - "| 2.7272727272727269376 |", - "| 0.9090909090909090816 |", + "| 0.8571428571428571428 |", + "| 2.7272727272727272727 |", + "| 0.9090909090909090909 |", "| 1.0000000000000000000 |", "| 1.0000000000000000000 |", - "| 0.9090909090909090816 |", - "| 0.9615384615384614912 |", - "| 0.6410256410256410624 |", - "| 1.5151515151515152384 |", - "| 0.7352941176470588416 |", + "| 0.9090909090909090909 |", + "| 0.9615384615384615384 |", + "| 0.6410256410256410256 |", + "| 1.5151515151515151515 |", + "| 0.7352941176470588235 |", "| 0.5000000000000000000 |", "+---------------------------------------+", ]; diff --git a/datafusion/physical-expr/src/expressions/binary.rs b/datafusion/physical-expr/src/expressions/binary.rs index 83365db94287..d88fa07c5bea 100644 --- a/datafusion/physical-expr/src/expressions/binary.rs +++ b/datafusion/physical-expr/src/expressions/binary.rs @@ -1152,7 +1152,9 @@ mod tests { use super::*; use crate::expressions::try_cast; use crate::expressions::{col, lit}; - use arrow::datatypes::{ArrowNumericType, Field, Int32Type, SchemaRef}; + use arrow::datatypes::{ + ArrowNumericType, Decimal128Type, Field, Int32Type, SchemaRef, + }; use datafusion_common::{ColumnStatistics, Result, Statistics}; use datafusion_expr::type_coercion::binary::coerce_types; @@ -3048,6 +3050,43 @@ mod tests { Ok(()) } + #[test] + fn arithmetic_divide_zero() -> Result<()> { + // other data type + let schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int32, true), + Field::new("b", DataType::Int32, true), + ])); + let a = Arc::new(Int32Array::from(vec![8, 32, 128, 512, 2048, 100])); + let b = Arc::new(Int32Array::from(vec![2, 4, 8, 16, 32, 0])); + + apply_arithmetic::( + schema, + vec![a, b], + Operator::Divide, + Int32Array::from(vec![Some(4), Some(8), Some(16), Some(32), Some(64), None]), + )?; + + // decimal + let schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Decimal128(25, 3), true), + Field::new("b", DataType::Decimal128(25, 3), true), + ])); + let left_decimal_array = + Arc::new(create_decimal_array(&[Some(1234567), Some(1234567)], 25, 3)); + let right_decimal_array = + Arc::new(create_decimal_array(&[Some(10), Some(0)], 25, 3)); + + apply_arithmetic::( + schema, + vec![left_decimal_array, right_decimal_array], + Operator::Divide, + create_decimal_array(&[Some(123456700), None], 25, 3), + )?; + + Ok(()) + } + #[test] fn bitwise_array_test() -> Result<()> { let left = Arc::new(Int32Array::from(vec![Some(12), None, Some(11)])) as ArrayRef; @@ -3270,6 +3309,7 @@ mod tests { } Ok(()) } + #[test] fn test_comparison_result_estimate_different_type() -> Result<()> { // A table where the column 'a' has a min of 1.3, a max of 50.7. diff --git a/datafusion/physical-expr/src/expressions/binary/kernels_arrow.rs b/datafusion/physical-expr/src/expressions/binary/kernels_arrow.rs index 2523a56dfadf..2135982b67f8 100644 --- a/datafusion/physical-expr/src/expressions/binary/kernels_arrow.rs +++ b/datafusion/physical-expr/src/expressions/binary/kernels_arrow.rs @@ -18,9 +18,12 @@ //! This module contains computation kernels that are eventually //! destined for arrow-rs but are in datafusion until they are ported. -use arrow::error::ArrowError; +use arrow::compute::{ + add, add_scalar, divide_opt, divide_scalar, modulus, modulus_scalar, multiply, + multiply_scalar, subtract, subtract_scalar, +}; use arrow::{array::*, datatypes::ArrowNumericType}; -use datafusion_common::{DataFusionError, Result}; +use datafusion_common::Result; // Simple (low performance) kernels until optimized kernels are added to arrow // See https://github.com/apache/arrow-rs/issues/960 @@ -171,53 +174,12 @@ pub(crate) fn is_not_distinct_from_decimal( .collect()) } -/// Creates an Decimal128Array the same size as `left`, -/// by applying `op` to all non-null elements of left and right -pub(crate) fn arith_decimal( - left: &Decimal128Array, - right: &Decimal128Array, - op: F, -) -> Result -where - F: Fn(i128, i128) -> Result, -{ - left.iter() - .zip(right.iter()) - .map(|(left, right)| { - if let (Some(left), Some(right)) = (left, right) { - Some(op(left, right)).transpose() - } else { - Ok(None) - } - }) - .collect() -} - -pub(crate) fn arith_decimal_scalar( - left: &Decimal128Array, - right: i128, - op: F, -) -> Result -where - F: Fn(i128, i128) -> Result, -{ - left.iter() - .map(|left| { - if let Some(left) = left { - Some(op(left, right)).transpose() - } else { - Ok(None) - } - }) - .collect() -} - pub(crate) fn add_decimal( left: &Decimal128Array, right: &Decimal128Array, ) -> Result { - let array = arith_decimal(left, right, |left, right| Ok(left + right))? - .with_precision_and_scale(left.precision(), left.scale())?; + let array = + add(left, right)?.with_precision_and_scale(left.precision(), left.scale())?; Ok(array) } @@ -225,7 +187,7 @@ pub(crate) fn add_decimal_scalar( left: &Decimal128Array, right: i128, ) -> Result { - let array = arith_decimal_scalar(left, right, |left, right| Ok(left + right))? + let array = add_scalar(left, right)? .with_precision_and_scale(left.precision(), left.scale())?; Ok(array) } @@ -234,7 +196,7 @@ pub(crate) fn subtract_decimal( left: &Decimal128Array, right: &Decimal128Array, ) -> Result { - let array = arith_decimal(left, right, |left, right| Ok(left - right))? + let array = subtract(left, right)? .with_precision_and_scale(left.precision(), left.scale())?; Ok(array) } @@ -243,7 +205,7 @@ pub(crate) fn subtract_decimal_scalar( left: &Decimal128Array, right: i128, ) -> Result { - let array = arith_decimal_scalar(left, right, |left, right| Ok(left - right))? + let array = subtract_scalar(left, right)? .with_precision_and_scale(left.precision(), left.scale())?; Ok(array) } @@ -253,7 +215,8 @@ pub(crate) fn multiply_decimal( right: &Decimal128Array, ) -> Result { let divide = 10_i128.pow(left.scale() as u32); - let array = arith_decimal(left, right, |left, right| Ok(left * right / divide))? + let array = multiply(left, right)?; + let array = divide_scalar(&array, divide)? .with_precision_and_scale(left.precision(), left.scale())?; Ok(array) } @@ -262,10 +225,10 @@ pub(crate) fn multiply_decimal_scalar( left: &Decimal128Array, right: i128, ) -> Result { + let array = multiply_scalar(left, right)?; let divide = 10_i128.pow(left.scale() as u32); - let array = - arith_decimal_scalar(left, right, |left, right| Ok(left * right / divide))? - .with_precision_and_scale(left.precision(), left.scale())?; + let array = divide_scalar(&array, divide)? + .with_precision_and_scale(left.precision(), left.scale())?; Ok(array) } @@ -273,17 +236,10 @@ pub(crate) fn divide_opt_decimal( left: &Decimal128Array, right: &Decimal128Array, ) -> Result { - let mul = 10_f64.powi(left.scale() as i32); - let array = arith_decimal(left, right, |left, right| { - if right == 0 { - return Err(DataFusionError::ArrowError(ArrowError::DivideByZero)); - } - let l_value = left as f64; - let r_value = right as f64; - let result = ((l_value / r_value) * mul) as i128; - Ok(result) - })? - .with_precision_and_scale(left.precision(), left.scale())?; + let mul = 10_i128.pow(left.scale() as u32); + let array = multiply_scalar(left, mul)?; + let array = divide_opt(&array, right)? + .with_precision_and_scale(left.precision(), left.scale())?; Ok(array) } @@ -291,17 +247,11 @@ pub(crate) fn divide_decimal_scalar( left: &Decimal128Array, right: i128, ) -> Result { - if right == 0 { - return Err(DataFusionError::ArrowError(ArrowError::DivideByZero)); - } - let mul = 10_f64.powi(left.scale() as i32); - let array = arith_decimal_scalar(left, right, |left, right| { - let l_value = left as f64; - let r_value = right as f64; - let result = ((l_value / r_value) * mul) as i128; - Ok(result) - })? - .with_precision_and_scale(left.precision(), left.scale())?; + let mul = 10_i128.pow(left.scale() as u32); + let array = multiply_scalar(left, mul)?; + // `0` of right will be checked in `divide_scalar` + let array = divide_scalar(&array, right)? + .with_precision_and_scale(left.precision(), left.scale())?; Ok(array) } @@ -309,14 +259,8 @@ pub(crate) fn modulus_decimal( left: &Decimal128Array, right: &Decimal128Array, ) -> Result { - let array = arith_decimal(left, right, |left, right| { - if right == 0 { - Err(DataFusionError::ArrowError(ArrowError::DivideByZero)) - } else { - Ok(left % right) - } - })? - .with_precision_and_scale(left.precision(), left.scale())?; + let array = + modulus(left, right)?.with_precision_and_scale(left.precision(), left.scale())?; Ok(array) } @@ -324,10 +268,8 @@ pub(crate) fn modulus_decimal_scalar( left: &Decimal128Array, right: i128, ) -> Result { - if right == 0 { - return Err(DataFusionError::ArrowError(ArrowError::DivideByZero)); - } - let array = arith_decimal_scalar(left, right, |left, right| Ok(left % right))? + // `0` for right will be checked in `modulus_scalar` + let array = modulus_scalar(left, right)? .with_precision_and_scale(left.precision(), left.scale())?; Ok(array) } @@ -485,7 +427,6 @@ mod tests { 3, ); assert_eq!(expect, result); - // modulus let result = modulus_decimal(&left_decimal_array, &right_decimal_array)?; let expect = create_decimal_array(&[Some(7), None, Some(37), Some(16), None], 25, 3); @@ -503,9 +444,6 @@ mod tests { let left_decimal_array = create_decimal_array(&[Some(101)], 10, 1); let right_decimal_array = create_decimal_array(&[Some(0)], 1, 1); - let err = - divide_opt_decimal(&left_decimal_array, &right_decimal_array).unwrap_err(); - assert_eq!("Arrow error: Divide by zero error", err.to_string()); let err = divide_decimal_scalar(&left_decimal_array, 0).unwrap_err(); assert_eq!("Arrow error: Divide by zero error", err.to_string()); let err = modulus_decimal(&left_decimal_array, &right_decimal_array).unwrap_err(); @@ -558,7 +496,7 @@ mod tests { Some(false), Some(true), Some(false), - Some(true) + Some(true), ]), is_distinct_from(&left_int_array, &right_int_array)? ); @@ -570,7 +508,7 @@ mod tests { Some(true), Some(false), Some(true), - Some(false) + Some(false), ]), is_not_distinct_from(&left_int_array, &right_int_array)? );