diff --git a/datafusion/functions/src/math/abs.rs b/datafusion/functions/src/math/abs.rs index 8049ef85ac36..35d0f3eccf57 100644 --- a/datafusion/functions/src/math/abs.rs +++ b/datafusion/functions/src/math/abs.rs @@ -39,6 +39,7 @@ use num_traits::sign::Signed; type MathArrayFunction = fn(&ArrayRef) -> Result; +#[macro_export] macro_rules! make_abs_function { ($ARRAY_TYPE:ident) => {{ |input: &ArrayRef| { @@ -67,7 +68,8 @@ macro_rules! make_try_abs_function { }}; } -macro_rules! make_decimal_abs_function { +#[macro_export] +macro_rules! make_wrapping_abs_function { ($ARRAY_TYPE:ident) => {{ |input: &ArrayRef| { let array = downcast_named_arg!(&input, "abs arg", $ARRAY_TYPE); @@ -101,10 +103,10 @@ fn create_abs_function(input_data_type: &DataType) -> Result | DataType::UInt64 => Ok(|input: &ArrayRef| Ok(Arc::clone(input))), // Decimal types - DataType::Decimal32(_, _) => Ok(make_decimal_abs_function!(Decimal32Array)), - DataType::Decimal64(_, _) => Ok(make_decimal_abs_function!(Decimal64Array)), - DataType::Decimal128(_, _) => Ok(make_decimal_abs_function!(Decimal128Array)), - DataType::Decimal256(_, _) => Ok(make_decimal_abs_function!(Decimal256Array)), + DataType::Decimal32(_, _) => Ok(make_wrapping_abs_function!(Decimal32Array)), + DataType::Decimal64(_, _) => Ok(make_wrapping_abs_function!(Decimal64Array)), + DataType::Decimal128(_, _) => Ok(make_wrapping_abs_function!(Decimal128Array)), + DataType::Decimal256(_, _) => Ok(make_wrapping_abs_function!(Decimal256Array)), other => not_impl_err!("Unsupported data type {other:?} for function abs"), } diff --git a/datafusion/spark/src/function/math/abs.rs b/datafusion/spark/src/function/math/abs.rs new file mode 100644 index 000000000000..f48f8964c28c --- /dev/null +++ b/datafusion/spark/src/function/math/abs.rs @@ -0,0 +1,378 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::array::*; +use arrow::datatypes::DataType; +use datafusion_common::{internal_err, DataFusionError, Result, ScalarValue}; +use datafusion_expr::{ + ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility, +}; +use datafusion_functions::{ + downcast_named_arg, make_abs_function, make_wrapping_abs_function, +}; +use std::any::Any; +use std::sync::Arc; + +/// Spark-compatible `abs` expression +/// +/// +/// Returns the absolute value of input +/// Returns NULL if input is NULL, returns NaN if input is NaN. +/// +/// TODOs: +/// - Spark's ANSI-compliant dialect, when off (i.e. `spark.sql.ansi.enabled=false`), taking absolute value on the minimal value of a signed integer returns the value as is. DataFusion's abs throws "DataFusion error: Arrow error: Compute error" on arithmetic overflow +/// - Spark's abs also supports ANSI interval types: YearMonthIntervalType and DayTimeIntervalType. DataFusion's abs doesn't. +/// +#[derive(Debug, PartialEq, Eq, Hash)] +pub struct SparkAbs { + signature: Signature, +} + +impl Default for SparkAbs { + fn default() -> Self { + Self::new() + } +} + +impl SparkAbs { + pub fn new() -> Self { + Self { + signature: Signature::numeric(1, Volatility::Immutable), + } + } +} + +impl ScalarUDFImpl for SparkAbs { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "abs" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + Ok(arg_types[0].clone()) + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + spark_abs(&args.args) + } +} + +macro_rules! scalar_compute_op { + ($INPUT:ident, $SCALAR_TYPE:ident) => {{ + let result = $INPUT.wrapping_abs(); + Ok(ColumnarValue::Scalar(ScalarValue::$SCALAR_TYPE(Some( + result, + )))) + }}; + ($INPUT:ident, $PRECISION:expr, $SCALE:expr, $SCALAR_TYPE:ident) => {{ + let result = $INPUT.wrapping_abs(); + Ok(ColumnarValue::Scalar(ScalarValue::$SCALAR_TYPE( + Some(result), + $PRECISION, + $SCALE, + ))) + }}; +} + +pub fn spark_abs(args: &[ColumnarValue]) -> Result { + if args.len() != 1 { + return internal_err!("abs takes exactly 1 argument, but got: {}", args.len()); + } + + match &args[0] { + ColumnarValue::Array(array) => match array.data_type() { + DataType::Null + | DataType::UInt8 + | DataType::UInt16 + | DataType::UInt32 + | DataType::UInt64 => Ok(args[0].clone()), + DataType::Int8 => { + let abs_fun = make_wrapping_abs_function!(Int8Array); + abs_fun(array).map(ColumnarValue::Array) + } + DataType::Int16 => { + let abs_fun = make_wrapping_abs_function!(Int16Array); + abs_fun(array).map(ColumnarValue::Array) + } + DataType::Int32 => { + let abs_fun = make_wrapping_abs_function!(Int32Array); + abs_fun(array).map(ColumnarValue::Array) + } + DataType::Int64 => { + let abs_fun = make_wrapping_abs_function!(Int64Array); + abs_fun(array).map(ColumnarValue::Array) + } + DataType::Float32 => { + let abs_fun = make_abs_function!(Float32Array); + abs_fun(array).map(ColumnarValue::Array) + } + DataType::Float64 => { + let abs_fun = make_abs_function!(Float64Array); + abs_fun(array).map(ColumnarValue::Array) + } + DataType::Decimal128(_, _) => { + let abs_fun = make_wrapping_abs_function!(Decimal128Array); + abs_fun(array).map(ColumnarValue::Array) + } + DataType::Decimal256(_, _) => { + let abs_fun = make_wrapping_abs_function!(Decimal256Array); + abs_fun(array).map(ColumnarValue::Array) + } + dt => internal_err!("Not supported datatype for Spark ABS: {dt}"), + }, + ColumnarValue::Scalar(sv) => match sv { + ScalarValue::Null + | ScalarValue::UInt8(_) + | ScalarValue::UInt16(_) + | ScalarValue::UInt32(_) + | ScalarValue::UInt64(_) => Ok(args[0].clone()), + sv if sv.is_null() => Ok(args[0].clone()), + ScalarValue::Int8(Some(v)) => scalar_compute_op!(v, Int8), + ScalarValue::Int16(Some(v)) => scalar_compute_op!(v, Int16), + ScalarValue::Int32(Some(v)) => scalar_compute_op!(v, Int32), + ScalarValue::Int64(Some(v)) => scalar_compute_op!(v, Int64), + ScalarValue::Float32(Some(v)) => { + Ok(ColumnarValue::Scalar(ScalarValue::Float32(Some(v.abs())))) + } + ScalarValue::Float64(Some(v)) => { + Ok(ColumnarValue::Scalar(ScalarValue::Float64(Some(v.abs())))) + } + ScalarValue::Decimal128(Some(v), precision, scale) => { + scalar_compute_op!(v, *precision, *scale, Decimal128) + } + ScalarValue::Decimal256(Some(v), precision, scale) => { + scalar_compute_op!(v, *precision, *scale, Decimal256) + } + dt => internal_err!("Not supported datatype for Spark ABS: {dt}"), + }, + } +} + +#[cfg(test)] +mod tests { + use super::*; + use arrow::datatypes::i256; + + macro_rules! eval_legacy_mode { + ($TYPE:ident, $VAL:expr) => {{ + let args = ColumnarValue::Scalar(ScalarValue::$TYPE(Some($VAL))); + match spark_abs(&[args]) { + Ok(ColumnarValue::Scalar(ScalarValue::$TYPE(Some(result)))) => { + assert_eq!(result, $VAL); + } + _ => unreachable!(), + } + }}; + ($TYPE:ident, $VAL:expr, $RESULT:expr) => {{ + let args = ColumnarValue::Scalar(ScalarValue::$TYPE(Some($VAL))); + match spark_abs(&[args]) { + Ok(ColumnarValue::Scalar(ScalarValue::$TYPE(Some(result)))) => { + assert_eq!(result, $RESULT); + } + _ => unreachable!(), + } + }}; + ($TYPE:ident, $VAL:expr, $PRECISION:expr, $SCALE:expr) => {{ + let args = + ColumnarValue::Scalar(ScalarValue::$TYPE(Some($VAL), $PRECISION, $SCALE)); + match spark_abs(&[args]) { + Ok(ColumnarValue::Scalar(ScalarValue::$TYPE( + Some(result), + precision, + scale, + ))) => { + assert_eq!(result, $VAL); + assert_eq!(precision, $PRECISION); + assert_eq!(scale, $SCALE); + } + _ => unreachable!(), + } + }}; + ($TYPE:ident, $VAL:expr, $PRECISION:expr, $SCALE:expr, $RESULT:expr) => {{ + let args = + ColumnarValue::Scalar(ScalarValue::$TYPE(Some($VAL), $PRECISION, $SCALE)); + match spark_abs(&[args]) { + Ok(ColumnarValue::Scalar(ScalarValue::$TYPE( + Some(result), + precision, + scale, + ))) => { + assert_eq!(result, $RESULT); + assert_eq!(precision, $PRECISION); + assert_eq!(scale, $SCALE); + } + _ => unreachable!(), + } + }}; + } + + #[test] + fn test_abs_scalar_legacy_mode() { + // NumericType MIN + eval_legacy_mode!(UInt8, u8::MIN); + eval_legacy_mode!(UInt16, u16::MIN); + eval_legacy_mode!(UInt32, u32::MIN); + eval_legacy_mode!(UInt64, u64::MIN); + eval_legacy_mode!(Int8, i8::MIN); + eval_legacy_mode!(Int16, i16::MIN); + eval_legacy_mode!(Int32, i32::MIN); + eval_legacy_mode!(Int64, i64::MIN); + eval_legacy_mode!(Float32, f32::MIN, f32::MAX); + eval_legacy_mode!(Float64, f64::MIN, f64::MAX); + eval_legacy_mode!(Decimal128, i128::MIN, 18, 10); + eval_legacy_mode!(Decimal256, i256::MIN, 10, 2); + + // NumericType not MIN + eval_legacy_mode!(Int8, -1i8, 1i8); + eval_legacy_mode!(Int16, -1i16, 1i16); + eval_legacy_mode!(Int32, -1i32, 1i32); + eval_legacy_mode!(Int64, -1i64, 1i64); + eval_legacy_mode!(Decimal128, -1i128, 18, 10, 1i128); + eval_legacy_mode!(Decimal256, i256::from(-1i8), 10, 2, i256::from(1i8)); + + // Float32, Float64 + eval_legacy_mode!(Float32, f32::NEG_INFINITY, f32::INFINITY); + eval_legacy_mode!(Float32, f32::INFINITY, f32::INFINITY); + eval_legacy_mode!(Float32, 0.0f32, 0.0f32); + eval_legacy_mode!(Float32, -0.0f32, 0.0f32); + eval_legacy_mode!(Float64, f64::NEG_INFINITY, f64::INFINITY); + eval_legacy_mode!(Float64, f64::INFINITY, f64::INFINITY); + eval_legacy_mode!(Float64, 0.0f64, 0.0f64); + eval_legacy_mode!(Float64, -0.0f64, 0.0f64); + } + + macro_rules! eval_array_legacy_mode { + ($INPUT:expr, $OUTPUT:expr, $FUNC:ident) => {{ + let input = $INPUT; + let args = ColumnarValue::Array(Arc::new(input)); + let expected = $OUTPUT; + match spark_abs(&[args]) { + Ok(ColumnarValue::Array(result)) => { + let actual = datafusion_common::cast::$FUNC(&result).unwrap(); + assert_eq!(actual, &expected); + } + _ => unreachable!(), + } + }}; + } + + #[test] + fn test_abs_array_legacy_mode() { + eval_array_legacy_mode!( + Int8Array::from(vec![Some(-1), Some(i8::MIN), Some(i8::MAX), None]), + Int8Array::from(vec![Some(1), Some(i8::MIN), Some(i8::MAX), None]), + as_int8_array + ); + + eval_array_legacy_mode!( + Int16Array::from(vec![Some(-1), Some(i16::MIN), Some(i16::MAX), None]), + Int16Array::from(vec![Some(1), Some(i16::MIN), Some(i16::MAX), None]), + as_int16_array + ); + + eval_array_legacy_mode!( + Int32Array::from(vec![Some(-1), Some(i32::MIN), Some(i32::MAX), None]), + Int32Array::from(vec![Some(1), Some(i32::MIN), Some(i32::MAX), None]), + as_int32_array + ); + + eval_array_legacy_mode!( + Int64Array::from(vec![Some(-1), Some(i64::MIN), Some(i64::MAX), None]), + Int64Array::from(vec![Some(1), Some(i64::MIN), Some(i64::MAX), None]), + as_int64_array + ); + + eval_array_legacy_mode!( + Float32Array::from(vec![ + Some(-1f32), + Some(f32::MIN), + Some(f32::MAX), + None, + Some(f32::NAN), + Some(f32::INFINITY), + Some(f32::NEG_INFINITY), + Some(0.0), + Some(-0.0), + ]), + Float32Array::from(vec![ + Some(1f32), + Some(f32::MAX), + Some(f32::MAX), + None, + Some(f32::NAN), + Some(f32::INFINITY), + Some(f32::INFINITY), + Some(0.0), + Some(0.0), + ]), + as_float32_array + ); + + eval_array_legacy_mode!( + Float64Array::from(vec![ + Some(-1f64), + Some(f64::MIN), + Some(f64::MAX), + None, + Some(f64::NAN), + Some(f64::INFINITY), + Some(f64::NEG_INFINITY), + Some(0.0), + Some(-0.0), + ]), + Float64Array::from(vec![ + Some(1f64), + Some(f64::MAX), + Some(f64::MAX), + None, + Some(f64::NAN), + Some(f64::INFINITY), + Some(f64::INFINITY), + Some(0.0), + Some(0.0), + ]), + as_float64_array + ); + + eval_array_legacy_mode!( + Decimal128Array::from(vec![Some(i128::MIN), None]) + .with_precision_and_scale(38, 37) + .unwrap(), + Decimal128Array::from(vec![Some(i128::MIN), None]) + .with_precision_and_scale(38, 37) + .unwrap(), + as_decimal128_array + ); + + eval_array_legacy_mode!( + Decimal256Array::from(vec![Some(i256::MIN), None]) + .with_precision_and_scale(5, 2) + .unwrap(), + Decimal256Array::from(vec![Some(i256::MIN), None]) + .with_precision_and_scale(5, 2) + .unwrap(), + as_decimal256_array + ); + } +} diff --git a/datafusion/spark/src/function/math/mod.rs b/datafusion/spark/src/function/math/mod.rs index fe8c2b1da0df..74fa4cf37ca5 100644 --- a/datafusion/spark/src/function/math/mod.rs +++ b/datafusion/spark/src/function/math/mod.rs @@ -15,6 +15,7 @@ // specific language governing permissions and limitations // under the License. +pub mod abs; pub mod expm1; pub mod factorial; pub mod hex; @@ -27,6 +28,7 @@ use datafusion_expr::ScalarUDF; use datafusion_functions::make_udf_function; use std::sync::Arc; +make_udf_function!(abs::SparkAbs, abs); make_udf_function!(expm1::SparkExpm1, expm1); make_udf_function!(factorial::SparkFactorial, factorial); make_udf_function!(hex::SparkHex, hex); @@ -40,6 +42,7 @@ make_udf_function!(trigonometry::SparkSec, sec); pub mod expr_fn { use datafusion_functions::export_functions; + export_functions!((abs, "Returns abs(expr)", arg1)); export_functions!((expm1, "Returns exp(expr) - 1 as a Float64.", arg1)); export_functions!(( factorial, @@ -57,6 +60,7 @@ pub mod expr_fn { pub fn functions() -> Vec> { vec![ + abs(), expm1(), factorial(), hex(), diff --git a/datafusion/sqllogictest/test_files/spark/math/abs.slt b/datafusion/sqllogictest/test_files/spark/math/abs.slt index 4b9edf7e29f2..19ca902ea3de 100644 --- a/datafusion/sqllogictest/test_files/spark/math/abs.slt +++ b/datafusion/sqllogictest/test_files/spark/math/abs.slt @@ -23,10 +23,75 @@ ## Original Query: SELECT abs(-1); ## PySpark 3.5.5 Result: {'abs(-1)': 1, 'typeof(abs(-1))': 'int', 'typeof(-1)': 'int'} -#query -#SELECT abs(-1::int); + +# abs: signed int and NULL +query IIIIR +SELECT abs(-127::TINYINT), abs(-32767::SMALLINT), abs(-2147483647::INT), abs(-9223372036854775807::BIGINT), abs(NULL); +---- +127 32767 2147483647 9223372036854775807 NULL + + +# See https://github.com/apache/datafusion/issues/18794 for operator precedence +# abs: signed int minimal values +query IIII +select abs((-128)::TINYINT), abs((-32768)::SMALLINT), abs((-2147483648)::INT), abs((-9223372036854775808)::BIGINT) +---- +-128 -32768 -2147483648 -9223372036854775808 + +# abs: floats, NULL, NaN, -0, infinity, -infinity +query RRRRRRRRRRRR +SELECT abs(-1.0::FLOAT), abs(0.::FLOAT), abs(-0.::FLOAT), abs(-0::FLOAT), abs(NULL::FLOAT), abs('NaN'::FLOAT), abs('inf'::FLOAT), abs('+inf'::FLOAT), abs('-inf'::FLOAT), abs('infinity'::FLOAT), abs('+infinity'::FLOAT), abs('-infinity'::FLOAT) +---- +1 0 0 0 NULL NaN Infinity Infinity Infinity Infinity Infinity Infinity + +# abs: doubles, NULL, NaN, -0, infinity, -infinity +query RRRRRRRRRRRR +SELECT abs(-1.0::DOUBLE), abs(0.::DOUBLE), abs(-0.::DOUBLE), abs(-0::DOUBLE), abs(NULL::DOUBLE), abs('NaN'::DOUBLE), abs('inf'::DOUBLE), abs('+inf'::DOUBLE), abs('-inf'::DOUBLE), abs('infinity'::DOUBLE), abs('+infinity'::DOUBLE), abs('-infinity'::DOUBLE) +---- +1 0 0 0 NULL NaN Infinity Infinity Infinity Infinity Infinity Infinity + +# abs: decimal128 and decimal256 +statement ok +CREATE TABLE test_nullable_decimal( + c1 DECIMAL(10, 2), /* Decimal128 */ + c2 DECIMAL(38, 10), /* Decimal128 with max precision */ + c3 DECIMAL(40, 2), /* Decimal256 */ + c4 DECIMAL(76, 10) /* Decimal256 with max precision */ + ) AS VALUES + (0, 0, 0, 0), + (NULL, NULL, NULL, NULL); + +query I +INSERT into test_nullable_decimal values + ( + -99999999.99, + '-9999999999999999999999999999.9999999999', + '-99999999999999999999999999999999999999.99', + '-999999999999999999999999999999999999999999999999999999999999999999.9999999999' + ), + ( + 99999999.99, + '9999999999999999999999999999.9999999999', + '99999999999999999999999999999999999999.99', + '999999999999999999999999999999999999999999999999999999999999999999.9999999999' + ) +---- +2 + +query RRRR rowsort +SELECT abs(c1), abs(c2), abs(c3), abs(c4) FROM test_nullable_decimal +---- +0 0 0 0 +99999999.99 9999999999999999999999999999.9999999999 99999999999999999999999999999999999999.99 999999999999999999999999999999999999999999999999999999999999999999.9999999999 +99999999.99 9999999999999999999999999999.9999999999 99999999999999999999999999999999999999.99 999999999999999999999999999999999999999999999999999999999999999999.9999999999 +NULL NULL NULL NULL + + +statement ok +drop table test_nullable_decimal ## Original Query: SELECT abs(INTERVAL -'1-1' YEAR TO MONTH); ## PySpark 3.5.5 Result: {"abs(INTERVAL '-1-1' YEAR TO MONTH)": 13, "typeof(abs(INTERVAL '-1-1' YEAR TO MONTH))": 'interval year to month', "typeof(INTERVAL '-1-1' YEAR TO MONTH)": 'interval year to month'} #query #SELECT abs(INTERVAL '-1-1' YEAR TO MONTH::interval year to month); +# See GitHub issue for ANSI interval support: https://github.com/apache/datafusion/issues/18793