From b4eea32a0228897104d9fa6de5890f9db0de1a93 Mon Sep 17 00:00:00 2001 From: huanghsiang_cheng Date: Mon, 20 Oct 2025 16:14:28 -0700 Subject: [PATCH 01/25] feat: support Spark abs math function --- datafusion/spark/src/function/math/abs.rs | 1141 +++++++++++++++++++++ datafusion/spark/src/function/math/mod.rs | 4 + 2 files changed, 1145 insertions(+) create mode 100644 datafusion/spark/src/function/math/abs.rs diff --git a/datafusion/spark/src/function/math/abs.rs b/datafusion/spark/src/function/math/abs.rs new file mode 100644 index 000000000000..1c7118663f86 --- /dev/null +++ b/datafusion/spark/src/function/math/abs.rs @@ -0,0 +1,1141 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::function::error_utils::{ + invalid_arg_count_exec_err, unsupported_data_type_exec_err, +}; +use arrow::array::*; +use arrow::datatypes::DataType; +use arrow::datatypes::*; +use datafusion_common::DataFusionError::ArrowError; +use datafusion_common::{internal_err, DataFusionError, Result, ScalarValue}; +use datafusion_expr::{ + ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility, +}; +use std::any::Any; +use std::sync::Arc; + +/// Spark-compatible `abs` expression +/// +/// +/// Returns the absolute value of input +/// Returns NULL if input is NULL, returns NaN if input is NaN. +/// +/// Differences with DataFusion abs: +/// - Spark's ANSI-compliant dialect, when off (i.e. `spark.sql.ansi.enabled=false`), taking absolute values on minimal values of signed integers returns the value as is. DataFusion's abs throws "DataFusion error: Arrow error: Compute error" on arithmetic overflow +/// - Spark's abs also supports ANSI interval types: YearMonthIntervalType and DayTimeIntervalType. DataFusion's abs doesn't. +/// +#[derive(Debug, PartialEq, Eq, Hash)] +pub struct SparkAbs { + signature: Signature, +} + +impl Default for SparkAbs { + fn default() -> Self { + Self::new() + } +} + +impl SparkAbs { + pub fn new() -> Self { + Self { + signature: Signature::user_defined(Volatility::Immutable), + } + } +} + +impl ScalarUDFImpl for SparkAbs { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "abs" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + Ok(arg_types[0].clone()) + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + spark_abs(&args.args) + } + + fn coerce_types(&self, arg_types: &[DataType]) -> Result> { + if arg_types.len() > 2 { + return Err(invalid_arg_count_exec_err("abs", (1, 2), arg_types.len())); + } + match &arg_types[0] { + DataType::Null + | DataType::UInt8 + | DataType::UInt16 + | DataType::UInt32 + | DataType::UInt64 + | DataType::Int8 + | DataType::Int16 + | DataType::Int32 + | DataType::Int64 + | DataType::Float32 + | DataType::Float64 + | DataType::Decimal128(_, _) + | DataType::Decimal256(_, _) + | DataType::Interval(IntervalUnit::YearMonth) + | DataType::Interval(IntervalUnit::DayTime) => Ok(vec![arg_types[0].clone()]), + other => { + if other.is_numeric() { + Ok(vec![DataType::Float64]) + } else { + Err(unsupported_data_type_exec_err( + "abs", + "Numeric Type or Interval Type", + &arg_types[0], + )) + } + } + } + } +} + +macro_rules! legacy_compute_op { + ($ARRAY:expr, $FUNC:ident, $TYPE:ident, $RESULT:ident) => {{ + let n = $ARRAY.as_any().downcast_ref::<$TYPE>(); + match n { + Some(array) => { + let res: $RESULT = + arrow::compute::kernels::arity::unary(array, |x| x.$FUNC()); + Ok(res) + } + _ => Err(DataFusionError::Internal(format!( + "Invalid data type for abs" + ))), + } + }}; +} + +macro_rules! ansi_compute_op { + ($ARRAY:expr, $FUNC:ident, $TYPE:ident, $RESULT:ident, $NATIVE:ident, $FROM_TYPE:expr) => {{ + let n = $ARRAY.as_any().downcast_ref::<$TYPE>(); + match n { + Some(array) => { + match arrow::compute::kernels::arity::try_unary(array, |x| { + if x == $NATIVE::MIN { + Err(arrow::error::ArrowError::ArithmeticOverflow( + $FROM_TYPE.to_string(), + )) + } else { + Ok(x.$FUNC()) + } + }) { + Ok(res) => Ok(ColumnarValue::Array( + Arc::>::new(res), + )), + Err(_) => Err(arithmetic_overflow_error($FROM_TYPE).into()), + } + } + _ => Err(DataFusionError::Internal("Invalid data type".to_string())), + } + }}; +} + +fn arithmetic_overflow_error(from_type: &str) -> DataFusionError { + ArrowError( + Box::from(arrow::error::ArrowError::ComputeError(format!( + "arithmetic overflow from {}", + from_type + ))), + None, + ) +} + +pub fn spark_abs(args: &[ColumnarValue]) -> Result { + if args.len() > 2 { + return internal_err!("abs takes at most 2 arguments, but got: {}", args.len()); + } + + let fail_on_error = if args.len() == 2 { + match &args[1] { + ColumnarValue::Scalar(ScalarValue::Boolean(Some(fail_on_error))) => { + *fail_on_error + } + _ => { + return internal_err!( + "The second argument must be boolean scalar, but got: {:?}", + args[1] + ); + } + } + } else { + false + }; + + match &args[0] { + ColumnarValue::Array(array) => match array.data_type() { + DataType::Null + | DataType::UInt8 + | DataType::UInt16 + | DataType::UInt32 + | DataType::UInt64 => Ok(args[0].clone()), + DataType::Int8 => { + if !fail_on_error { + let result = + legacy_compute_op!(array, wrapping_abs, Int8Array, Int8Array); + Ok(ColumnarValue::Array(Arc::new(result?))) + } else { + ansi_compute_op!(array, abs, Int8Array, Int8Type, i8, "Int8") + } + } + DataType::Int16 => { + if !fail_on_error { + let result = + legacy_compute_op!(array, wrapping_abs, Int16Array, Int16Array); + Ok(ColumnarValue::Array(Arc::new(result?))) + } else { + ansi_compute_op!(array, abs, Int16Array, Int16Type, i16, "Int16") + } + } + DataType::Int32 => { + if !fail_on_error { + let result = + legacy_compute_op!(array, wrapping_abs, Int32Array, Int32Array); + Ok(ColumnarValue::Array(Arc::new(result?))) + } else { + ansi_compute_op!(array, abs, Int32Array, Int32Type, i32, "Int32") + } + } + DataType::Int64 => { + if !fail_on_error { + let result = + legacy_compute_op!(array, wrapping_abs, Int64Array, Int64Array); + Ok(ColumnarValue::Array(Arc::new(result?))) + } else { + ansi_compute_op!(array, abs, Int64Array, Int64Type, i64, "Int64") + } + } + DataType::Float32 => { + let result = legacy_compute_op!(array, abs, Float32Array, Float32Array); + Ok(ColumnarValue::Array(Arc::new(result?))) + } + DataType::Float64 => { + let result = legacy_compute_op!(array, abs, Float64Array, Float64Array); + Ok(ColumnarValue::Array(Arc::new(result?))) + } + DataType::Decimal128(precision, scale) => { + if !fail_on_error { + let result = legacy_compute_op!( + array, + wrapping_abs, + Decimal128Array, + Decimal128Array + )?; + let result = + result.with_data_type(DataType::Decimal128(*precision, *scale)); + Ok(ColumnarValue::Array(Arc::new(result))) + } else { + // Need to pass precision and scale from input, so not using ansi_compute_op + let input = array.as_any().downcast_ref::(); + match input { + Some(i) => { + match arrow::compute::kernels::arity::try_unary(i, |x| { + if x == i128::MIN { + Err(arrow::error::ArrowError::ArithmeticOverflow( + "Decimal128".to_string(), + )) + } else { + Ok(x.abs()) + } + }) { + Ok(res) => Ok(ColumnarValue::Array(Arc::< + PrimitiveArray, + >::new( + res.with_data_type(DataType::Decimal128( + *precision, *scale, + )), + ))), + Err(_) => { + Err(arithmetic_overflow_error("Decimal128").into()) + } + } + } + _ => Err(DataFusionError::Internal( + "Invalid data type".to_string(), + )), + } + } + } + DataType::Decimal256(precision, scale) => { + if !fail_on_error { + let result = legacy_compute_op!( + array, + wrapping_abs, + Decimal256Array, + Decimal256Array + )?; + let result = + result.with_data_type(DataType::Decimal256(*precision, *scale)); + Ok(ColumnarValue::Array(Arc::new(result))) + } else { + // Need to pass precision and scale from input, so not using ansi_compute_op + let input = array.as_any().downcast_ref::(); + match input { + Some(i) => { + match arrow::compute::kernels::arity::try_unary(i, |x| { + if x == i256::MIN { + Err(arrow::error::ArrowError::ArithmeticOverflow( + "Decimal256".to_string(), + )) + } else { + Ok(x.wrapping_abs()) // i256 doesn't define abs() method + } + }) { + Ok(res) => Ok(ColumnarValue::Array(Arc::< + PrimitiveArray, + >::new( + res.with_data_type(DataType::Decimal256( + *precision, *scale, + )), + ))), + Err(_) => { + Err(arithmetic_overflow_error("Decimal256").into()) + } + } + } + _ => Err(DataFusionError::Internal( + "Invalid data type".to_string(), + )), + } + } + } + DataType::Interval(unit) => match unit { + IntervalUnit::YearMonth => { + let result = legacy_compute_op!( + array, + wrapping_abs, + IntervalYearMonthArray, + IntervalYearMonthArray + )?; + let result = result.with_data_type(DataType::Interval(*unit)); + Ok(ColumnarValue::Array(Arc::new(result))) + } + IntervalUnit::DayTime => { + let result = legacy_compute_op!( + array, + wrapping_abs, + IntervalDayTimeArray, + IntervalDayTimeArray + )?; + let result = result.with_data_type(DataType::Interval(*unit)); + Ok(ColumnarValue::Array(Arc::new(result))) + } + IntervalUnit::MonthDayNano => internal_err!( + "MonthDayNano is not a supported Interval unit for Spark ABS" + ), + }, + dt => internal_err!("Not supported datatype for Spark ABS: {dt}"), + }, + ColumnarValue::Scalar(sv) => { + match sv { + ScalarValue::Null + | ScalarValue::UInt8(_) + | ScalarValue::UInt16(_) + | ScalarValue::UInt32(_) + | ScalarValue::UInt64(_) => Ok(args[0].clone()), + ScalarValue::Int8(a) => a + .map(|v| match v.checked_abs() { + Some(abs_val) => { + Ok(ColumnarValue::Scalar(ScalarValue::Int8(Some(abs_val)))) + } + None => { + if !fail_on_error { + // return the original value + Ok(ColumnarValue::Scalar(ScalarValue::Int8(Some(v)))) + } else { + Err(arithmetic_overflow_error("Int8").into()) + } + } + }) + .unwrap(), + ScalarValue::Int16(a) => a + .map(|v| match v.checked_abs() { + Some(abs_val) => { + Ok(ColumnarValue::Scalar(ScalarValue::Int16(Some(abs_val)))) + } + None => { + if !fail_on_error { + // return the original value + Ok(ColumnarValue::Scalar(ScalarValue::Int16(Some(v)))) + } else { + Err(arithmetic_overflow_error("Int16").into()) + } + } + }) + .unwrap(), + ScalarValue::Int32(a) => a + .map(|v| match v.checked_abs() { + Some(abs_val) => { + Ok(ColumnarValue::Scalar(ScalarValue::Int32(Some(abs_val)))) + } + None => { + if !fail_on_error { + // return the original value + Ok(ColumnarValue::Scalar(ScalarValue::Int32(Some(v)))) + } else { + Err(arithmetic_overflow_error("Int32").into()) + } + } + }) + .unwrap(), + ScalarValue::Int64(a) => a + .map(|v| match v.checked_abs() { + Some(abs_val) => { + Ok(ColumnarValue::Scalar(ScalarValue::Int64(Some(abs_val)))) + } + None => { + if !fail_on_error { + // return the original value + Ok(ColumnarValue::Scalar(ScalarValue::Int64(Some(v)))) + } else { + Err(arithmetic_overflow_error("Int64").into()) + } + } + }) + .unwrap(), + ScalarValue::Float32(a) => Ok(ColumnarValue::Scalar( + ScalarValue::Float32(a.map(|x| x.abs())), + )), + ScalarValue::Float64(a) => Ok(ColumnarValue::Scalar( + ScalarValue::Float64(a.map(|x| x.abs())), + )), + ScalarValue::Decimal128(a, precision, scale) => a + .map(|v| match v.checked_abs() { + Some(abs_val) => Ok(ColumnarValue::Scalar( + ScalarValue::Decimal128(Some(abs_val), *precision, *scale), + )), + None => { + if !fail_on_error { + // return the original value + Ok(ColumnarValue::Scalar(ScalarValue::Decimal128( + Some(v), + *precision, + *scale, + ))) + } else { + Err(arithmetic_overflow_error("Decimal128").into()) + } + } + }) + .unwrap(), + ScalarValue::Decimal256(a, precision, scale) => a + .map(|v| match v.checked_abs() { + Some(abs_val) => Ok(ColumnarValue::Scalar( + ScalarValue::Decimal256(Some(abs_val), *precision, *scale), + )), + None => { + if !fail_on_error { + // return the original value + Ok(ColumnarValue::Scalar(ScalarValue::Decimal256( + Some(v), + *precision, + *scale, + ))) + } else { + Err(arithmetic_overflow_error("Decimal256").into()) + } + } + }) + .unwrap(), + ScalarValue::IntervalYearMonth(a) => { + let result = a.map(|v| v.wrapping_abs()); + Ok(ColumnarValue::Scalar(ScalarValue::IntervalYearMonth( + result, + ))) + } + ScalarValue::IntervalDayTime(a) => { + let result = a.map(|v| v.wrapping_abs()); + Ok(ColumnarValue::Scalar(ScalarValue::IntervalDayTime(result))) + } + + dt => internal_err!("Not supported datatype for Spark ABS: {dt}"), + } + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use datafusion_common::cast::{ + as_decimal128_array, as_decimal256_array, as_float32_array, as_float64_array, + as_int16_array, as_int32_array, as_int64_array, as_int8_array, + as_interval_dt_array, as_interval_ym_array, as_uint64_array, + }; + + fn with_fail_on_error Result<()>>(test_fn: F) { + for fail_on_error in [true, false] { + let _ = test_fn(fail_on_error); + } + } + + // Unsigned types, return as is + #[test] + fn test_abs_u8_scalar() { + with_fail_on_error(|fail_on_error| { + let args = ColumnarValue::Scalar(ScalarValue::UInt8(Some(u8::MAX))); + let fail_on_error_arg = + ColumnarValue::Scalar(ScalarValue::Boolean(Some(fail_on_error))); + match spark_abs(&[args, fail_on_error_arg]) { + Ok(ColumnarValue::Scalar(ScalarValue::UInt8(Some(result)))) => { + assert_eq!(result, u8::MAX); + Ok(()) + } + Err(e) => { + if fail_on_error { + assert!( + e.to_string().contains("ARITHMETIC_OVERFLOW"), + "Error message did not match. Actual message: {e}" + ); + Ok(()) + } else { + panic!("Didn't expect error, but got: {e:?}") + } + } + _ => unreachable!(), + } + }); + } + + #[test] + fn test_abs_i8_scalar() { + with_fail_on_error(|fail_on_error| { + let args = ColumnarValue::Scalar(ScalarValue::Int8(Some(i8::MIN))); + let fail_on_error_arg = + ColumnarValue::Scalar(ScalarValue::Boolean(Some(fail_on_error))); + match spark_abs(&[args, fail_on_error_arg]) { + Ok(ColumnarValue::Scalar(ScalarValue::Int8(Some(result)))) => { + assert_eq!(result, i8::MIN); + Ok(()) + } + Err(e) => { + if fail_on_error { + assert!( + e.to_string().contains("arithmetic overflow"), + "Error message did not match. Actual message: {e}" + ); + Ok(()) + } else { + panic!("Didn't expect error, but got: {e:?}") + } + } + _ => unreachable!(), + } + }); + } + + #[test] + fn test_abs_i16_scalar() { + with_fail_on_error(|fail_on_error| { + let args = ColumnarValue::Scalar(ScalarValue::Int16(Some(i16::MIN))); + let fail_on_error_arg = + ColumnarValue::Scalar(ScalarValue::Boolean(Some(fail_on_error))); + match spark_abs(&[args, fail_on_error_arg]) { + Ok(ColumnarValue::Scalar(ScalarValue::Int16(Some(result)))) => { + assert_eq!(result, i16::MIN); + Ok(()) + } + Err(e) => { + if fail_on_error { + assert!( + e.to_string().contains("arithmetic overflow"), + "Error message did not match. Actual message: {e}" + ); + Ok(()) + } else { + panic!("Didn't expect error, but got: {e:?}") + } + } + _ => unreachable!(), + } + }); + } + + #[test] + fn test_abs_i32_scalar() { + with_fail_on_error(|fail_on_error| { + let args = ColumnarValue::Scalar(ScalarValue::Int32(Some(i32::MIN))); + let fail_on_error_arg = + ColumnarValue::Scalar(ScalarValue::Boolean(Some(fail_on_error))); + match spark_abs(&[args, fail_on_error_arg]) { + Ok(ColumnarValue::Scalar(ScalarValue::Int32(Some(result)))) => { + assert_eq!(result, i32::MIN); + Ok(()) + } + Err(e) => { + if fail_on_error { + assert!( + e.to_string().contains("arithmetic overflow"), + "Error message did not match. Actual message: {e}" + ); + Ok(()) + } else { + panic!("Didn't expect error, but got: {e:?}") + } + } + _ => unreachable!(), + } + }); + } + + #[test] + fn test_abs_i64_scalar() { + with_fail_on_error(|fail_on_error| { + let args = ColumnarValue::Scalar(ScalarValue::Int64(Some(i64::MIN))); + let fail_on_error_arg = + ColumnarValue::Scalar(ScalarValue::Boolean(Some(fail_on_error))); + match spark_abs(&[args, fail_on_error_arg]) { + Ok(ColumnarValue::Scalar(ScalarValue::Int64(Some(result)))) => { + assert_eq!(result, i64::MIN); + Ok(()) + } + Err(e) => { + if fail_on_error { + assert!( + e.to_string().contains("arithmetic overflow"), + "Error message did not match. Actual message: {e}" + ); + Ok(()) + } else { + panic!("Didn't expect error, but got: {e:?}") + } + } + _ => unreachable!(), + } + }); + } + + #[test] + fn test_abs_decimal128_scalar() { + with_fail_on_error(|fail_on_error| { + let args = + ColumnarValue::Scalar(ScalarValue::Decimal128(Some(i128::MIN), 18, 10)); + let fail_on_error_arg = + ColumnarValue::Scalar(ScalarValue::Boolean(Some(fail_on_error))); + match spark_abs(&[args, fail_on_error_arg]) { + Ok(ColumnarValue::Scalar(ScalarValue::Decimal128( + Some(result), + precision, + scale, + ))) => { + assert_eq!(result, i128::MIN); + assert_eq!(precision, 18); + assert_eq!(scale, 10); + Ok(()) + } + Err(e) => { + if fail_on_error { + assert!( + e.to_string().contains("arithmetic overflow"), + "Error message did not match. Actual message: {e}" + ); + Ok(()) + } else { + panic!("Didn't expect error, but got: {e:?}") + } + } + _ => unreachable!(), + } + }); + } + + #[test] + fn test_abs_decimal256_scalar() { + with_fail_on_error(|fail_on_error| { + let args = + ColumnarValue::Scalar(ScalarValue::Decimal256(Some(i256::MIN), 10, 2)); + let fail_on_error_arg = + ColumnarValue::Scalar(ScalarValue::Boolean(Some(fail_on_error))); + match spark_abs(&[args, fail_on_error_arg]) { + Ok(ColumnarValue::Scalar(ScalarValue::Decimal256( + Some(result), + precision, + scale, + ))) => { + assert_eq!(result, i256::MIN); + assert_eq!(precision, 10); + assert_eq!(scale, 2); + Ok(()) + } + Err(e) => { + if fail_on_error { + assert!( + e.to_string().contains("arithmetic overflow"), + "Error message did not match. Actual message: {e}" + ); + Ok(()) + } else { + panic!("Didn't expect error, but got: {e:?}") + } + } + _ => unreachable!(), + } + }); + } + + #[test] + fn test_abs_interval_year_month_scalar() { + with_fail_on_error(|fail_on_error| { + let args = + ColumnarValue::Scalar(ScalarValue::IntervalYearMonth(Some(i32::MIN))); + let fail_on_error_arg = + ColumnarValue::Scalar(ScalarValue::Boolean(Some(fail_on_error))); + match spark_abs(&[args, fail_on_error_arg]) { + Ok(ColumnarValue::Scalar(ScalarValue::IntervalYearMonth(Some( + result, + )))) => { + assert_eq!(result, i32::MIN); + Ok(()) + } + Err(e) => { + if fail_on_error { + assert!( + e.to_string().contains("arithmetic overflow"), + "Error message did not match. Actual message: {e}" + ); + Ok(()) + } else { + panic!("Didn't expect error, but got: {e:?}") + } + } + _ => unreachable!(), + } + }); + } + + #[test] + fn test_abs_interval_day_time_scalar() { + with_fail_on_error(|fail_on_error| { + let args = ColumnarValue::Scalar(ScalarValue::IntervalDayTime(Some( + IntervalDayTime::MIN, + ))); + let fail_on_error_arg = + ColumnarValue::Scalar(ScalarValue::Boolean(Some(fail_on_error))); + match spark_abs(&[args, fail_on_error_arg]) { + Ok(ColumnarValue::Scalar(ScalarValue::IntervalDayTime(Some(result)))) => { + assert_eq!(result, IntervalDayTime::MIN); + Ok(()) + } + Err(e) => { + if fail_on_error { + assert!( + e.to_string().contains("arithmetic overflow"), + "Error message did not match. Actual message: {e}" + ); + Ok(()) + } else { + panic!("Didn't expect error, but got: {e:?}") + } + } + _ => unreachable!(), + } + }); + } + + #[test] + fn test_abs_i8_array() { + with_fail_on_error(|fail_on_error| { + let input = + Int8Array::from(vec![Some(-1), Some(i8::MIN), Some(i8::MAX), None]); + let args = ColumnarValue::Array(Arc::new(input)); + let expected = + Int8Array::from(vec![Some(1), Some(i8::MIN), Some(i8::MAX), None]); + let fail_on_error_arg = + ColumnarValue::Scalar(ScalarValue::Boolean(Some(fail_on_error))); + + match spark_abs(&[args, fail_on_error_arg]) { + Ok(ColumnarValue::Array(result)) => { + let actual = as_int8_array(&result)?; + assert_eq!(actual, &expected); + Ok(()) + } + Err(e) => { + if fail_on_error { + assert!( + e.to_string().contains("arithmetic overflow"), + "Error message did not match. Actual message: {e}" + ); + Ok(()) + } else { + panic!("Didn't expect error, but got: {e:?}") + } + } + _ => unreachable!(), + } + }); + } + + #[test] + fn test_abs_i16_array() { + with_fail_on_error(|fail_on_error| { + let input = + Int16Array::from(vec![Some(-1), Some(i16::MIN), Some(i16::MAX), None]); + let args = ColumnarValue::Array(Arc::new(input)); + let expected = + Int16Array::from(vec![Some(1), Some(i16::MIN), Some(i16::MAX), None]); + let fail_on_error_arg = + ColumnarValue::Scalar(ScalarValue::Boolean(Some(fail_on_error))); + + match spark_abs(&[args, fail_on_error_arg]) { + Ok(ColumnarValue::Array(result)) => { + let actual = as_int16_array(&result)?; + assert_eq!(actual, &expected); + Ok(()) + } + Err(e) => { + if fail_on_error { + assert!( + e.to_string().contains("arithmetic overflow"), + "Error message did not match. Actual message: {e}" + ); + Ok(()) + } else { + panic!("Didn't expect error, but got: {e:?}") + } + } + _ => unreachable!(), + } + }); + } + + #[test] + fn test_abs_i32_array() { + with_fail_on_error(|fail_on_error| { + let input = + Int32Array::from(vec![Some(-1), Some(i32::MIN), Some(i32::MAX), None]); + let args = ColumnarValue::Array(Arc::new(input)); + let expected = + Int32Array::from(vec![Some(1), Some(i32::MIN), Some(i32::MAX), None]); + let fail_on_error_arg = + ColumnarValue::Scalar(ScalarValue::Boolean(Some(fail_on_error))); + + match spark_abs(&[args, fail_on_error_arg]) { + Ok(ColumnarValue::Array(result)) => { + let actual = as_int32_array(&result)?; + assert_eq!(actual, &expected); + Ok(()) + } + Err(e) => { + if fail_on_error { + assert!( + e.to_string().contains("arithmetic overflow"), + "Error message did not match. Actual message: {e}" + ); + Ok(()) + } else { + panic!("Didn't expect error, but got: {e:?}") + } + } + _ => unreachable!(), + } + }); + } + + #[test] + fn test_abs_i64_array() { + with_fail_on_error(|fail_on_error| { + let input = + Int64Array::from(vec![Some(-1), Some(i64::MIN), Some(i64::MAX), None]); + let args = ColumnarValue::Array(Arc::new(input)); + let expected = + Int64Array::from(vec![Some(1), Some(i64::MIN), Some(i64::MAX), None]); + let fail_on_error_arg = + ColumnarValue::Scalar(ScalarValue::Boolean(Some(fail_on_error))); + + match spark_abs(&[args, fail_on_error_arg]) { + Ok(ColumnarValue::Array(result)) => { + let actual = as_int64_array(&result)?; + assert_eq!(actual, &expected); + Ok(()) + } + Err(e) => { + if fail_on_error { + assert!( + e.to_string().contains("arithmetic overflow"), + "Error message did not match. Actual message: {e}" + ); + Ok(()) + } else { + panic!("Didn't expect error, but got: {e:?}") + } + } + _ => unreachable!(), + } + }); + } + + #[test] + fn test_abs_f32_array() { + with_fail_on_error(|fail_on_error| { + let input = Float32Array::from(vec![ + Some(-1f32), + Some(f32::MIN), + Some(f32::MAX), + None, + Some(f32::NAN), + ]); + let args = ColumnarValue::Array(Arc::new(input)); + let expected = Float32Array::from(vec![ + Some(1f32), + Some(f32::MAX), + Some(f32::MAX), + None, + Some(f32::NAN), + ]); + let fail_on_error_arg = + ColumnarValue::Scalar(ScalarValue::Boolean(Some(fail_on_error))); + + match spark_abs(&[args, fail_on_error_arg]) { + Ok(ColumnarValue::Array(result)) => { + let actual = as_float32_array(&result)?; + assert_eq!(actual, &expected); + Ok(()) + } + Err(e) => { + if fail_on_error { + assert!( + e.to_string().contains("arithmetic overflow"), + "Error message did not match. Actual message: {e}" + ); + Ok(()) + } else { + panic!("Didn't expect error, but got: {e:?}") + } + } + _ => unreachable!(), + } + }); + } + + #[test] + fn test_abs_f64_array() { + with_fail_on_error(|fail_on_error| { + let input = Float64Array::from(vec![ + Some(-1f64), + Some(f64::MIN), + Some(f64::MAX), + None, + Some(f64::NAN), + ]); + let args = ColumnarValue::Array(Arc::new(input)); + let expected = Float64Array::from(vec![ + Some(1f64), + Some(f64::MAX), + Some(f64::MAX), + None, + Some(f64::NAN), + ]); + let fail_on_error_arg = + ColumnarValue::Scalar(ScalarValue::Boolean(Some(fail_on_error))); + + match spark_abs(&[args, fail_on_error_arg]) { + Ok(ColumnarValue::Array(result)) => { + let actual = as_float64_array(&result)?; + assert_eq!(actual, &expected); + Ok(()) + } + Err(e) => { + if fail_on_error { + assert!( + e.to_string().contains("arithmetic overflow"), + "Error message did not match. Actual message: {e}" + ); + Ok(()) + } else { + panic!("Didn't expect error, but got: {e:?}") + } + } + _ => unreachable!(), + } + }); + } + + #[test] + fn test_abs_decimal128_array() { + with_fail_on_error(|fail_on_error| { + let input = Decimal128Array::from(vec![Some(i128::MIN), None]) + .with_precision_and_scale(38, 37)?; + let args = ColumnarValue::Array(Arc::new(input)); + let expected = Decimal128Array::from(vec![Some(i128::MIN), None]) + .with_precision_and_scale(38, 37)?; + let fail_on_error_arg = + ColumnarValue::Scalar(ScalarValue::Boolean(Some(fail_on_error))); + + match spark_abs(&[args, fail_on_error_arg]) { + Ok(ColumnarValue::Array(result)) => { + let actual = as_decimal128_array(&result)?; + assert_eq!(actual, &expected); + Ok(()) + } + Err(e) => { + if fail_on_error { + assert!( + e.to_string().contains("arithmetic overflow"), + "Error message did not match. Actual message: {e}" + ); + Ok(()) + } else { + panic!("Didn't expect error, but got: {e:?}") + } + } + _ => unreachable!(), + } + }); + } + + #[test] + fn test_abs_decimal256_array() { + with_fail_on_error(|fail_on_error| { + let input = Decimal256Array::from(vec![Some(i256::MIN), None]) + .with_precision_and_scale(5, 2)?; + let args = ColumnarValue::Array(Arc::new(input)); + let expected = Decimal256Array::from(vec![Some(i256::MIN), None]) + .with_precision_and_scale(5, 2)?; + let fail_on_error_arg = + ColumnarValue::Scalar(ScalarValue::Boolean(Some(fail_on_error))); + + match spark_abs(&[args, fail_on_error_arg]) { + Ok(ColumnarValue::Array(result)) => { + let actual = as_decimal256_array(&result)?; + assert_eq!(actual, &expected); + Ok(()) + } + Err(e) => { + if fail_on_error { + assert!( + e.to_string().contains("arithmetic overflow"), + "Error message did not match. Actual message: {e}" + ); + Ok(()) + } else { + panic!("Didn't expect error, but got: {e:?}") + } + } + _ => unreachable!(), + } + }); + } + + #[test] + fn test_abs_interval_year_month_array() { + with_fail_on_error(|fail_on_error| { + let input = IntervalYearMonthArray::from(vec![i32::MIN, -1]); + let args = ColumnarValue::Array(Arc::new(input)); + let expected = IntervalYearMonthArray::from(vec![i32::MIN, 1]); + let fail_on_error_arg = + ColumnarValue::Scalar(ScalarValue::Boolean(Some(fail_on_error))); + + match spark_abs(&[args, fail_on_error_arg]) { + Ok(ColumnarValue::Array(result)) => { + let actual = as_interval_ym_array(&result)?; + assert_eq!(actual, &expected); + Ok(()) + } + Err(e) => { + if fail_on_error { + assert!( + e.to_string().contains("arithmetic overflow"), + "Error message did not match. Actual message: {e}" + ); + Ok(()) + } else { + panic!("Didn't expect error, but got: {e:?}") + } + } + _ => unreachable!(), + } + }); + } + + #[test] + fn test_abs_interval_day_time_array() { + with_fail_on_error(|fail_on_error| { + let input = IntervalDayTimeArray::from(vec![IntervalDayTime::new( + i32::MIN, + i32::MIN, + )]); + let args = ColumnarValue::Array(Arc::new(input)); + let expected = IntervalDayTimeArray::from(vec![IntervalDayTime::new( + i32::MIN, + i32::MIN, + )]); + let fail_on_error_arg = + ColumnarValue::Scalar(ScalarValue::Boolean(Some(fail_on_error))); + + match spark_abs(&[args, fail_on_error_arg]) { + Ok(ColumnarValue::Array(result)) => { + let actual = as_interval_dt_array(&result)?; + assert_eq!(actual, &expected); + Ok(()) + } + Err(e) => { + if fail_on_error { + assert!( + e.to_string().contains("arithmetic overflow"), + "Error message did not match. Actual message: {e}" + ); + Ok(()) + } else { + panic!("Didn't expect error, but got: {e:?}") + } + } + _ => unreachable!(), + } + }); + } + + #[test] + fn test_abs_u64_array() { + with_fail_on_error(|fail_on_error| { + let input = UInt64Array::from(vec![Some(u64::MIN), Some(u64::MAX), None]); + let args = ColumnarValue::Array(Arc::new(input)); + let expected = UInt64Array::from(vec![Some(u64::MIN), Some(u64::MAX), None]); + let fail_on_error_arg = + ColumnarValue::Scalar(ScalarValue::Boolean(Some(fail_on_error))); + + match spark_abs(&[args, fail_on_error_arg]) { + Ok(ColumnarValue::Array(result)) => { + let actual = as_uint64_array(&result)?; + assert_eq!(actual, &expected); + Ok(()) + } + Err(e) => { + if fail_on_error { + assert!( + e.to_string().contains("arithmetic overflow"), + "Error message did not match. Actual message: {e}" + ); + Ok(()) + } else { + panic!("Didn't expect error, but got: {e:?}") + } + } + _ => unreachable!(), + } + }); + } +} diff --git a/datafusion/spark/src/function/math/mod.rs b/datafusion/spark/src/function/math/mod.rs index fe8c2b1da0df..74fa4cf37ca5 100644 --- a/datafusion/spark/src/function/math/mod.rs +++ b/datafusion/spark/src/function/math/mod.rs @@ -15,6 +15,7 @@ // specific language governing permissions and limitations // under the License. +pub mod abs; pub mod expm1; pub mod factorial; pub mod hex; @@ -27,6 +28,7 @@ use datafusion_expr::ScalarUDF; use datafusion_functions::make_udf_function; use std::sync::Arc; +make_udf_function!(abs::SparkAbs, abs); make_udf_function!(expm1::SparkExpm1, expm1); make_udf_function!(factorial::SparkFactorial, factorial); make_udf_function!(hex::SparkHex, hex); @@ -40,6 +42,7 @@ make_udf_function!(trigonometry::SparkSec, sec); pub mod expr_fn { use datafusion_functions::export_functions; + export_functions!((abs, "Returns abs(expr)", arg1)); export_functions!((expm1, "Returns exp(expr) - 1 as a Float64.", arg1)); export_functions!(( factorial, @@ -57,6 +60,7 @@ pub mod expr_fn { pub fn functions() -> Vec> { vec![ + abs(), expm1(), factorial(), hex(), From 386d07ffaacdd170d17b97293ec86fa9960eaec2 Mon Sep 17 00:00:00 2001 From: huanghsiang_cheng Date: Mon, 20 Oct 2025 16:11:59 -0700 Subject: [PATCH 02/25] test: Add Spark abs sqllogictest --- .../test_files/spark/math/abs.slt | 101 +++++++++++++++++- 1 file changed, 97 insertions(+), 4 deletions(-) diff --git a/datafusion/sqllogictest/test_files/spark/math/abs.slt b/datafusion/sqllogictest/test_files/spark/math/abs.slt index 4b9edf7e29f2..eb60eebc35f0 100644 --- a/datafusion/sqllogictest/test_files/spark/math/abs.slt +++ b/datafusion/sqllogictest/test_files/spark/math/abs.slt @@ -23,10 +23,103 @@ ## Original Query: SELECT abs(-1); ## PySpark 3.5.5 Result: {'abs(-1)': 1, 'typeof(abs(-1))': 'int', 'typeof(-1)': 'int'} -#query -#SELECT abs(-1::int); +query I +SELECT abs(-1::int); +---- +1 + +statement ok +CREATE TABLE test_nullable_integer( + c1 TINYINT, + c2 SMALLINT, + c3 INT, + c4 BIGINT, + dataset TEXT + ) + AS VALUES + (NULL, NULL, NULL, NULL, 'nulls'), + (0, 0, 0, 0, 'zeros'), + (1, 1, 1, 1, 'ones'); + +query I +INSERT into test_nullable_integer values(-128, -32768, -2147483648, -9223372036854775808, 'mins'); +---- +1 + +# abs: signed int minimal values +query IIII +select abs(c1), abs(c2), abs(c3), abs(c4) from test_nullable_integer where dataset = 'mins' +---- +-128 -32768 -2147483648 -9223372036854775808 + +statement ok +drop table test_nullable_integer + +statement ok +CREATE TABLE test_nullable_float( + c1 float, + c2 double + ) AS VALUES + (-1.0, -1.0), + (1.0, 1.0), + (NULL, NULL), + (0., 0.), + ('NaN'::double, 'NaN'::double); + +# abs: floats +query RR rowsort +SELECT abs(c1), abs(c2) from test_nullable_float +---- +0 0 +1 1 +1 1 +NULL NULL +NaN NaN + +statement ok +drop table test_nullable_float + +statement ok +CREATE TABLE test_nullable_decimal( + c1 DECIMAL(10, 2), /* Decimal128 */ + c2 DECIMAL(38, 10), /* Decimal128 with max precision */ + c3 DECIMAL(40, 2), /* Decimal256 */ + c4 DECIMAL(76, 10) /* Decimal256 with max precision */ + ) AS VALUES + (0, 0, 0, 0), + (NULL, NULL, NULL, NULL); + +query I +INSERT into test_nullable_decimal values + ( + -99999999.99, + '-9999999999999999999999999999.9999999999', + '-99999999999999999999999999999999999999.99', + '-999999999999999999999999999999999999999999999999999999999999999999.9999999999' + ), + ( + 99999999.99, + '9999999999999999999999999999.9999999999', + '99999999999999999999999999999999999999.99', + '999999999999999999999999999999999999999999999999999999999999999999.9999999999' + ) +---- +2 + +# abs: decimals +query RRRR rowsort +SELECT abs(c1), abs(c2), abs(c3), abs(c4) FROM test_nullable_decimal +---- +0 0 0 0 +99999999.99 9999999999999999999999999999.9999999999 99999999999999999999999999999999999999.99 999999999999999999999999999999999999999999999999999999999999999999.9999999999 +99999999.99 9999999999999999999999999999.9999999999 99999999999999999999999999999999999999.99 999999999999999999999999999999999999999999999999999999999999999999.9999999999 +NULL NULL NULL NULL + + +statement ok +drop table test_nullable_decimal ## Original Query: SELECT abs(INTERVAL -'1-1' YEAR TO MONTH); ## PySpark 3.5.5 Result: {"abs(INTERVAL '-1-1' YEAR TO MONTH)": 13, "typeof(abs(INTERVAL '-1-1' YEAR TO MONTH))": 'interval year to month', "typeof(INTERVAL '-1-1' YEAR TO MONTH)": 'interval year to month'} -#query -#SELECT abs(INTERVAL '-1-1' YEAR TO MONTH::interval year to month); +query error DataFusion error: This feature is not implemented: Unsupported SQL type INTERVAL YEAR TO MONTH +SELECT abs(INTERVAL '-1-1' YEAR TO MONTH::interval year to month); From a4cbf2a9157d632aac6efeff86db1786262df9c8 Mon Sep 17 00:00:00 2001 From: huanghsiang_cheng Date: Tue, 21 Oct 2025 14:28:04 -0700 Subject: [PATCH 03/25] Fix clippy error --- datafusion/spark/src/function/math/abs.rs | 21 ++++++++++----------- 1 file changed, 10 insertions(+), 11 deletions(-) diff --git a/datafusion/spark/src/function/math/abs.rs b/datafusion/spark/src/function/math/abs.rs index 1c7118663f86..9d2ca93b133d 100644 --- a/datafusion/spark/src/function/math/abs.rs +++ b/datafusion/spark/src/function/math/abs.rs @@ -147,7 +147,7 @@ macro_rules! ansi_compute_op { Ok(res) => Ok(ColumnarValue::Array( Arc::>::new(res), )), - Err(_) => Err(arithmetic_overflow_error($FROM_TYPE).into()), + Err(_) => Err(arithmetic_overflow_error($FROM_TYPE)), } } _ => Err(DataFusionError::Internal("Invalid data type".to_string())), @@ -158,8 +158,7 @@ macro_rules! ansi_compute_op { fn arithmetic_overflow_error(from_type: &str) -> DataFusionError { ArrowError( Box::from(arrow::error::ArrowError::ComputeError(format!( - "arithmetic overflow from {}", - from_type + "arithmetic overflow from {from_type}", ))), None, ) @@ -270,7 +269,7 @@ pub fn spark_abs(args: &[ColumnarValue]) -> Result { - Err(arithmetic_overflow_error("Decimal128").into()) + Err(arithmetic_overflow_error("Decimal128")) } } } @@ -313,7 +312,7 @@ pub fn spark_abs(args: &[ColumnarValue]) -> Result { - Err(arithmetic_overflow_error("Decimal256").into()) + Err(arithmetic_overflow_error("Decimal256")) } } } @@ -367,7 +366,7 @@ pub fn spark_abs(args: &[ColumnarValue]) -> Result Result Result Result Result Result Date: Tue, 21 Oct 2025 16:29:16 -0700 Subject: [PATCH 04/25] formatting --- datafusion/spark/src/function/math/abs.rs | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/datafusion/spark/src/function/math/abs.rs b/datafusion/spark/src/function/math/abs.rs index 9d2ca93b133d..daeb5129cc4d 100644 --- a/datafusion/spark/src/function/math/abs.rs +++ b/datafusion/spark/src/function/math/abs.rs @@ -268,9 +268,7 @@ pub fn spark_abs(args: &[ColumnarValue]) -> Result { - Err(arithmetic_overflow_error("Decimal128")) - } + Err(_) => Err(arithmetic_overflow_error("Decimal128")), } } _ => Err(DataFusionError::Internal( @@ -311,9 +309,7 @@ pub fn spark_abs(args: &[ColumnarValue]) -> Result { - Err(arithmetic_overflow_error("Decimal256")) - } + Err(_) => Err(arithmetic_overflow_error("Decimal256")), } } _ => Err(DataFusionError::Internal( From 3736010309efac68866b958ee75f5a1c4ba537ea Mon Sep 17 00:00:00 2001 From: hsiang-c <137842490+hsiang-c@users.noreply.github.com> Date: Fri, 24 Oct 2025 22:57:30 -0700 Subject: [PATCH 05/25] Update datafusion/spark/src/function/math/abs.rs Co-authored-by: Oleks V --- datafusion/spark/src/function/math/abs.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/datafusion/spark/src/function/math/abs.rs b/datafusion/spark/src/function/math/abs.rs index daeb5129cc4d..8823043f62b9 100644 --- a/datafusion/spark/src/function/math/abs.rs +++ b/datafusion/spark/src/function/math/abs.rs @@ -53,7 +53,7 @@ impl Default for SparkAbs { impl SparkAbs { pub fn new() -> Self { Self { - signature: Signature::user_defined(Volatility::Immutable), + signature: Signature::numeric(1, Volatility::Immutable), } } } From 7e6d15f1ffe2877f74785f7cc8595cbc582ff638 Mon Sep 17 00:00:00 2001 From: huanghsiang_cheng Date: Fri, 24 Oct 2025 23:57:59 -0700 Subject: [PATCH 06/25] Inline abs tests --- .../test_files/spark/math/abs.slt | 62 +++++-------------- 1 file changed, 16 insertions(+), 46 deletions(-) diff --git a/datafusion/sqllogictest/test_files/spark/math/abs.slt b/datafusion/sqllogictest/test_files/spark/math/abs.slt index eb60eebc35f0..d279d00c5d0c 100644 --- a/datafusion/sqllogictest/test_files/spark/math/abs.slt +++ b/datafusion/sqllogictest/test_files/spark/math/abs.slt @@ -23,62 +23,33 @@ ## Original Query: SELECT abs(-1); ## PySpark 3.5.5 Result: {'abs(-1)': 1, 'typeof(abs(-1))': 'int', 'typeof(-1)': 'int'} -query I -SELECT abs(-1::int); ----- -1 -statement ok -CREATE TABLE test_nullable_integer( - c1 TINYINT, - c2 SMALLINT, - c3 INT, - c4 BIGINT, - dataset TEXT - ) - AS VALUES - (NULL, NULL, NULL, NULL, 'nulls'), - (0, 0, 0, 0, 'zeros'), - (1, 1, 1, 1, 'ones'); - -query I -INSERT into test_nullable_integer values(-128, -32768, -2147483648, -9223372036854775808, 'mins'); +# abs: signed int and NULL +query IIIIR +SELECT abs(-127::TINYINT), abs(-32767::SMALLINT), abs(-2147483647::INT), abs(-9223372036854775807::BIGINT), abs(NULL); ---- -1 +127 32767 2147483647 9223372036854775807 NULL + # abs: signed int minimal values query IIII -select abs(c1), abs(c2), abs(c3), abs(c4) from test_nullable_integer where dataset = 'mins' +select abs((-128)::TINYINT), abs((-32768)::SMALLINT), abs((-2147483648)::INT), abs((-9223372036854775808)::BIGINT) ---- -128 -32768 -2147483648 -9223372036854775808 -statement ok -drop table test_nullable_integer - -statement ok -CREATE TABLE test_nullable_float( - c1 float, - c2 double - ) AS VALUES - (-1.0, -1.0), - (1.0, 1.0), - (NULL, NULL), - (0., 0.), - ('NaN'::double, 'NaN'::double); - -# abs: floats -query RR rowsort -SELECT abs(c1), abs(c2) from test_nullable_float +# abs: floats, NULL and NaN +query RRRR +SELECT abs(-1.0::FLOAT), abs(0.::FLOAT), abs(NULL::FLOAT), abs('NaN'::FLOAT) ---- -0 0 -1 1 -1 1 -NULL NULL -NaN NaN +1 0 NULL NaN -statement ok -drop table test_nullable_float +# abs: doubles, NULL and NaN +query RRRR +SELECT abs(-1.0::DOUBLE), abs(0.::DOUBLE), abs(NULL::DOUBLE), abs('NaN'::DOUBLE) +---- +1 0 NULL NaN +# abs: decimal128 and decimal256 statement ok CREATE TABLE test_nullable_decimal( c1 DECIMAL(10, 2), /* Decimal128 */ @@ -106,7 +77,6 @@ INSERT into test_nullable_decimal values ---- 2 -# abs: decimals query RRRR rowsort SELECT abs(c1), abs(c2), abs(c3), abs(c4) FROM test_nullable_decimal ---- From 179332acbbaaa3e2374316f7104ed771fe976840 Mon Sep 17 00:00:00 2001 From: huanghsiang_cheng Date: Mon, 17 Nov 2025 22:18:05 -0800 Subject: [PATCH 07/25] Upwrap directly --- datafusion/spark/src/function/math/abs.rs | 65 +++++++++-------------- 1 file changed, 26 insertions(+), 39 deletions(-) diff --git a/datafusion/spark/src/function/math/abs.rs b/datafusion/spark/src/function/math/abs.rs index 8823043f62b9..8a7dd2b592ba 100644 --- a/datafusion/spark/src/function/math/abs.rs +++ b/datafusion/spark/src/function/math/abs.rs @@ -116,41 +116,28 @@ impl ScalarUDFImpl for SparkAbs { macro_rules! legacy_compute_op { ($ARRAY:expr, $FUNC:ident, $TYPE:ident, $RESULT:ident) => {{ - let n = $ARRAY.as_any().downcast_ref::<$TYPE>(); - match n { - Some(array) => { - let res: $RESULT = - arrow::compute::kernels::arity::unary(array, |x| x.$FUNC()); - Ok(res) - } - _ => Err(DataFusionError::Internal(format!( - "Invalid data type for abs" - ))), - } + let array = $ARRAY.as_any().downcast_ref::<$TYPE>().unwrap(); + let res: $RESULT = arrow::compute::kernels::arity::unary(array, |x| x.$FUNC()); + res }}; } macro_rules! ansi_compute_op { ($ARRAY:expr, $FUNC:ident, $TYPE:ident, $RESULT:ident, $NATIVE:ident, $FROM_TYPE:expr) => {{ - let n = $ARRAY.as_any().downcast_ref::<$TYPE>(); - match n { - Some(array) => { - match arrow::compute::kernels::arity::try_unary(array, |x| { - if x == $NATIVE::MIN { - Err(arrow::error::ArrowError::ArithmeticOverflow( - $FROM_TYPE.to_string(), - )) - } else { - Ok(x.$FUNC()) - } - }) { - Ok(res) => Ok(ColumnarValue::Array( - Arc::>::new(res), - )), - Err(_) => Err(arithmetic_overflow_error($FROM_TYPE)), - } + let array = $ARRAY.as_any().downcast_ref::<$TYPE>().unwrap(); + match arrow::compute::kernels::arity::try_unary(array, |x| { + if x == $NATIVE::MIN { + Err(arrow::error::ArrowError::ArithmeticOverflow( + $FROM_TYPE.to_string(), + )) + } else { + Ok(x.$FUNC()) } - _ => Err(DataFusionError::Internal("Invalid data type".to_string())), + }) { + Ok(res) => Ok(ColumnarValue::Array(Arc::>::new( + res, + ))), + Err(_) => Err(arithmetic_overflow_error($FROM_TYPE)), } }}; } @@ -196,7 +183,7 @@ pub fn spark_abs(args: &[ColumnarValue]) -> Result Result Result Result { let result = legacy_compute_op!(array, abs, Float32Array, Float32Array); - Ok(ColumnarValue::Array(Arc::new(result?))) + Ok(ColumnarValue::Array(Arc::new(result))) } DataType::Float64 => { let result = legacy_compute_op!(array, abs, Float64Array, Float64Array); - Ok(ColumnarValue::Array(Arc::new(result?))) + Ok(ColumnarValue::Array(Arc::new(result))) } DataType::Decimal128(precision, scale) => { if !fail_on_error { @@ -243,7 +230,7 @@ pub fn spark_abs(args: &[ColumnarValue]) -> Result Result Result Result Date: Wed, 29 Oct 2025 17:11:07 +0800 Subject: [PATCH 08/25] Simplify error construction --- datafusion/spark/src/function/math/abs.rs | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/datafusion/spark/src/function/math/abs.rs b/datafusion/spark/src/function/math/abs.rs index 8a7dd2b592ba..adbc65c7fcf2 100644 --- a/datafusion/spark/src/function/math/abs.rs +++ b/datafusion/spark/src/function/math/abs.rs @@ -143,12 +143,7 @@ macro_rules! ansi_compute_op { } fn arithmetic_overflow_error(from_type: &str) -> DataFusionError { - ArrowError( - Box::from(arrow::error::ArrowError::ComputeError(format!( - "arithmetic overflow from {from_type}", - ))), - None, - ) + DataFusionError::Execution(format!("arithmetic overflow from {from_type}")) } pub fn spark_abs(args: &[ColumnarValue]) -> Result { From 88c2ae93d156f1b74b045f97dfae78f2b13cbf0f Mon Sep 17 00:00:00 2001 From: huanghsiang_cheng Date: Fri, 7 Nov 2025 14:35:52 -0800 Subject: [PATCH 09/25] Handle None instead of unwrap --- datafusion/spark/src/function/math/abs.rs | 195 ++++++++++++---------- 1 file changed, 103 insertions(+), 92 deletions(-) diff --git a/datafusion/spark/src/function/math/abs.rs b/datafusion/spark/src/function/math/abs.rs index adbc65c7fcf2..16dfb74c0c10 100644 --- a/datafusion/spark/src/function/math/abs.rs +++ b/datafusion/spark/src/function/math/abs.rs @@ -21,7 +21,6 @@ use crate::function::error_utils::{ use arrow::array::*; use arrow::datatypes::DataType; use arrow::datatypes::*; -use datafusion_common::DataFusionError::ArrowError; use datafusion_common::{internal_err, DataFusionError, Result, ScalarValue}; use datafusion_expr::{ ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility, @@ -327,81 +326,88 @@ pub fn spark_abs(args: &[ColumnarValue]) -> Result internal_err!("Not supported datatype for Spark ABS: {dt}"), }, - ColumnarValue::Scalar(sv) => { - match sv { - ScalarValue::Null - | ScalarValue::UInt8(_) - | ScalarValue::UInt16(_) - | ScalarValue::UInt32(_) - | ScalarValue::UInt64(_) => Ok(args[0].clone()), - ScalarValue::Int8(a) => a - .map(|v| match v.checked_abs() { - Some(abs_val) => { - Ok(ColumnarValue::Scalar(ScalarValue::Int8(Some(abs_val)))) - } - None => { - if !fail_on_error { - // return the original value - Ok(ColumnarValue::Scalar(ScalarValue::Int8(Some(v)))) - } else { - Err(arithmetic_overflow_error("Int8")) - } - } - }) - .unwrap(), - ScalarValue::Int16(a) => a - .map(|v| match v.checked_abs() { - Some(abs_val) => { - Ok(ColumnarValue::Scalar(ScalarValue::Int16(Some(abs_val)))) - } - None => { - if !fail_on_error { - // return the original value - Ok(ColumnarValue::Scalar(ScalarValue::Int16(Some(v)))) - } else { - Err(arithmetic_overflow_error("Int16")) - } - } - }) - .unwrap(), - ScalarValue::Int32(a) => a - .map(|v| match v.checked_abs() { - Some(abs_val) => { - Ok(ColumnarValue::Scalar(ScalarValue::Int32(Some(abs_val)))) + ColumnarValue::Scalar(sv) => match sv { + ScalarValue::Null + | ScalarValue::UInt8(_) + | ScalarValue::UInt16(_) + | ScalarValue::UInt32(_) + | ScalarValue::UInt64(_) => Ok(args[0].clone()), + ScalarValue::Int8(a) => match a { + None => Ok(args[0].clone()), + Some(v) => match v.checked_abs() { + Some(abs_val) => { + Ok(ColumnarValue::Scalar(ScalarValue::Int8(Some(abs_val)))) + } + None => { + if !fail_on_error { + // return the original value + Ok(ColumnarValue::Scalar(ScalarValue::Int8(Some(*v)))) + } else { + Err(arithmetic_overflow_error("Int8")) } - None => { - if !fail_on_error { - // return the original value - Ok(ColumnarValue::Scalar(ScalarValue::Int32(Some(v)))) - } else { - Err(arithmetic_overflow_error("Int32")) - } + } + }, + }, + ScalarValue::Int16(a) => match a { + None => Ok(args[0].clone()), + Some(v) => match v.checked_abs() { + Some(abs_val) => { + Ok(ColumnarValue::Scalar(ScalarValue::Int16(Some(abs_val)))) + } + None => { + if !fail_on_error { + // return the original value + Ok(ColumnarValue::Scalar(ScalarValue::Int16(Some(*v)))) + } else { + Err(arithmetic_overflow_error("Int16")) } - }) - .unwrap(), - ScalarValue::Int64(a) => a - .map(|v| match v.checked_abs() { - Some(abs_val) => { - Ok(ColumnarValue::Scalar(ScalarValue::Int64(Some(abs_val)))) + } + }, + }, + ScalarValue::Int32(a) => match a { + None => Ok(args[0].clone()), + Some(v) => match v.checked_abs() { + Some(abs_val) => { + Ok(ColumnarValue::Scalar(ScalarValue::Int32(Some(abs_val)))) + } + None => { + if !fail_on_error { + // return the original value + Ok(ColumnarValue::Scalar(ScalarValue::Int32(Some(*v)))) + } else { + Err(arithmetic_overflow_error("Int32")) } - None => { - if !fail_on_error { - // return the original value - Ok(ColumnarValue::Scalar(ScalarValue::Int64(Some(v)))) - } else { - Err(arithmetic_overflow_error("Int64")) - } + } + }, + }, + ScalarValue::Int64(a) => match a { + None => Ok(args[0].clone()), + Some(v) => match v.checked_abs() { + Some(abs_val) => { + Ok(ColumnarValue::Scalar(ScalarValue::Int64(Some(abs_val)))) + } + None => { + if !fail_on_error { + // return the original value + Ok(ColumnarValue::Scalar(ScalarValue::Int64(Some(*v)))) + } else { + Err(arithmetic_overflow_error("Int64")) } - }) - .unwrap(), - ScalarValue::Float32(a) => Ok(ColumnarValue::Scalar( - ScalarValue::Float32(a.map(|x| x.abs())), - )), - ScalarValue::Float64(a) => Ok(ColumnarValue::Scalar( - ScalarValue::Float64(a.map(|x| x.abs())), - )), - ScalarValue::Decimal128(a, precision, scale) => a - .map(|v| match v.checked_abs() { + } + }, + }, + ScalarValue::Float32(a) => match a { + None => Ok(args[0].clone()), + Some(v) => Ok(ColumnarValue::Scalar(ScalarValue::Float32(Some(v.abs())))), + }, + ScalarValue::Float64(a) => match a { + None => Ok(args[0].clone()), + Some(v) => Ok(ColumnarValue::Scalar(ScalarValue::Float64(Some(v.abs())))), + }, + ScalarValue::Decimal128(a, precision, scale) => { + match a { + None => Ok(args[0].clone()), + Some(v) => match v.checked_abs() { Some(abs_val) => Ok(ColumnarValue::Scalar( ScalarValue::Decimal128(Some(abs_val), *precision, *scale), )), @@ -409,7 +415,7 @@ pub fn spark_abs(args: &[ColumnarValue]) -> Result Result a - .map(|v| match v.checked_abs() { + }, + } + } + ScalarValue::Decimal256(a, precision, scale) => { + match a { + None => Ok(args[0].clone()), + Some(v) => match v.checked_abs() { Some(abs_val) => Ok(ColumnarValue::Scalar( ScalarValue::Decimal256(Some(abs_val), *precision, *scale), )), @@ -428,7 +437,7 @@ pub fn spark_abs(args: &[ColumnarValue]) -> Result Result { - let result = a.map(|v| v.wrapping_abs()); - Ok(ColumnarValue::Scalar(ScalarValue::IntervalYearMonth( - result, - ))) - } - ScalarValue::IntervalDayTime(a) => { - let result = a.map(|v| v.wrapping_abs()); - Ok(ColumnarValue::Scalar(ScalarValue::IntervalDayTime(result))) + }, } - - dt => internal_err!("Not supported datatype for Spark ABS: {dt}"), } - } + ScalarValue::IntervalYearMonth(a) => match a { + None => Ok(args[0].clone()), + Some(v) => Ok(ColumnarValue::Scalar(ScalarValue::IntervalYearMonth( + Some(v.wrapping_abs()), + ))), + }, + ScalarValue::IntervalDayTime(a) => match a { + None => Ok(args[0].clone()), + Some(v) => Ok(ColumnarValue::Scalar(ScalarValue::IntervalDayTime(Some( + v.wrapping_abs(), + )))), + }, + + dt => internal_err!("Not supported datatype for Spark ABS: {dt}"), + }, } } From d3df5a8addadb0679920683ce04d3b34584529ce Mon Sep 17 00:00:00 2001 From: huanghsiang_cheng Date: Fri, 7 Nov 2025 16:45:31 -0800 Subject: [PATCH 10/25] Refactor scalar test --- datafusion/spark/src/function/math/abs.rs | 486 +++++++++++----------- 1 file changed, 235 insertions(+), 251 deletions(-) diff --git a/datafusion/spark/src/function/math/abs.rs b/datafusion/spark/src/function/math/abs.rs index 16dfb74c0c10..50c183c6190b 100644 --- a/datafusion/spark/src/function/math/abs.rs +++ b/datafusion/spark/src/function/math/abs.rs @@ -122,10 +122,10 @@ macro_rules! legacy_compute_op { } macro_rules! ansi_compute_op { - ($ARRAY:expr, $FUNC:ident, $TYPE:ident, $RESULT:ident, $NATIVE:ident, $FROM_TYPE:expr) => {{ + ($ARRAY:expr, $FUNC:ident, $TYPE:ident, $RESULT:ident, $MIN:expr, $FROM_TYPE:expr) => {{ let array = $ARRAY.as_any().downcast_ref::<$TYPE>().unwrap(); match arrow::compute::kernels::arity::try_unary(array, |x| { - if x == $NATIVE::MIN { + if x == $MIN { Err(arrow::error::ArrowError::ArithmeticOverflow( $FROM_TYPE.to_string(), )) @@ -179,7 +179,7 @@ pub fn spark_abs(args: &[ColumnarValue]) -> Result { @@ -188,7 +188,7 @@ pub fn spark_abs(args: &[ColumnarValue]) -> Result { @@ -197,7 +197,7 @@ pub fn spark_abs(args: &[ColumnarValue]) -> Result { @@ -206,7 +206,7 @@ pub fn spark_abs(args: &[ColumnarValue]) -> Result { @@ -301,24 +301,46 @@ pub fn spark_abs(args: &[ColumnarValue]) -> Result match unit { IntervalUnit::YearMonth => { - let result = legacy_compute_op!( - array, - wrapping_abs, - IntervalYearMonthArray, - IntervalYearMonthArray - ); - let result = result.with_data_type(DataType::Interval(*unit)); - Ok(ColumnarValue::Array(Arc::new(result))) + if !fail_on_error { + let result = legacy_compute_op!( + array, + wrapping_abs, + IntervalYearMonthArray, + IntervalYearMonthArray + ); + let result = result.with_data_type(DataType::Interval(*unit)); + Ok(ColumnarValue::Array(Arc::new(result))) + } else { + ansi_compute_op!( + array, + abs, + IntervalYearMonthArray, + IntervalYearMonthType, + i32::MIN, + "IntervalYearMonth" + ) + } } IntervalUnit::DayTime => { - let result = legacy_compute_op!( - array, - wrapping_abs, - IntervalDayTimeArray, - IntervalDayTimeArray - ); - let result = result.with_data_type(DataType::Interval(*unit)); - Ok(ColumnarValue::Array(Arc::new(result))) + if !fail_on_error { + let result = legacy_compute_op!( + array, + wrapping_abs, + IntervalDayTimeArray, + IntervalDayTimeArray + ); + let result = result.with_data_type(DataType::Interval(*unit)); + Ok(ColumnarValue::Array(Arc::new(result))) + } else { + ansi_compute_op!( + array, + wrapping_abs, + IntervalDayTimeArray, + IntervalDayTimeType, + IntervalDayTime::MIN, + "IntervalDayTime" + ) + } } IntervalUnit::MonthDayNano => internal_err!( "MonthDayNano is not a supported Interval unit for Spark ABS" @@ -450,15 +472,39 @@ pub fn spark_abs(args: &[ColumnarValue]) -> Result match a { None => Ok(args[0].clone()), - Some(v) => Ok(ColumnarValue::Scalar(ScalarValue::IntervalYearMonth( - Some(v.wrapping_abs()), - ))), + Some(v) => match v.checked_abs() { + Some(abs_val) => Ok(ColumnarValue::Scalar( + ScalarValue::IntervalYearMonth(Some(abs_val)), + )), + None => { + if !fail_on_error { + // return the original value + Ok(ColumnarValue::Scalar(ScalarValue::IntervalYearMonth( + Some(*v), + ))) + } else { + Err(arithmetic_overflow_error("IntervalYearMonth")) + } + } + }, }, ScalarValue::IntervalDayTime(a) => match a { None => Ok(args[0].clone()), - Some(v) => Ok(ColumnarValue::Scalar(ScalarValue::IntervalDayTime(Some( - v.wrapping_abs(), - )))), + Some(v) => match v.checked_abs() { + Some(abs_val) => Ok(ColumnarValue::Scalar( + ScalarValue::IntervalDayTime(Some(abs_val)), + )), + None => { + if !fail_on_error { + // return the original value + Ok(ColumnarValue::Scalar(ScalarValue::IntervalDayTime(Some( + *v, + )))) + } else { + Err(arithmetic_overflow_error("IntervalYearMonth")) + } + } + }, }, dt => internal_err!("Not supported datatype for Spark ABS: {dt}"), @@ -481,267 +527,205 @@ mod tests { } } - // Unsigned types, return as is #[test] - fn test_abs_u8_scalar() { - with_fail_on_error(|fail_on_error| { - let args = ColumnarValue::Scalar(ScalarValue::UInt8(Some(u8::MAX))); - let fail_on_error_arg = - ColumnarValue::Scalar(ScalarValue::Boolean(Some(fail_on_error))); - match spark_abs(&[args, fail_on_error_arg]) { - Ok(ColumnarValue::Scalar(ScalarValue::UInt8(Some(result)))) => { - assert_eq!(result, u8::MAX); - Ok(()) - } - Err(e) => { - if fail_on_error { - assert!( - e.to_string().contains("ARITHMETIC_OVERFLOW"), - "Error message did not match. Actual message: {e}" - ); - Ok(()) - } else { - panic!("Didn't expect error, but got: {e:?}") - } - } - _ => unreachable!(), - } - }); + fn test_abs_zero_arg() { + assert!(spark_abs(&[]).is_err()); } - #[test] - fn test_abs_i8_scalar() { - with_fail_on_error(|fail_on_error| { - let args = ColumnarValue::Scalar(ScalarValue::Int8(Some(i8::MIN))); - let fail_on_error_arg = - ColumnarValue::Scalar(ScalarValue::Boolean(Some(fail_on_error))); - match spark_abs(&[args, fail_on_error_arg]) { - Ok(ColumnarValue::Scalar(ScalarValue::Int8(Some(result)))) => { - assert_eq!(result, i8::MIN); - Ok(()) + macro_rules! eval_legacy_mode { + ($TYPE:ident, $VAL:expr) => {{ + let args = ColumnarValue::Scalar(ScalarValue::$TYPE(Some($VAL))); + match spark_abs(&[args]) { + Ok(ColumnarValue::Scalar(ScalarValue::$TYPE(Some(result)))) => { + assert_eq!(result, $VAL); } - Err(e) => { - if fail_on_error { - assert!( - e.to_string().contains("arithmetic overflow"), - "Error message did not match. Actual message: {e}" - ); - Ok(()) - } else { - panic!("Didn't expect error, but got: {e:?}") - } + _ => unreachable!(), + } + }}; + ($TYPE:ident, $VAL:expr, $RESULT:expr) => {{ + let args = ColumnarValue::Scalar(ScalarValue::$TYPE(Some($VAL))); + match spark_abs(&[args]) { + Ok(ColumnarValue::Scalar(ScalarValue::$TYPE(Some(result)))) => { + assert_eq!(result, $RESULT); } _ => unreachable!(), } - }); - } - - #[test] - fn test_abs_i16_scalar() { - with_fail_on_error(|fail_on_error| { - let args = ColumnarValue::Scalar(ScalarValue::Int16(Some(i16::MIN))); - let fail_on_error_arg = - ColumnarValue::Scalar(ScalarValue::Boolean(Some(fail_on_error))); - match spark_abs(&[args, fail_on_error_arg]) { - Ok(ColumnarValue::Scalar(ScalarValue::Int16(Some(result)))) => { - assert_eq!(result, i16::MIN); - Ok(()) + }}; + ($TYPE:ident, $VAL:expr, $PRECISION:expr, $SCALE:expr) => {{ + let args = + ColumnarValue::Scalar(ScalarValue::$TYPE(Some($VAL), $PRECISION, $SCALE)); + match spark_abs(&[args]) { + Ok(ColumnarValue::Scalar(ScalarValue::$TYPE( + Some(result), + precision, + scale, + ))) => { + assert_eq!(result, $VAL); + assert_eq!(precision, $PRECISION); + assert_eq!(scale, $SCALE); } - Err(e) => { - if fail_on_error { - assert!( - e.to_string().contains("arithmetic overflow"), - "Error message did not match. Actual message: {e}" - ); - Ok(()) - } else { - panic!("Didn't expect error, but got: {e:?}") - } + _ => unreachable!(), + } + }}; + ($TYPE:ident, $VAL:expr, $PRECISION:expr, $SCALE:expr, $RESULT:expr) => {{ + let args = + ColumnarValue::Scalar(ScalarValue::$TYPE(Some($VAL), $PRECISION, $SCALE)); + match spark_abs(&[args]) { + Ok(ColumnarValue::Scalar(ScalarValue::$TYPE( + Some(result), + precision, + scale, + ))) => { + assert_eq!(result, $RESULT); + assert_eq!(precision, $PRECISION); + assert_eq!(scale, $SCALE); } _ => unreachable!(), } - }); + }}; } #[test] - fn test_abs_i32_scalar() { - with_fail_on_error(|fail_on_error| { - let args = ColumnarValue::Scalar(ScalarValue::Int32(Some(i32::MIN))); - let fail_on_error_arg = - ColumnarValue::Scalar(ScalarValue::Boolean(Some(fail_on_error))); - match spark_abs(&[args, fail_on_error_arg]) { - Ok(ColumnarValue::Scalar(ScalarValue::Int32(Some(result)))) => { - assert_eq!(result, i32::MIN); - Ok(()) - } + fn test_abs_scalar_legacy_mode() { + // NumericType, DayTimeIntervalType, and YearMonthIntervalType MIN + eval_legacy_mode!(UInt8, u8::MIN); + eval_legacy_mode!(UInt16, u16::MIN); + eval_legacy_mode!(UInt32, u32::MIN); + eval_legacy_mode!(UInt64, u64::MIN); + eval_legacy_mode!(Int8, i8::MIN); + eval_legacy_mode!(Int16, i16::MIN); + eval_legacy_mode!(Int32, i32::MIN); + eval_legacy_mode!(Int64, i64::MIN); + eval_legacy_mode!(Float32, f32::MIN, f32::MAX); + eval_legacy_mode!(Float64, f64::MIN, f64::MAX); + eval_legacy_mode!(Decimal128, i128::MIN, 18, 10); + eval_legacy_mode!(Decimal256, i256::MIN, 10, 2); + eval_legacy_mode!(IntervalYearMonth, i32::MIN); + eval_legacy_mode!(IntervalDayTime, IntervalDayTime::MIN); + + // NumericType, DayTimeIntervalType, and YearMonthIntervalType not MIN + eval_legacy_mode!(Int8, -1i8, 1i8); + eval_legacy_mode!(Int16, -1i16, 1i16); + eval_legacy_mode!(Int32, -1i32, 1i32); + eval_legacy_mode!(Int64, -1i64, 1i64); + eval_legacy_mode!(Decimal128, -1i128, 18, 10, 1i128); + eval_legacy_mode!(Decimal256, i256::from(-1i8), 10, 2, i256::from(1i8)); + eval_legacy_mode!(IntervalYearMonth, -1i32, 1i32); + eval_legacy_mode!( + IntervalDayTime, + IntervalDayTime::new(-1i32, -1i32), + IntervalDayTime::new(1i32, 1i32) + ); + + // Float32, Float64 + eval_legacy_mode!(Float32, f32::NEG_INFINITY, f32::INFINITY); + eval_legacy_mode!(Float32, f32::INFINITY, f32::INFINITY); + eval_legacy_mode!(Float32, 0.0f32, 0.0f32); + eval_legacy_mode!(Float32, -0.0f32, 0.0f32); + eval_legacy_mode!(Float64, f64::NEG_INFINITY, f64::INFINITY); + eval_legacy_mode!(Float64, f64::INFINITY, f64::INFINITY); + eval_legacy_mode!(Float64, 0.0f64, 0.0f64); + eval_legacy_mode!(Float64, -0.0f64, 0.0f64); + } + + macro_rules! eval_ansi_mode { + ($TYPE:ident, $VAL:expr) => {{ + let args = ColumnarValue::Scalar(ScalarValue::$TYPE(Some($VAL))); + let fail_on_error = ColumnarValue::Scalar(ScalarValue::Boolean(Some(true))); + match spark_abs(&[args, fail_on_error]) { Err(e) => { - if fail_on_error { - assert!( - e.to_string().contains("arithmetic overflow"), - "Error message did not match. Actual message: {e}" - ); - Ok(()) - } else { - panic!("Didn't expect error, but got: {e:?}") - } + assert!( + e.to_string().contains("arithmetic overflow"), + "Error message did not match. Actual message: {e}" + ); } _ => unreachable!(), } - }); - } - - #[test] - fn test_abs_i64_scalar() { - with_fail_on_error(|fail_on_error| { - let args = ColumnarValue::Scalar(ScalarValue::Int64(Some(i64::MIN))); - let fail_on_error_arg = - ColumnarValue::Scalar(ScalarValue::Boolean(Some(fail_on_error))); - match spark_abs(&[args, fail_on_error_arg]) { - Ok(ColumnarValue::Scalar(ScalarValue::Int64(Some(result)))) => { - assert_eq!(result, i64::MIN); - Ok(()) - } - Err(e) => { - if fail_on_error { - assert!( - e.to_string().contains("arithmetic overflow"), - "Error message did not match. Actual message: {e}" - ); - Ok(()) - } else { - panic!("Didn't expect error, but got: {e:?}") - } + }}; + ($TYPE:ident, $VAL:expr, $RESULT:expr) => {{ + let args = ColumnarValue::Scalar(ScalarValue::$TYPE(Some($VAL))); + let fail_on_error = ColumnarValue::Scalar(ScalarValue::Boolean(Some(true))); + match spark_abs(&[args, fail_on_error]) { + Ok(ColumnarValue::Scalar(ScalarValue::$TYPE(Some(result)))) => { + assert_eq!(result, $RESULT); } _ => unreachable!(), } - }); - } - - #[test] - fn test_abs_decimal128_scalar() { - with_fail_on_error(|fail_on_error| { + }}; + ($TYPE:ident, $VAL:expr, $PRECISION:expr, $SCALE:expr) => {{ let args = - ColumnarValue::Scalar(ScalarValue::Decimal128(Some(i128::MIN), 18, 10)); - let fail_on_error_arg = - ColumnarValue::Scalar(ScalarValue::Boolean(Some(fail_on_error))); - match spark_abs(&[args, fail_on_error_arg]) { - Ok(ColumnarValue::Scalar(ScalarValue::Decimal128( - Some(result), - precision, - scale, - ))) => { - assert_eq!(result, i128::MIN); - assert_eq!(precision, 18); - assert_eq!(scale, 10); - Ok(()) - } + ColumnarValue::Scalar(ScalarValue::$TYPE(Some($VAL), $PRECISION, $SCALE)); + let fail_on_error = ColumnarValue::Scalar(ScalarValue::Boolean(Some(true))); + match spark_abs(&[args, fail_on_error]) { Err(e) => { - if fail_on_error { - assert!( - e.to_string().contains("arithmetic overflow"), - "Error message did not match. Actual message: {e}" - ); - Ok(()) - } else { - panic!("Didn't expect error, but got: {e:?}") - } + assert!( + e.to_string().contains("arithmetic overflow"), + "Error message did not match. Actual message: {e}" + ); } _ => unreachable!(), } - }); - } - - #[test] - fn test_abs_decimal256_scalar() { - with_fail_on_error(|fail_on_error| { + }}; + ($TYPE:ident, $VAL:expr, $PRECISION:expr, $SCALE:expr, $RESULT:expr) => {{ let args = - ColumnarValue::Scalar(ScalarValue::Decimal256(Some(i256::MIN), 10, 2)); - let fail_on_error_arg = - ColumnarValue::Scalar(ScalarValue::Boolean(Some(fail_on_error))); - match spark_abs(&[args, fail_on_error_arg]) { - Ok(ColumnarValue::Scalar(ScalarValue::Decimal256( + ColumnarValue::Scalar(ScalarValue::$TYPE(Some($VAL), $PRECISION, $SCALE)); + let fail_on_error = ColumnarValue::Scalar(ScalarValue::Boolean(Some(true))); + match spark_abs(&[args, fail_on_error]) { + Ok(ColumnarValue::Scalar(ScalarValue::$TYPE( Some(result), precision, scale, ))) => { - assert_eq!(result, i256::MIN); - assert_eq!(precision, 10); - assert_eq!(scale, 2); - Ok(()) - } - Err(e) => { - if fail_on_error { - assert!( - e.to_string().contains("arithmetic overflow"), - "Error message did not match. Actual message: {e}" - ); - Ok(()) - } else { - panic!("Didn't expect error, but got: {e:?}") - } + assert_eq!(result, $RESULT); + assert_eq!(precision, $PRECISION); + assert_eq!(scale, $SCALE); } _ => unreachable!(), } - }); + }}; } #[test] - fn test_abs_interval_year_month_scalar() { - with_fail_on_error(|fail_on_error| { - let args = - ColumnarValue::Scalar(ScalarValue::IntervalYearMonth(Some(i32::MIN))); - let fail_on_error_arg = - ColumnarValue::Scalar(ScalarValue::Boolean(Some(fail_on_error))); - match spark_abs(&[args, fail_on_error_arg]) { - Ok(ColumnarValue::Scalar(ScalarValue::IntervalYearMonth(Some( - result, - )))) => { - assert_eq!(result, i32::MIN); - Ok(()) - } - Err(e) => { - if fail_on_error { - assert!( - e.to_string().contains("arithmetic overflow"), - "Error message did not match. Actual message: {e}" - ); - Ok(()) - } else { - panic!("Didn't expect error, but got: {e:?}") - } - } - _ => unreachable!(), - } - }); - } + fn test_abs_scalar_ansi_mode() { + eval_ansi_mode!(Int8, i8::MIN); + eval_ansi_mode!(Int16, i16::MIN); + eval_ansi_mode!(Int32, i32::MIN); + eval_ansi_mode!(Int64, i64::MIN); + eval_ansi_mode!(Decimal128, i128::MIN, 18, 10); + eval_ansi_mode!(Decimal256, i256::MIN, 10, 2); + eval_ansi_mode!(IntervalYearMonth, i32::MIN); + eval_ansi_mode!(IntervalDayTime, IntervalDayTime::MIN); - #[test] - fn test_abs_interval_day_time_scalar() { - with_fail_on_error(|fail_on_error| { - let args = ColumnarValue::Scalar(ScalarValue::IntervalDayTime(Some( - IntervalDayTime::MIN, - ))); - let fail_on_error_arg = - ColumnarValue::Scalar(ScalarValue::Boolean(Some(fail_on_error))); - match spark_abs(&[args, fail_on_error_arg]) { - Ok(ColumnarValue::Scalar(ScalarValue::IntervalDayTime(Some(result)))) => { - assert_eq!(result, IntervalDayTime::MIN); - Ok(()) - } - Err(e) => { - if fail_on_error { - assert!( - e.to_string().contains("arithmetic overflow"), - "Error message did not match. Actual message: {e}" - ); - Ok(()) - } else { - panic!("Didn't expect error, but got: {e:?}") - } - } - _ => unreachable!(), - } - }); + eval_ansi_mode!(UInt8, u8::MIN, u8::MIN); + eval_ansi_mode!(UInt16, u16::MIN, u16::MIN); + eval_ansi_mode!(UInt32, u32::MIN, u32::MIN); + eval_ansi_mode!(UInt64, u64::MIN, u64::MIN); + eval_ansi_mode!(Float32, f32::MIN, f32::MAX); + eval_ansi_mode!(Float64, f64::MIN, f64::MAX); + + // NumericType, DayTimeIntervalType, and YearMonthIntervalType not MIN + eval_ansi_mode!(Int8, -1i8, 1i8); + eval_ansi_mode!(Int16, -1i16, 1i16); + eval_ansi_mode!(Int32, -1i32, 1i32); + eval_ansi_mode!(Int64, -1i64, 1i64); + eval_ansi_mode!(Decimal128, -1i128, 18, 10, 1i128); + eval_ansi_mode!(Decimal256, i256::from(-1i8), 10, 2, i256::from(1i8)); + eval_ansi_mode!(IntervalYearMonth, -1i32, 1i32); + eval_ansi_mode!( + IntervalDayTime, + IntervalDayTime::new(-1i32, -1i32), + IntervalDayTime::new(1i32, 1i32) + ); + + // Float32, Float64 + eval_ansi_mode!(Float32, f32::NEG_INFINITY, f32::INFINITY); + eval_ansi_mode!(Float32, f32::INFINITY, f32::INFINITY); + eval_ansi_mode!(Float32, 0.0f32, 0.0f32); + eval_ansi_mode!(Float32, -0.0f32, 0.0f32); + eval_ansi_mode!(Float64, f64::NEG_INFINITY, f64::INFINITY); + eval_ansi_mode!(Float64, f64::INFINITY, f64::INFINITY); + eval_ansi_mode!(Float64, 0.0f64, 0.0f64); + eval_ansi_mode!(Float64, -0.0f64, 0.0f64); } #[test] From f51c284662d6ccb16e20858e4c2ac7b8b1996c1b Mon Sep 17 00:00:00 2001 From: huanghsiang_cheng Date: Sat, 8 Nov 2025 14:25:06 -0800 Subject: [PATCH 11/25] Refactor array test --- datafusion/spark/src/function/math/abs.rs | 522 ++++++++-------------- 1 file changed, 196 insertions(+), 326 deletions(-) diff --git a/datafusion/spark/src/function/math/abs.rs b/datafusion/spark/src/function/math/abs.rs index 50c183c6190b..1ea061aa8cf4 100644 --- a/datafusion/spark/src/function/math/abs.rs +++ b/datafusion/spark/src/function/math/abs.rs @@ -728,386 +728,256 @@ mod tests { eval_ansi_mode!(Float64, -0.0f64, 0.0f64); } - #[test] - fn test_abs_i8_array() { - with_fail_on_error(|fail_on_error| { - let input = - Int8Array::from(vec![Some(-1), Some(i8::MIN), Some(i8::MAX), None]); - let args = ColumnarValue::Array(Arc::new(input)); - let expected = - Int8Array::from(vec![Some(1), Some(i8::MIN), Some(i8::MAX), None]); - let fail_on_error_arg = - ColumnarValue::Scalar(ScalarValue::Boolean(Some(fail_on_error))); - - match spark_abs(&[args, fail_on_error_arg]) { - Ok(ColumnarValue::Array(result)) => { - let actual = as_int8_array(&result)?; - assert_eq!(actual, &expected); - Ok(()) - } - Err(e) => { - if fail_on_error { - assert!( - e.to_string().contains("arithmetic overflow"), - "Error message did not match. Actual message: {e}" - ); - Ok(()) - } else { - panic!("Didn't expect error, but got: {e:?}") - } - } - _ => unreachable!(), - } - }); - } - - #[test] - fn test_abs_i16_array() { - with_fail_on_error(|fail_on_error| { - let input = - Int16Array::from(vec![Some(-1), Some(i16::MIN), Some(i16::MAX), None]); + macro_rules! eval_array_legacy_mode { + ($INPUT:expr, $OUTPUT:expr, $FUNC:ident) => {{ + let input = $INPUT; let args = ColumnarValue::Array(Arc::new(input)); - let expected = - Int16Array::from(vec![Some(1), Some(i16::MIN), Some(i16::MAX), None]); - let fail_on_error_arg = - ColumnarValue::Scalar(ScalarValue::Boolean(Some(fail_on_error))); - - match spark_abs(&[args, fail_on_error_arg]) { + let fail_on_error = ColumnarValue::Scalar(ScalarValue::Boolean(Some(false))); + let expected = $OUTPUT; + match spark_abs(&[args, fail_on_error]) { Ok(ColumnarValue::Array(result)) => { - let actual = as_int16_array(&result)?; + let actual = datafusion_common::cast::$FUNC(&result).unwrap(); assert_eq!(actual, &expected); - Ok(()) - } - Err(e) => { - if fail_on_error { - assert!( - e.to_string().contains("arithmetic overflow"), - "Error message did not match. Actual message: {e}" - ); - Ok(()) - } else { - panic!("Didn't expect error, but got: {e:?}") - } } _ => unreachable!(), } - }); + }}; } #[test] - fn test_abs_i32_array() { - with_fail_on_error(|fail_on_error| { - let input = - Int32Array::from(vec![Some(-1), Some(i32::MIN), Some(i32::MAX), None]); - let args = ColumnarValue::Array(Arc::new(input)); - let expected = - Int32Array::from(vec![Some(1), Some(i32::MIN), Some(i32::MAX), None]); - let fail_on_error_arg = - ColumnarValue::Scalar(ScalarValue::Boolean(Some(fail_on_error))); + fn test_abs_array_legacy_mode() { + eval_array_legacy_mode!( + Int8Array::from(vec![Some(-1), Some(i8::MIN), Some(i8::MAX), None]), + Int8Array::from(vec![Some(1), Some(i8::MIN), Some(i8::MAX), None]), + as_int8_array + ); - match spark_abs(&[args, fail_on_error_arg]) { - Ok(ColumnarValue::Array(result)) => { - let actual = as_int32_array(&result)?; - assert_eq!(actual, &expected); - Ok(()) - } - Err(e) => { - if fail_on_error { - assert!( - e.to_string().contains("arithmetic overflow"), - "Error message did not match. Actual message: {e}" - ); - Ok(()) - } else { - panic!("Didn't expect error, but got: {e:?}") - } - } - _ => unreachable!(), - } - }); - } + eval_array_legacy_mode!( + Int16Array::from(vec![Some(-1), Some(i16::MIN), Some(i16::MAX), None]), + Int16Array::from(vec![Some(1), Some(i16::MIN), Some(i16::MAX), None]), + as_int16_array + ); - #[test] - fn test_abs_i64_array() { - with_fail_on_error(|fail_on_error| { - let input = - Int64Array::from(vec![Some(-1), Some(i64::MIN), Some(i64::MAX), None]); - let args = ColumnarValue::Array(Arc::new(input)); - let expected = - Int64Array::from(vec![Some(1), Some(i64::MIN), Some(i64::MAX), None]); - let fail_on_error_arg = - ColumnarValue::Scalar(ScalarValue::Boolean(Some(fail_on_error))); + eval_array_legacy_mode!( + Int32Array::from(vec![Some(-1), Some(i32::MIN), Some(i32::MAX), None]), + Int32Array::from(vec![Some(1), Some(i32::MIN), Some(i32::MAX), None]), + as_int32_array + ); - match spark_abs(&[args, fail_on_error_arg]) { - Ok(ColumnarValue::Array(result)) => { - let actual = as_int64_array(&result)?; - assert_eq!(actual, &expected); - Ok(()) - } - Err(e) => { - if fail_on_error { - assert!( - e.to_string().contains("arithmetic overflow"), - "Error message did not match. Actual message: {e}" - ); - Ok(()) - } else { - panic!("Didn't expect error, but got: {e:?}") - } - } - _ => unreachable!(), - } - }); - } + eval_array_legacy_mode!( + Int64Array::from(vec![Some(-1), Some(i64::MIN), Some(i64::MAX), None]), + Int64Array::from(vec![Some(1), Some(i64::MIN), Some(i64::MAX), None]), + as_int64_array + ); - #[test] - fn test_abs_f32_array() { - with_fail_on_error(|fail_on_error| { - let input = Float32Array::from(vec![ + eval_array_legacy_mode!( + Float32Array::from(vec![ Some(-1f32), Some(f32::MIN), Some(f32::MAX), None, Some(f32::NAN), - ]); - let args = ColumnarValue::Array(Arc::new(input)); - let expected = Float32Array::from(vec![ + Some(f32::INFINITY), + Some(f32::NEG_INFINITY), + Some(0.0), + Some(-0.0), + ]), + Float32Array::from(vec![ Some(1f32), Some(f32::MAX), Some(f32::MAX), None, Some(f32::NAN), - ]); - let fail_on_error_arg = - ColumnarValue::Scalar(ScalarValue::Boolean(Some(fail_on_error))); - - match spark_abs(&[args, fail_on_error_arg]) { - Ok(ColumnarValue::Array(result)) => { - let actual = as_float32_array(&result)?; - assert_eq!(actual, &expected); - Ok(()) - } - Err(e) => { - if fail_on_error { - assert!( - e.to_string().contains("arithmetic overflow"), - "Error message did not match. Actual message: {e}" - ); - Ok(()) - } else { - panic!("Didn't expect error, but got: {e:?}") - } - } - _ => unreachable!(), - } - }); - } + Some(f32::INFINITY), + Some(f32::INFINITY), + Some(0.0), + Some(0.0), + ]), + as_float32_array + ); - #[test] - fn test_abs_f64_array() { - with_fail_on_error(|fail_on_error| { - let input = Float64Array::from(vec![ + eval_array_legacy_mode!( + Float64Array::from(vec![ Some(-1f64), Some(f64::MIN), Some(f64::MAX), None, Some(f64::NAN), - ]); - let args = ColumnarValue::Array(Arc::new(input)); - let expected = Float64Array::from(vec![ + Some(f64::INFINITY), + Some(f64::NEG_INFINITY), + Some(0.0), + Some(-0.0), + ]), + Float64Array::from(vec![ Some(1f64), Some(f64::MAX), Some(f64::MAX), None, Some(f64::NAN), - ]); - let fail_on_error_arg = - ColumnarValue::Scalar(ScalarValue::Boolean(Some(fail_on_error))); + Some(f64::INFINITY), + Some(f64::INFINITY), + Some(0.0), + Some(0.0), + ]), + as_float64_array + ); - match spark_abs(&[args, fail_on_error_arg]) { - Ok(ColumnarValue::Array(result)) => { - let actual = as_float64_array(&result)?; - assert_eq!(actual, &expected); - Ok(()) - } - Err(e) => { - if fail_on_error { - assert!( - e.to_string().contains("arithmetic overflow"), - "Error message did not match. Actual message: {e}" - ); - Ok(()) - } else { - panic!("Didn't expect error, but got: {e:?}") - } - } - _ => unreachable!(), - } - }); - } + eval_array_legacy_mode!( + Decimal128Array::from(vec![Some(i128::MIN), None]) + .with_precision_and_scale(38, 37) + .unwrap(), + Decimal128Array::from(vec![Some(i128::MIN), None]) + .with_precision_and_scale(38, 37) + .unwrap(), + as_decimal128_array + ); - #[test] - fn test_abs_decimal128_array() { - with_fail_on_error(|fail_on_error| { - let input = Decimal128Array::from(vec![Some(i128::MIN), None]) - .with_precision_and_scale(38, 37)?; - let args = ColumnarValue::Array(Arc::new(input)); - let expected = Decimal128Array::from(vec![Some(i128::MIN), None]) - .with_precision_and_scale(38, 37)?; - let fail_on_error_arg = - ColumnarValue::Scalar(ScalarValue::Boolean(Some(fail_on_error))); + eval_array_legacy_mode!( + Decimal256Array::from(vec![Some(i256::MIN), None]) + .with_precision_and_scale(5, 2) + .unwrap(), + Decimal256Array::from(vec![Some(i256::MIN), None]) + .with_precision_and_scale(5, 2) + .unwrap(), + as_decimal256_array + ); - match spark_abs(&[args, fail_on_error_arg]) { - Ok(ColumnarValue::Array(result)) => { - let actual = as_decimal128_array(&result)?; - assert_eq!(actual, &expected); - Ok(()) - } - Err(e) => { - if fail_on_error { - assert!( - e.to_string().contains("arithmetic overflow"), - "Error message did not match. Actual message: {e}" - ); - Ok(()) - } else { - panic!("Didn't expect error, but got: {e:?}") - } - } - _ => unreachable!(), - } - }); + eval_array_legacy_mode!( + IntervalYearMonthArray::from(vec![i32::MIN, -1]), + IntervalYearMonthArray::from(vec![i32::MIN, 1]), + as_interval_ym_array + ); + + eval_array_legacy_mode!( + IntervalDayTimeArray::from(vec![IntervalDayTime::new(i32::MIN, i32::MIN,)]), + IntervalDayTimeArray::from(vec![IntervalDayTime::new(i32::MIN, i32::MIN,)]), + as_interval_dt_array + ); } - #[test] - fn test_abs_decimal256_array() { - with_fail_on_error(|fail_on_error| { - let input = Decimal256Array::from(vec![Some(i256::MIN), None]) - .with_precision_and_scale(5, 2)?; + macro_rules! eval_array_ansi_mode { + ($INPUT:expr) => {{ + let input = $INPUT; let args = ColumnarValue::Array(Arc::new(input)); - let expected = Decimal256Array::from(vec![Some(i256::MIN), None]) - .with_precision_and_scale(5, 2)?; - let fail_on_error_arg = - ColumnarValue::Scalar(ScalarValue::Boolean(Some(fail_on_error))); - - match spark_abs(&[args, fail_on_error_arg]) { - Ok(ColumnarValue::Array(result)) => { - let actual = as_decimal256_array(&result)?; - assert_eq!(actual, &expected); - Ok(()) - } + let fail_on_error = ColumnarValue::Scalar(ScalarValue::Boolean(Some(true))); + match spark_abs(&[args, fail_on_error]) { Err(e) => { - if fail_on_error { - assert!( - e.to_string().contains("arithmetic overflow"), - "Error message did not match. Actual message: {e}" - ); - Ok(()) - } else { - panic!("Didn't expect error, but got: {e:?}") - } + assert!( + e.to_string().contains("arithmetic overflow"), + "Error message did not match. Actual message: {e}" + ); } _ => unreachable!(), } - }); - } - - #[test] - fn test_abs_interval_year_month_array() { - with_fail_on_error(|fail_on_error| { - let input = IntervalYearMonthArray::from(vec![i32::MIN, -1]); + }}; + ($INPUT:expr, $OUTPUT:expr, $FUNC:ident) => {{ + let input = $INPUT; let args = ColumnarValue::Array(Arc::new(input)); - let expected = IntervalYearMonthArray::from(vec![i32::MIN, 1]); - let fail_on_error_arg = - ColumnarValue::Scalar(ScalarValue::Boolean(Some(fail_on_error))); - - match spark_abs(&[args, fail_on_error_arg]) { + let fail_on_error = ColumnarValue::Scalar(ScalarValue::Boolean(Some(true))); + let expected = $OUTPUT; + match spark_abs(&[args, fail_on_error]) { Ok(ColumnarValue::Array(result)) => { - let actual = as_interval_ym_array(&result)?; + let actual = datafusion_common::cast::$FUNC(&result).unwrap(); assert_eq!(actual, &expected); - Ok(()) - } - Err(e) => { - if fail_on_error { - assert!( - e.to_string().contains("arithmetic overflow"), - "Error message did not match. Actual message: {e}" - ); - Ok(()) - } else { - panic!("Didn't expect error, but got: {e:?}") - } } _ => unreachable!(), } - }); + }}; } - #[test] - fn test_abs_interval_day_time_array() { - with_fail_on_error(|fail_on_error| { - let input = IntervalDayTimeArray::from(vec![IntervalDayTime::new( - i32::MIN, - i32::MIN, - )]); - let args = ColumnarValue::Array(Arc::new(input)); - let expected = IntervalDayTimeArray::from(vec![IntervalDayTime::new( - i32::MIN, - i32::MIN, - )]); - let fail_on_error_arg = - ColumnarValue::Scalar(ScalarValue::Boolean(Some(fail_on_error))); + fn test_abs_array_ansi_mode() { + eval_array_ansi_mode!( + UInt64Array::from(vec![Some(u64::MIN), Some(u64::MAX), None]), + UInt64Array::from(vec![Some(u64::MIN), Some(u64::MAX), None]), + as_uint64_array + ); - match spark_abs(&[args, fail_on_error_arg]) { - Ok(ColumnarValue::Array(result)) => { - let actual = as_interval_dt_array(&result)?; - assert_eq!(actual, &expected); - Ok(()) - } - Err(e) => { - if fail_on_error { - assert!( - e.to_string().contains("arithmetic overflow"), - "Error message did not match. Actual message: {e}" - ); - Ok(()) - } else { - panic!("Didn't expect error, but got: {e:?}") - } - } - _ => unreachable!(), - } - }); - } + eval_array_ansi_mode!(Int8Array::from(vec![ + Some(-1), + Some(i8::MIN), + Some(i8::MAX), + None + ])); + eval_array_ansi_mode!(Int16Array::from(vec![ + Some(-1), + Some(i16::MIN), + Some(i16::MAX), + None + ])); + eval_array_ansi_mode!(Int32Array::from(vec![ + Some(-1), + Some(i32::MIN), + Some(i32::MAX), + None + ])); + eval_array_ansi_mode!(Int64Array::from(vec![ + Some(-1), + Some(i64::MIN), + Some(i64::MAX), + None + ])); + eval_array_ansi_mode!( + Float32Array::from(vec![ + Some(-1f32), + Some(f32::MIN), + Some(f32::MAX), + None, + Some(f32::NAN), + Some(f32::INFINITY), + Some(f32::NEG_INFINITY), + Some(0.0), + Some(-0.0), + ]), + Float32Array::from(vec![ + Some(1f32), + Some(f32::MAX), + Some(f32::MAX), + None, + Some(f32::NAN), + Some(f32::INFINITY), + Some(f32::INFINITY), + Some(0.0), + Some(0.0), + ]), + as_float32_array + ); - #[test] - fn test_abs_u64_array() { - with_fail_on_error(|fail_on_error| { - let input = UInt64Array::from(vec![Some(u64::MIN), Some(u64::MAX), None]); - let args = ColumnarValue::Array(Arc::new(input)); - let expected = UInt64Array::from(vec![Some(u64::MIN), Some(u64::MAX), None]); - let fail_on_error_arg = - ColumnarValue::Scalar(ScalarValue::Boolean(Some(fail_on_error))); + eval_array_ansi_mode!( + Float64Array::from(vec![ + Some(-1f64), + Some(f64::MIN), + Some(f64::MAX), + None, + Some(f64::NAN), + Some(f64::INFINITY), + Some(f64::NEG_INFINITY), + Some(0.0), + Some(-0.0), + ]), + Float64Array::from(vec![ + Some(1f64), + Some(f64::MAX), + Some(f64::MAX), + None, + Some(f64::NAN), + Some(f64::INFINITY), + Some(f64::INFINITY), + Some(0.0), + Some(0.0), + ]), + as_float64_array + ); - match spark_abs(&[args, fail_on_error_arg]) { - Ok(ColumnarValue::Array(result)) => { - let actual = as_uint64_array(&result)?; - assert_eq!(actual, &expected); - Ok(()) - } - Err(e) => { - if fail_on_error { - assert!( - e.to_string().contains("arithmetic overflow"), - "Error message did not match. Actual message: {e}" - ); - Ok(()) - } else { - panic!("Didn't expect error, but got: {e:?}") - } - } - _ => unreachable!(), - } - }); + eval_array_ansi_mode!(Decimal128Array::from(vec![Some(i128::MIN), None]) + .with_precision_and_scale(38, 37) + .unwrap()); + eval_array_ansi_mode!(Decimal256Array::from(vec![Some(i256::MIN), None]) + .with_precision_and_scale(5, 2) + .unwrap()); + eval_array_ansi_mode!(IntervalYearMonthArray::from(vec![i32::MIN, -1])); + eval_array_ansi_mode!(IntervalDayTimeArray::from(vec![IntervalDayTime::new( + i32::MIN, + i32::MIN, + )])); } } From 2e1101579afbcbf75c92159e77ffc20eb9197b11 Mon Sep 17 00:00:00 2001 From: huanghsiang_cheng Date: Fri, 7 Nov 2025 17:38:35 -0800 Subject: [PATCH 12/25] Remove unused imports/methods --- datafusion/spark/src/function/math/abs.rs | 11 ----------- 1 file changed, 11 deletions(-) diff --git a/datafusion/spark/src/function/math/abs.rs b/datafusion/spark/src/function/math/abs.rs index 1ea061aa8cf4..be3a7d5be093 100644 --- a/datafusion/spark/src/function/math/abs.rs +++ b/datafusion/spark/src/function/math/abs.rs @@ -515,17 +515,6 @@ pub fn spark_abs(args: &[ColumnarValue]) -> Result Result<()>>(test_fn: F) { - for fail_on_error in [true, false] { - let _ = test_fn(fail_on_error); - } - } #[test] fn test_abs_zero_arg() { From caa17e7ea36d0467e339f31b78c1927edfd939c1 Mon Sep 17 00:00:00 2001 From: huanghsiang_cheng Date: Fri, 7 Nov 2025 14:37:07 -0800 Subject: [PATCH 13/25] Test invalid arguments --- datafusion/spark/src/function/math/abs.rs | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/datafusion/spark/src/function/math/abs.rs b/datafusion/spark/src/function/math/abs.rs index be3a7d5be093..9218a197d584 100644 --- a/datafusion/spark/src/function/math/abs.rs +++ b/datafusion/spark/src/function/math/abs.rs @@ -79,7 +79,7 @@ impl ScalarUDFImpl for SparkAbs { } fn coerce_types(&self, arg_types: &[DataType]) -> Result> { - if arg_types.len() > 2 { + if arg_types.is_empty() || arg_types.len() > 2 { return Err(invalid_arg_count_exec_err("abs", (1, 2), arg_types.len())); } match &arg_types[0] { @@ -146,7 +146,7 @@ fn arithmetic_overflow_error(from_type: &str) -> DataFusionError { } pub fn spark_abs(args: &[ColumnarValue]) -> Result { - if args.len() > 2 { + if args.is_empty() || args.len() > 2 { return internal_err!("abs takes at most 2 arguments, but got: {}", args.len()); } @@ -517,8 +517,14 @@ mod tests { use super::*; #[test] - fn test_abs_zero_arg() { + fn test_abs_incorrect_arg() { + let arg = ColumnarValue::Scalar(ScalarValue::UInt8(Some(u8::MAX))); + // zero arg assert!(spark_abs(&[]).is_err()); + // more than 2 args + assert!(spark_abs(&[arg.clone(), arg.clone(), arg.clone()]).is_err()); + // incorrect 2nd arg type + assert!(spark_abs(&[arg.clone(), arg.clone()]).is_err()); } macro_rules! eval_legacy_mode { From 2494609c109b95e693460dc50ee0ed5c82ec084e Mon Sep 17 00:00:00 2001 From: huanghsiang_cheng Date: Sat, 8 Nov 2025 15:56:23 -0800 Subject: [PATCH 14/25] Reuse marcos from DF's abs function --- datafusion/functions/src/math/abs.rs | 5 +- datafusion/spark/src/function/math/abs.rs | 205 +++++----------------- 2 files changed, 51 insertions(+), 159 deletions(-) diff --git a/datafusion/functions/src/math/abs.rs b/datafusion/functions/src/math/abs.rs index 8049ef85ac36..70155ccac787 100644 --- a/datafusion/functions/src/math/abs.rs +++ b/datafusion/functions/src/math/abs.rs @@ -39,6 +39,7 @@ use num_traits::sign::Signed; type MathArrayFunction = fn(&ArrayRef) -> Result; +#[macro_export] macro_rules! make_abs_function { ($ARRAY_TYPE:ident) => {{ |input: &ArrayRef| { @@ -49,6 +50,7 @@ macro_rules! make_abs_function { }}; } +#[macro_export] macro_rules! make_try_abs_function { ($ARRAY_TYPE:ident) => {{ |input: &ArrayRef| { @@ -56,7 +58,7 @@ macro_rules! make_try_abs_function { let res: $ARRAY_TYPE = array.try_unary(|x| { x.checked_abs().ok_or_else(|| { ArrowError::ComputeError(format!( - "{} overflow on abs({})", + "{} overflow on abs({:?})", stringify!($ARRAY_TYPE), x )) @@ -67,6 +69,7 @@ macro_rules! make_try_abs_function { }}; } +#[macro_export] macro_rules! make_decimal_abs_function { ($ARRAY_TYPE:ident) => {{ |input: &ArrayRef| { diff --git a/datafusion/spark/src/function/math/abs.rs b/datafusion/spark/src/function/math/abs.rs index 9218a197d584..2d7f6fa4acdd 100644 --- a/datafusion/spark/src/function/math/abs.rs +++ b/datafusion/spark/src/function/math/abs.rs @@ -21,10 +21,15 @@ use crate::function::error_utils::{ use arrow::array::*; use arrow::datatypes::DataType; use arrow::datatypes::*; +use arrow::error::ArrowError; use datafusion_common::{internal_err, DataFusionError, Result, ScalarValue}; use datafusion_expr::{ ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility, }; +use datafusion_functions::{ + downcast_named_arg, make_abs_function, make_decimal_abs_function, + make_try_abs_function, +}; use std::any::Any; use std::sync::Arc; @@ -113,36 +118,8 @@ impl ScalarUDFImpl for SparkAbs { } } -macro_rules! legacy_compute_op { - ($ARRAY:expr, $FUNC:ident, $TYPE:ident, $RESULT:ident) => {{ - let array = $ARRAY.as_any().downcast_ref::<$TYPE>().unwrap(); - let res: $RESULT = arrow::compute::kernels::arity::unary(array, |x| x.$FUNC()); - res - }}; -} - -macro_rules! ansi_compute_op { - ($ARRAY:expr, $FUNC:ident, $TYPE:ident, $RESULT:ident, $MIN:expr, $FROM_TYPE:expr) => {{ - let array = $ARRAY.as_any().downcast_ref::<$TYPE>().unwrap(); - match arrow::compute::kernels::arity::try_unary(array, |x| { - if x == $MIN { - Err(arrow::error::ArrowError::ArithmeticOverflow( - $FROM_TYPE.to_string(), - )) - } else { - Ok(x.$FUNC()) - } - }) { - Ok(res) => Ok(ColumnarValue::Array(Arc::>::new( - res, - ))), - Err(_) => Err(arithmetic_overflow_error($FROM_TYPE)), - } - }}; -} - fn arithmetic_overflow_error(from_type: &str) -> DataFusionError { - DataFusionError::Execution(format!("arithmetic overflow from {from_type}")) + DataFusionError::Execution(format!("overflow on abs {from_type}")) } pub fn spark_abs(args: &[ColumnarValue]) -> Result { @@ -175,171 +152,83 @@ pub fn spark_abs(args: &[ColumnarValue]) -> Result Ok(args[0].clone()), DataType::Int8 => { if !fail_on_error { - let result = - legacy_compute_op!(array, wrapping_abs, Int8Array, Int8Array); - Ok(ColumnarValue::Array(Arc::new(result))) + let abs_fun = make_decimal_abs_function!(Int8Array); + abs_fun(array).map(ColumnarValue::Array) } else { - ansi_compute_op!(array, abs, Int8Array, Int8Type, i8::MIN, "Int8") + let abs_fun = make_try_abs_function!(Int8Array); + abs_fun(array).map(ColumnarValue::Array) } } DataType::Int16 => { if !fail_on_error { - let result = - legacy_compute_op!(array, wrapping_abs, Int16Array, Int16Array); - Ok(ColumnarValue::Array(Arc::new(result))) + let abs_fun = make_decimal_abs_function!(Int16Array); + abs_fun(array).map(ColumnarValue::Array) } else { - ansi_compute_op!(array, abs, Int16Array, Int16Type, i16::MIN, "Int16") + let abs_fun = make_try_abs_function!(Int16Array); + abs_fun(array).map(ColumnarValue::Array) } } DataType::Int32 => { if !fail_on_error { - let result = - legacy_compute_op!(array, wrapping_abs, Int32Array, Int32Array); - Ok(ColumnarValue::Array(Arc::new(result))) + let abs_fun = make_decimal_abs_function!(Int32Array); + abs_fun(array).map(ColumnarValue::Array) } else { - ansi_compute_op!(array, abs, Int32Array, Int32Type, i32::MIN, "Int32") + let abs_fun = make_try_abs_function!(Int32Array); + abs_fun(array).map(ColumnarValue::Array) } } DataType::Int64 => { if !fail_on_error { - let result = - legacy_compute_op!(array, wrapping_abs, Int64Array, Int64Array); - Ok(ColumnarValue::Array(Arc::new(result))) + let abs_fun = make_decimal_abs_function!(Int64Array); + abs_fun(array).map(ColumnarValue::Array) } else { - ansi_compute_op!(array, abs, Int64Array, Int64Type, i64::MIN, "Int64") + let abs_fun = make_try_abs_function!(Int64Array); + abs_fun(array).map(ColumnarValue::Array) } } DataType::Float32 => { - let result = legacy_compute_op!(array, abs, Float32Array, Float32Array); - Ok(ColumnarValue::Array(Arc::new(result))) + let abs_fun = make_abs_function!(Float32Array); + abs_fun(array).map(ColumnarValue::Array) } DataType::Float64 => { - let result = legacy_compute_op!(array, abs, Float64Array, Float64Array); - Ok(ColumnarValue::Array(Arc::new(result))) + let abs_fun = make_abs_function!(Float64Array); + abs_fun(array).map(ColumnarValue::Array) } - DataType::Decimal128(precision, scale) => { + DataType::Decimal128(_, _) => { if !fail_on_error { - let result = legacy_compute_op!( - array, - wrapping_abs, - Decimal128Array, - Decimal128Array - ); - let result = - result.with_data_type(DataType::Decimal128(*precision, *scale)); - Ok(ColumnarValue::Array(Arc::new(result))) + let abs_fun = make_decimal_abs_function!(Decimal128Array); + abs_fun(array).map(ColumnarValue::Array) } else { - // Need to pass precision and scale from input, so not using ansi_compute_op - let input = array.as_any().downcast_ref::(); - match input { - Some(i) => { - match arrow::compute::kernels::arity::try_unary(i, |x| { - if x == i128::MIN { - Err(arrow::error::ArrowError::ArithmeticOverflow( - "Decimal128".to_string(), - )) - } else { - Ok(x.abs()) - } - }) { - Ok(res) => Ok(ColumnarValue::Array(Arc::< - PrimitiveArray, - >::new( - res.with_data_type(DataType::Decimal128( - *precision, *scale, - )), - ))), - Err(_) => Err(arithmetic_overflow_error("Decimal128")), - } - } - _ => Err(DataFusionError::Internal( - "Invalid data type".to_string(), - )), - } + let abs_fun = make_try_abs_function!(Decimal128Array); + abs_fun(array).map(ColumnarValue::Array) } } - DataType::Decimal256(precision, scale) => { + DataType::Decimal256(_, _) => { if !fail_on_error { - let result = legacy_compute_op!( - array, - wrapping_abs, - Decimal256Array, - Decimal256Array - ); - let result = - result.with_data_type(DataType::Decimal256(*precision, *scale)); - Ok(ColumnarValue::Array(Arc::new(result))) + let abs_fun = make_decimal_abs_function!(Decimal256Array); + abs_fun(array).map(ColumnarValue::Array) } else { - // Need to pass precision and scale from input, so not using ansi_compute_op - let input = array.as_any().downcast_ref::(); - match input { - Some(i) => { - match arrow::compute::kernels::arity::try_unary(i, |x| { - if x == i256::MIN { - Err(arrow::error::ArrowError::ArithmeticOverflow( - "Decimal256".to_string(), - )) - } else { - Ok(x.wrapping_abs()) // i256 doesn't define abs() method - } - }) { - Ok(res) => Ok(ColumnarValue::Array(Arc::< - PrimitiveArray, - >::new( - res.with_data_type(DataType::Decimal256( - *precision, *scale, - )), - ))), - Err(_) => Err(arithmetic_overflow_error("Decimal256")), - } - } - _ => Err(DataFusionError::Internal( - "Invalid data type".to_string(), - )), - } + let abs_fun = make_try_abs_function!(Decimal256Array); + abs_fun(array).map(ColumnarValue::Array) } } DataType::Interval(unit) => match unit { IntervalUnit::YearMonth => { if !fail_on_error { - let result = legacy_compute_op!( - array, - wrapping_abs, - IntervalYearMonthArray, - IntervalYearMonthArray - ); - let result = result.with_data_type(DataType::Interval(*unit)); - Ok(ColumnarValue::Array(Arc::new(result))) + let abs_fun = make_decimal_abs_function!(IntervalYearMonthArray); + abs_fun(array).map(ColumnarValue::Array) } else { - ansi_compute_op!( - array, - abs, - IntervalYearMonthArray, - IntervalYearMonthType, - i32::MIN, - "IntervalYearMonth" - ) + let abs_fun = make_try_abs_function!(IntervalYearMonthArray); + abs_fun(array).map(ColumnarValue::Array) } } IntervalUnit::DayTime => { if !fail_on_error { - let result = legacy_compute_op!( - array, - wrapping_abs, - IntervalDayTimeArray, - IntervalDayTimeArray - ); - let result = result.with_data_type(DataType::Interval(*unit)); - Ok(ColumnarValue::Array(Arc::new(result))) + let abs_fun = make_decimal_abs_function!(IntervalDayTimeArray); + abs_fun(array).map(ColumnarValue::Array) } else { - ansi_compute_op!( - array, - wrapping_abs, - IntervalDayTimeArray, - IntervalDayTimeType, - IntervalDayTime::MIN, - "IntervalDayTime" - ) + let abs_fun = make_try_abs_function!(IntervalDayTimeArray); + abs_fun(array).map(ColumnarValue::Array) } } IntervalUnit::MonthDayNano => internal_err!( @@ -630,7 +519,7 @@ mod tests { match spark_abs(&[args, fail_on_error]) { Err(e) => { assert!( - e.to_string().contains("arithmetic overflow"), + e.to_string().contains("overflow on abs"), "Error message did not match. Actual message: {e}" ); } @@ -654,7 +543,7 @@ mod tests { match spark_abs(&[args, fail_on_error]) { Err(e) => { assert!( - e.to_string().contains("arithmetic overflow"), + e.to_string().contains("overflow on abs"), "Error message did not match. Actual message: {e}" ); } @@ -858,7 +747,7 @@ mod tests { match spark_abs(&[args, fail_on_error]) { Err(e) => { assert!( - e.to_string().contains("arithmetic overflow"), + e.to_string().contains("overflow on abs"), "Error message did not match. Actual message: {e}" ); } From d11869f0dcfafd641f7e6e0e53c644de21c33427 Mon Sep 17 00:00:00 2001 From: huanghsiang_cheng Date: Sat, 8 Nov 2025 22:15:08 -0800 Subject: [PATCH 15/25] Rename wrapping_abs marco --- datafusion/functions/src/math/abs.rs | 10 +++++----- datafusion/spark/src/function/math/abs.rs | 20 ++++++++++---------- 2 files changed, 15 insertions(+), 15 deletions(-) diff --git a/datafusion/functions/src/math/abs.rs b/datafusion/functions/src/math/abs.rs index 70155ccac787..e16fc804bc62 100644 --- a/datafusion/functions/src/math/abs.rs +++ b/datafusion/functions/src/math/abs.rs @@ -70,7 +70,7 @@ macro_rules! make_try_abs_function { } #[macro_export] -macro_rules! make_decimal_abs_function { +macro_rules! make_wrapping_abs_function { ($ARRAY_TYPE:ident) => {{ |input: &ArrayRef| { let array = downcast_named_arg!(&input, "abs arg", $ARRAY_TYPE); @@ -104,10 +104,10 @@ fn create_abs_function(input_data_type: &DataType) -> Result | DataType::UInt64 => Ok(|input: &ArrayRef| Ok(Arc::clone(input))), // Decimal types - DataType::Decimal32(_, _) => Ok(make_decimal_abs_function!(Decimal32Array)), - DataType::Decimal64(_, _) => Ok(make_decimal_abs_function!(Decimal64Array)), - DataType::Decimal128(_, _) => Ok(make_decimal_abs_function!(Decimal128Array)), - DataType::Decimal256(_, _) => Ok(make_decimal_abs_function!(Decimal256Array)), + DataType::Decimal32(_, _) => Ok(make_wrapping_abs_function!(Decimal32Array)), + DataType::Decimal64(_, _) => Ok(make_wrapping_abs_function!(Decimal64Array)), + DataType::Decimal128(_, _) => Ok(make_wrapping_abs_function!(Decimal128Array)), + DataType::Decimal256(_, _) => Ok(make_wrapping_abs_function!(Decimal256Array)), other => not_impl_err!("Unsupported data type {other:?} for function abs"), } diff --git a/datafusion/spark/src/function/math/abs.rs b/datafusion/spark/src/function/math/abs.rs index 2d7f6fa4acdd..879372231876 100644 --- a/datafusion/spark/src/function/math/abs.rs +++ b/datafusion/spark/src/function/math/abs.rs @@ -27,8 +27,8 @@ use datafusion_expr::{ ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility, }; use datafusion_functions::{ - downcast_named_arg, make_abs_function, make_decimal_abs_function, - make_try_abs_function, + downcast_named_arg, make_abs_function, make_try_abs_function, + make_wrapping_abs_function, }; use std::any::Any; use std::sync::Arc; @@ -152,7 +152,7 @@ pub fn spark_abs(args: &[ColumnarValue]) -> Result Ok(args[0].clone()), DataType::Int8 => { if !fail_on_error { - let abs_fun = make_decimal_abs_function!(Int8Array); + let abs_fun = make_wrapping_abs_function!(Int8Array); abs_fun(array).map(ColumnarValue::Array) } else { let abs_fun = make_try_abs_function!(Int8Array); @@ -161,7 +161,7 @@ pub fn spark_abs(args: &[ColumnarValue]) -> Result { if !fail_on_error { - let abs_fun = make_decimal_abs_function!(Int16Array); + let abs_fun = make_wrapping_abs_function!(Int16Array); abs_fun(array).map(ColumnarValue::Array) } else { let abs_fun = make_try_abs_function!(Int16Array); @@ -170,7 +170,7 @@ pub fn spark_abs(args: &[ColumnarValue]) -> Result { if !fail_on_error { - let abs_fun = make_decimal_abs_function!(Int32Array); + let abs_fun = make_wrapping_abs_function!(Int32Array); abs_fun(array).map(ColumnarValue::Array) } else { let abs_fun = make_try_abs_function!(Int32Array); @@ -179,7 +179,7 @@ pub fn spark_abs(args: &[ColumnarValue]) -> Result { if !fail_on_error { - let abs_fun = make_decimal_abs_function!(Int64Array); + let abs_fun = make_wrapping_abs_function!(Int64Array); abs_fun(array).map(ColumnarValue::Array) } else { let abs_fun = make_try_abs_function!(Int64Array); @@ -196,7 +196,7 @@ pub fn spark_abs(args: &[ColumnarValue]) -> Result { if !fail_on_error { - let abs_fun = make_decimal_abs_function!(Decimal128Array); + let abs_fun = make_wrapping_abs_function!(Decimal128Array); abs_fun(array).map(ColumnarValue::Array) } else { let abs_fun = make_try_abs_function!(Decimal128Array); @@ -205,7 +205,7 @@ pub fn spark_abs(args: &[ColumnarValue]) -> Result { if !fail_on_error { - let abs_fun = make_decimal_abs_function!(Decimal256Array); + let abs_fun = make_wrapping_abs_function!(Decimal256Array); abs_fun(array).map(ColumnarValue::Array) } else { let abs_fun = make_try_abs_function!(Decimal256Array); @@ -215,7 +215,7 @@ pub fn spark_abs(args: &[ColumnarValue]) -> Result match unit { IntervalUnit::YearMonth => { if !fail_on_error { - let abs_fun = make_decimal_abs_function!(IntervalYearMonthArray); + let abs_fun = make_wrapping_abs_function!(IntervalYearMonthArray); abs_fun(array).map(ColumnarValue::Array) } else { let abs_fun = make_try_abs_function!(IntervalYearMonthArray); @@ -224,7 +224,7 @@ pub fn spark_abs(args: &[ColumnarValue]) -> Result { if !fail_on_error { - let abs_fun = make_decimal_abs_function!(IntervalDayTimeArray); + let abs_fun = make_wrapping_abs_function!(IntervalDayTimeArray); abs_fun(array).map(ColumnarValue::Array) } else { let abs_fun = make_try_abs_function!(IntervalDayTimeArray); From 94065b2899b6c3905f7949b2b881df0da3fa0bd7 Mon Sep 17 00:00:00 2001 From: huanghsiang_cheng Date: Sat, 8 Nov 2025 22:16:16 -0800 Subject: [PATCH 16/25] Fix comment --- datafusion/spark/src/function/math/abs.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/datafusion/spark/src/function/math/abs.rs b/datafusion/spark/src/function/math/abs.rs index 879372231876..99c8b9d33129 100644 --- a/datafusion/spark/src/function/math/abs.rs +++ b/datafusion/spark/src/function/math/abs.rs @@ -40,7 +40,7 @@ use std::sync::Arc; /// Returns NULL if input is NULL, returns NaN if input is NaN. /// /// Differences with DataFusion abs: -/// - Spark's ANSI-compliant dialect, when off (i.e. `spark.sql.ansi.enabled=false`), taking absolute values on minimal values of signed integers returns the value as is. DataFusion's abs throws "DataFusion error: Arrow error: Compute error" on arithmetic overflow +/// - Spark's ANSI-compliant dialect, when off (i.e. `spark.sql.ansi.enabled=false`), taking absolute value on the minimal value of a signed integer returns the value as is. DataFusion's abs throws "DataFusion error: Arrow error: Compute error" on arithmetic overflow /// - Spark's abs also supports ANSI interval types: YearMonthIntervalType and DayTimeIntervalType. DataFusion's abs doesn't. /// #[derive(Debug, PartialEq, Eq, Hash)] @@ -109,7 +109,7 @@ impl ScalarUDFImpl for SparkAbs { } else { Err(unsupported_data_type_exec_err( "abs", - "Numeric Type or Interval Type", + "Numeric Type or ANSI Interval Type", &arg_types[0], )) } From 1839ab5319bc2055e997520ee3965aa0a74070dd Mon Sep 17 00:00:00 2001 From: huanghsiang_cheng Date: Sat, 8 Nov 2025 22:51:13 -0800 Subject: [PATCH 17/25] Refactor scalar logic to marco --- datafusion/spark/src/function/math/abs.rs | 177 +++++++--------------- 1 file changed, 51 insertions(+), 126 deletions(-) diff --git a/datafusion/spark/src/function/math/abs.rs b/datafusion/spark/src/function/math/abs.rs index 99c8b9d33129..d75a02cbeb91 100644 --- a/datafusion/spark/src/function/math/abs.rs +++ b/datafusion/spark/src/function/math/abs.rs @@ -118,8 +118,41 @@ impl ScalarUDFImpl for SparkAbs { } } -fn arithmetic_overflow_error(from_type: &str) -> DataFusionError { - DataFusionError::Execution(format!("overflow on abs {from_type}")) +macro_rules! scalar_compute_op { + ($FLAG:expr, $INPUT:ident, $SCALAR_TYPE:ident) => {{ + let result = if !$FLAG { + $INPUT.wrapping_abs() + } else { + $INPUT.checked_abs().ok_or_else(|| { + ArrowError::ComputeError(format!( + "{} overflow on abs({:?})", + stringify!($SCALAR_TYPE), + $INPUT + )) + })? + }; + Ok(ColumnarValue::Scalar(ScalarValue::$SCALAR_TYPE(Some( + result, + )))) + }}; + ($FLAG:expr, $INPUT:ident, $PRECISION:expr, $SCALE:expr, $SCALAR_TYPE:ident) => {{ + let result = if !$FLAG { + $INPUT.wrapping_abs() + } else { + $INPUT.checked_abs().ok_or_else(|| { + ArrowError::ComputeError(format!( + "{} overflow on abs({:?})", + stringify!($SCALAR_TYPE), + $INPUT + )) + })? + }; + Ok(ColumnarValue::Scalar(ScalarValue::$SCALAR_TYPE( + Some(result), + $PRECISION, + $SCALE, + ))) + }}; } pub fn spark_abs(args: &[ColumnarValue]) -> Result { @@ -245,67 +278,19 @@ pub fn spark_abs(args: &[ColumnarValue]) -> Result Ok(args[0].clone()), ScalarValue::Int8(a) => match a { None => Ok(args[0].clone()), - Some(v) => match v.checked_abs() { - Some(abs_val) => { - Ok(ColumnarValue::Scalar(ScalarValue::Int8(Some(abs_val)))) - } - None => { - if !fail_on_error { - // return the original value - Ok(ColumnarValue::Scalar(ScalarValue::Int8(Some(*v)))) - } else { - Err(arithmetic_overflow_error("Int8")) - } - } - }, + Some(v) => scalar_compute_op!(fail_on_error, v, Int8), }, ScalarValue::Int16(a) => match a { None => Ok(args[0].clone()), - Some(v) => match v.checked_abs() { - Some(abs_val) => { - Ok(ColumnarValue::Scalar(ScalarValue::Int16(Some(abs_val)))) - } - None => { - if !fail_on_error { - // return the original value - Ok(ColumnarValue::Scalar(ScalarValue::Int16(Some(*v)))) - } else { - Err(arithmetic_overflow_error("Int16")) - } - } - }, + Some(v) => scalar_compute_op!(fail_on_error, v, Int16), }, ScalarValue::Int32(a) => match a { None => Ok(args[0].clone()), - Some(v) => match v.checked_abs() { - Some(abs_val) => { - Ok(ColumnarValue::Scalar(ScalarValue::Int32(Some(abs_val)))) - } - None => { - if !fail_on_error { - // return the original value - Ok(ColumnarValue::Scalar(ScalarValue::Int32(Some(*v)))) - } else { - Err(arithmetic_overflow_error("Int32")) - } - } - }, + Some(v) => scalar_compute_op!(fail_on_error, v, Int32), }, ScalarValue::Int64(a) => match a { None => Ok(args[0].clone()), - Some(v) => match v.checked_abs() { - Some(abs_val) => { - Ok(ColumnarValue::Scalar(ScalarValue::Int64(Some(abs_val)))) - } - None => { - if !fail_on_error { - // return the original value - Ok(ColumnarValue::Scalar(ScalarValue::Int64(Some(*v)))) - } else { - Err(arithmetic_overflow_error("Int64")) - } - } - }, + Some(v) => scalar_compute_op!(fail_on_error, v, Int64), }, ScalarValue::Float32(a) => match a { None => Ok(args[0].clone()), @@ -315,85 +300,25 @@ pub fn spark_abs(args: &[ColumnarValue]) -> Result Ok(args[0].clone()), Some(v) => Ok(ColumnarValue::Scalar(ScalarValue::Float64(Some(v.abs())))), }, - ScalarValue::Decimal128(a, precision, scale) => { - match a { - None => Ok(args[0].clone()), - Some(v) => match v.checked_abs() { - Some(abs_val) => Ok(ColumnarValue::Scalar( - ScalarValue::Decimal128(Some(abs_val), *precision, *scale), - )), - None => { - if !fail_on_error { - // return the original value - Ok(ColumnarValue::Scalar(ScalarValue::Decimal128( - Some(*v), - *precision, - *scale, - ))) - } else { - Err(arithmetic_overflow_error("Decimal128")) - } - } - }, + ScalarValue::Decimal128(a, precision, scale) => match a { + None => Ok(args[0].clone()), + Some(v) => { + scalar_compute_op!(fail_on_error, v, *precision, *scale, Decimal128) } - } - ScalarValue::Decimal256(a, precision, scale) => { - match a { - None => Ok(args[0].clone()), - Some(v) => match v.checked_abs() { - Some(abs_val) => Ok(ColumnarValue::Scalar( - ScalarValue::Decimal256(Some(abs_val), *precision, *scale), - )), - None => { - if !fail_on_error { - // return the original value - Ok(ColumnarValue::Scalar(ScalarValue::Decimal256( - Some(*v), - *precision, - *scale, - ))) - } else { - Err(arithmetic_overflow_error("Decimal256")) - } - } - }, + }, + ScalarValue::Decimal256(a, precision, scale) => match a { + None => Ok(args[0].clone()), + Some(v) => { + scalar_compute_op!(fail_on_error, v, *precision, *scale, Decimal256) } - } + }, ScalarValue::IntervalYearMonth(a) => match a { None => Ok(args[0].clone()), - Some(v) => match v.checked_abs() { - Some(abs_val) => Ok(ColumnarValue::Scalar( - ScalarValue::IntervalYearMonth(Some(abs_val)), - )), - None => { - if !fail_on_error { - // return the original value - Ok(ColumnarValue::Scalar(ScalarValue::IntervalYearMonth( - Some(*v), - ))) - } else { - Err(arithmetic_overflow_error("IntervalYearMonth")) - } - } - }, + Some(v) => scalar_compute_op!(fail_on_error, v, IntervalYearMonth), }, ScalarValue::IntervalDayTime(a) => match a { None => Ok(args[0].clone()), - Some(v) => match v.checked_abs() { - Some(abs_val) => Ok(ColumnarValue::Scalar( - ScalarValue::IntervalDayTime(Some(abs_val)), - )), - None => { - if !fail_on_error { - // return the original value - Ok(ColumnarValue::Scalar(ScalarValue::IntervalDayTime(Some( - *v, - )))) - } else { - Err(arithmetic_overflow_error("IntervalYearMonth")) - } - } - }, + Some(v) => scalar_compute_op!(fail_on_error, v, IntervalDayTime), }, dt => internal_err!("Not supported datatype for Spark ABS: {dt}"), From 5cb0215773ba08cbfd77beb15dda3151ad8494e4 Mon Sep 17 00:00:00 2001 From: huanghsiang_cheng Date: Tue, 11 Nov 2025 15:12:53 -0800 Subject: [PATCH 18/25] Remove ANSI support logic --- datafusion/functions/src/math/abs.rs | 1 - datafusion/spark/src/function/math/abs.rs | 383 ++-------------------- 2 files changed, 34 insertions(+), 350 deletions(-) diff --git a/datafusion/functions/src/math/abs.rs b/datafusion/functions/src/math/abs.rs index e16fc804bc62..179cb1391775 100644 --- a/datafusion/functions/src/math/abs.rs +++ b/datafusion/functions/src/math/abs.rs @@ -50,7 +50,6 @@ macro_rules! make_abs_function { }}; } -#[macro_export] macro_rules! make_try_abs_function { ($ARRAY_TYPE:ident) => {{ |input: &ArrayRef| { diff --git a/datafusion/spark/src/function/math/abs.rs b/datafusion/spark/src/function/math/abs.rs index d75a02cbeb91..e622f232c8a4 100644 --- a/datafusion/spark/src/function/math/abs.rs +++ b/datafusion/spark/src/function/math/abs.rs @@ -20,14 +20,12 @@ use crate::function::error_utils::{ }; use arrow::array::*; use arrow::datatypes::DataType; -use arrow::datatypes::*; -use arrow::error::ArrowError; use datafusion_common::{internal_err, DataFusionError, Result, ScalarValue}; use datafusion_expr::{ ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility, }; use datafusion_functions::{ - downcast_named_arg, make_abs_function, make_try_abs_function, + downcast_named_arg, make_abs_function, make_wrapping_abs_function, }; use std::any::Any; @@ -39,7 +37,7 @@ use std::sync::Arc; /// Returns the absolute value of input /// Returns NULL if input is NULL, returns NaN if input is NaN. /// -/// Differences with DataFusion abs: +/// TODOs: /// - Spark's ANSI-compliant dialect, when off (i.e. `spark.sql.ansi.enabled=false`), taking absolute value on the minimal value of a signed integer returns the value as is. DataFusion's abs throws "DataFusion error: Arrow error: Compute error" on arithmetic overflow /// - Spark's abs also supports ANSI interval types: YearMonthIntervalType and DayTimeIntervalType. DataFusion's abs doesn't. /// @@ -84,8 +82,8 @@ impl ScalarUDFImpl for SparkAbs { } fn coerce_types(&self, arg_types: &[DataType]) -> Result> { - if arg_types.is_empty() || arg_types.len() > 2 { - return Err(invalid_arg_count_exec_err("abs", (1, 2), arg_types.len())); + if arg_types.len() != 1 { + return Err(invalid_arg_count_exec_err("abs", (1, 1), arg_types.len())); } match &arg_types[0] { DataType::Null @@ -119,34 +117,14 @@ impl ScalarUDFImpl for SparkAbs { } macro_rules! scalar_compute_op { - ($FLAG:expr, $INPUT:ident, $SCALAR_TYPE:ident) => {{ - let result = if !$FLAG { - $INPUT.wrapping_abs() - } else { - $INPUT.checked_abs().ok_or_else(|| { - ArrowError::ComputeError(format!( - "{} overflow on abs({:?})", - stringify!($SCALAR_TYPE), - $INPUT - )) - })? - }; + ($INPUT:ident, $SCALAR_TYPE:ident) => {{ + let result = $INPUT.wrapping_abs(); Ok(ColumnarValue::Scalar(ScalarValue::$SCALAR_TYPE(Some( result, )))) }}; - ($FLAG:expr, $INPUT:ident, $PRECISION:expr, $SCALE:expr, $SCALAR_TYPE:ident) => {{ - let result = if !$FLAG { - $INPUT.wrapping_abs() - } else { - $INPUT.checked_abs().ok_or_else(|| { - ArrowError::ComputeError(format!( - "{} overflow on abs({:?})", - stringify!($SCALAR_TYPE), - $INPUT - )) - })? - }; + ($INPUT:ident, $PRECISION:expr, $SCALE:expr, $SCALAR_TYPE:ident) => {{ + let result = $INPUT.wrapping_abs(); Ok(ColumnarValue::Scalar(ScalarValue::$SCALAR_TYPE( Some(result), $PRECISION, @@ -156,26 +134,10 @@ macro_rules! scalar_compute_op { } pub fn spark_abs(args: &[ColumnarValue]) -> Result { - if args.is_empty() || args.len() > 2 { - return internal_err!("abs takes at most 2 arguments, but got: {}", args.len()); + if args.len() != 1 { + return internal_err!("abs takes exactly 1 argument, but got: {}", args.len()); } - let fail_on_error = if args.len() == 2 { - match &args[1] { - ColumnarValue::Scalar(ScalarValue::Boolean(Some(fail_on_error))) => { - *fail_on_error - } - _ => { - return internal_err!( - "The second argument must be boolean scalar, but got: {:?}", - args[1] - ); - } - } - } else { - false - }; - match &args[0] { ColumnarValue::Array(array) => match array.data_type() { DataType::Null @@ -184,40 +146,20 @@ pub fn spark_abs(args: &[ColumnarValue]) -> Result Ok(args[0].clone()), DataType::Int8 => { - if !fail_on_error { - let abs_fun = make_wrapping_abs_function!(Int8Array); - abs_fun(array).map(ColumnarValue::Array) - } else { - let abs_fun = make_try_abs_function!(Int8Array); - abs_fun(array).map(ColumnarValue::Array) - } + let abs_fun = make_wrapping_abs_function!(Int8Array); + abs_fun(array).map(ColumnarValue::Array) } DataType::Int16 => { - if !fail_on_error { - let abs_fun = make_wrapping_abs_function!(Int16Array); - abs_fun(array).map(ColumnarValue::Array) - } else { - let abs_fun = make_try_abs_function!(Int16Array); - abs_fun(array).map(ColumnarValue::Array) - } + let abs_fun = make_wrapping_abs_function!(Int16Array); + abs_fun(array).map(ColumnarValue::Array) } DataType::Int32 => { - if !fail_on_error { let abs_fun = make_wrapping_abs_function!(Int32Array); abs_fun(array).map(ColumnarValue::Array) - } else { - let abs_fun = make_try_abs_function!(Int32Array); - abs_fun(array).map(ColumnarValue::Array) - } } DataType::Int64 => { - if !fail_on_error { - let abs_fun = make_wrapping_abs_function!(Int64Array); - abs_fun(array).map(ColumnarValue::Array) - } else { - let abs_fun = make_try_abs_function!(Int64Array); - abs_fun(array).map(ColumnarValue::Array) - } + let abs_fun = make_wrapping_abs_function!(Int64Array); + abs_fun(array).map(ColumnarValue::Array) } DataType::Float32 => { let abs_fun = make_abs_function!(Float32Array); @@ -228,41 +170,21 @@ pub fn spark_abs(args: &[ColumnarValue]) -> Result { - if !fail_on_error { - let abs_fun = make_wrapping_abs_function!(Decimal128Array); - abs_fun(array).map(ColumnarValue::Array) - } else { - let abs_fun = make_try_abs_function!(Decimal128Array); - abs_fun(array).map(ColumnarValue::Array) - } + let abs_fun = make_wrapping_abs_function!(Decimal128Array); + abs_fun(array).map(ColumnarValue::Array) } DataType::Decimal256(_, _) => { - if !fail_on_error { - let abs_fun = make_wrapping_abs_function!(Decimal256Array); - abs_fun(array).map(ColumnarValue::Array) - } else { - let abs_fun = make_try_abs_function!(Decimal256Array); - abs_fun(array).map(ColumnarValue::Array) - } + let abs_fun = make_wrapping_abs_function!(Decimal256Array); + abs_fun(array).map(ColumnarValue::Array) } DataType::Interval(unit) => match unit { IntervalUnit::YearMonth => { - if !fail_on_error { - let abs_fun = make_wrapping_abs_function!(IntervalYearMonthArray); - abs_fun(array).map(ColumnarValue::Array) - } else { - let abs_fun = make_try_abs_function!(IntervalYearMonthArray); - abs_fun(array).map(ColumnarValue::Array) - } + let abs_fun = make_wrapping_abs_function!(IntervalYearMonthArray); + abs_fun(array).map(ColumnarValue::Array) } IntervalUnit::DayTime => { - if !fail_on_error { - let abs_fun = make_wrapping_abs_function!(IntervalDayTimeArray); - abs_fun(array).map(ColumnarValue::Array) - } else { - let abs_fun = make_try_abs_function!(IntervalDayTimeArray); - abs_fun(array).map(ColumnarValue::Array) - } + let abs_fun = make_wrapping_abs_function!(IntervalDayTimeArray); + abs_fun(array).map(ColumnarValue::Array) } IntervalUnit::MonthDayNano => internal_err!( "MonthDayNano is not a supported Interval unit for Spark ABS" @@ -278,19 +200,19 @@ pub fn spark_abs(args: &[ColumnarValue]) -> Result Ok(args[0].clone()), ScalarValue::Int8(a) => match a { None => Ok(args[0].clone()), - Some(v) => scalar_compute_op!(fail_on_error, v, Int8), + Some(v) => scalar_compute_op!(v, Int8), }, ScalarValue::Int16(a) => match a { None => Ok(args[0].clone()), - Some(v) => scalar_compute_op!(fail_on_error, v, Int16), + Some(v) => scalar_compute_op!(v, Int16), }, ScalarValue::Int32(a) => match a { None => Ok(args[0].clone()), - Some(v) => scalar_compute_op!(fail_on_error, v, Int32), + Some(v) => scalar_compute_op!(v, Int32), }, ScalarValue::Int64(a) => match a { None => Ok(args[0].clone()), - Some(v) => scalar_compute_op!(fail_on_error, v, Int64), + Some(v) => scalar_compute_op!(v, Int64), }, ScalarValue::Float32(a) => match a { None => Ok(args[0].clone()), @@ -303,22 +225,22 @@ pub fn spark_abs(args: &[ColumnarValue]) -> Result match a { None => Ok(args[0].clone()), Some(v) => { - scalar_compute_op!(fail_on_error, v, *precision, *scale, Decimal128) + scalar_compute_op!(v, *precision, *scale, Decimal128) } }, ScalarValue::Decimal256(a, precision, scale) => match a { None => Ok(args[0].clone()), Some(v) => { - scalar_compute_op!(fail_on_error, v, *precision, *scale, Decimal256) + scalar_compute_op!(v, *precision, *scale, Decimal256) } }, ScalarValue::IntervalYearMonth(a) => match a { None => Ok(args[0].clone()), - Some(v) => scalar_compute_op!(fail_on_error, v, IntervalYearMonth), + Some(v) => scalar_compute_op!(v, IntervalYearMonth), }, ScalarValue::IntervalDayTime(a) => match a { None => Ok(args[0].clone()), - Some(v) => scalar_compute_op!(fail_on_error, v, IntervalDayTime), + Some(v) => scalar_compute_op!(v, IntervalDayTime), }, dt => internal_err!("Not supported datatype for Spark ABS: {dt}"), @@ -329,17 +251,7 @@ pub fn spark_abs(args: &[ColumnarValue]) -> Result {{ @@ -437,113 +349,12 @@ mod tests { eval_legacy_mode!(Float64, -0.0f64, 0.0f64); } - macro_rules! eval_ansi_mode { - ($TYPE:ident, $VAL:expr) => {{ - let args = ColumnarValue::Scalar(ScalarValue::$TYPE(Some($VAL))); - let fail_on_error = ColumnarValue::Scalar(ScalarValue::Boolean(Some(true))); - match spark_abs(&[args, fail_on_error]) { - Err(e) => { - assert!( - e.to_string().contains("overflow on abs"), - "Error message did not match. Actual message: {e}" - ); - } - _ => unreachable!(), - } - }}; - ($TYPE:ident, $VAL:expr, $RESULT:expr) => {{ - let args = ColumnarValue::Scalar(ScalarValue::$TYPE(Some($VAL))); - let fail_on_error = ColumnarValue::Scalar(ScalarValue::Boolean(Some(true))); - match spark_abs(&[args, fail_on_error]) { - Ok(ColumnarValue::Scalar(ScalarValue::$TYPE(Some(result)))) => { - assert_eq!(result, $RESULT); - } - _ => unreachable!(), - } - }}; - ($TYPE:ident, $VAL:expr, $PRECISION:expr, $SCALE:expr) => {{ - let args = - ColumnarValue::Scalar(ScalarValue::$TYPE(Some($VAL), $PRECISION, $SCALE)); - let fail_on_error = ColumnarValue::Scalar(ScalarValue::Boolean(Some(true))); - match spark_abs(&[args, fail_on_error]) { - Err(e) => { - assert!( - e.to_string().contains("overflow on abs"), - "Error message did not match. Actual message: {e}" - ); - } - _ => unreachable!(), - } - }}; - ($TYPE:ident, $VAL:expr, $PRECISION:expr, $SCALE:expr, $RESULT:expr) => {{ - let args = - ColumnarValue::Scalar(ScalarValue::$TYPE(Some($VAL), $PRECISION, $SCALE)); - let fail_on_error = ColumnarValue::Scalar(ScalarValue::Boolean(Some(true))); - match spark_abs(&[args, fail_on_error]) { - Ok(ColumnarValue::Scalar(ScalarValue::$TYPE( - Some(result), - precision, - scale, - ))) => { - assert_eq!(result, $RESULT); - assert_eq!(precision, $PRECISION); - assert_eq!(scale, $SCALE); - } - _ => unreachable!(), - } - }}; - } - - #[test] - fn test_abs_scalar_ansi_mode() { - eval_ansi_mode!(Int8, i8::MIN); - eval_ansi_mode!(Int16, i16::MIN); - eval_ansi_mode!(Int32, i32::MIN); - eval_ansi_mode!(Int64, i64::MIN); - eval_ansi_mode!(Decimal128, i128::MIN, 18, 10); - eval_ansi_mode!(Decimal256, i256::MIN, 10, 2); - eval_ansi_mode!(IntervalYearMonth, i32::MIN); - eval_ansi_mode!(IntervalDayTime, IntervalDayTime::MIN); - - eval_ansi_mode!(UInt8, u8::MIN, u8::MIN); - eval_ansi_mode!(UInt16, u16::MIN, u16::MIN); - eval_ansi_mode!(UInt32, u32::MIN, u32::MIN); - eval_ansi_mode!(UInt64, u64::MIN, u64::MIN); - eval_ansi_mode!(Float32, f32::MIN, f32::MAX); - eval_ansi_mode!(Float64, f64::MIN, f64::MAX); - - // NumericType, DayTimeIntervalType, and YearMonthIntervalType not MIN - eval_ansi_mode!(Int8, -1i8, 1i8); - eval_ansi_mode!(Int16, -1i16, 1i16); - eval_ansi_mode!(Int32, -1i32, 1i32); - eval_ansi_mode!(Int64, -1i64, 1i64); - eval_ansi_mode!(Decimal128, -1i128, 18, 10, 1i128); - eval_ansi_mode!(Decimal256, i256::from(-1i8), 10, 2, i256::from(1i8)); - eval_ansi_mode!(IntervalYearMonth, -1i32, 1i32); - eval_ansi_mode!( - IntervalDayTime, - IntervalDayTime::new(-1i32, -1i32), - IntervalDayTime::new(1i32, 1i32) - ); - - // Float32, Float64 - eval_ansi_mode!(Float32, f32::NEG_INFINITY, f32::INFINITY); - eval_ansi_mode!(Float32, f32::INFINITY, f32::INFINITY); - eval_ansi_mode!(Float32, 0.0f32, 0.0f32); - eval_ansi_mode!(Float32, -0.0f32, 0.0f32); - eval_ansi_mode!(Float64, f64::NEG_INFINITY, f64::INFINITY); - eval_ansi_mode!(Float64, f64::INFINITY, f64::INFINITY); - eval_ansi_mode!(Float64, 0.0f64, 0.0f64); - eval_ansi_mode!(Float64, -0.0f64, 0.0f64); - } - macro_rules! eval_array_legacy_mode { ($INPUT:expr, $OUTPUT:expr, $FUNC:ident) => {{ let input = $INPUT; let args = ColumnarValue::Array(Arc::new(input)); - let fail_on_error = ColumnarValue::Scalar(ScalarValue::Boolean(Some(false))); let expected = $OUTPUT; - match spark_abs(&[args, fail_on_error]) { + match spark_abs(&[args]) { Ok(ColumnarValue::Array(result)) => { let actual = datafusion_common::cast::$FUNC(&result).unwrap(); assert_eq!(actual, &expected); @@ -663,130 +474,4 @@ mod tests { as_interval_dt_array ); } - - macro_rules! eval_array_ansi_mode { - ($INPUT:expr) => {{ - let input = $INPUT; - let args = ColumnarValue::Array(Arc::new(input)); - let fail_on_error = ColumnarValue::Scalar(ScalarValue::Boolean(Some(true))); - match spark_abs(&[args, fail_on_error]) { - Err(e) => { - assert!( - e.to_string().contains("overflow on abs"), - "Error message did not match. Actual message: {e}" - ); - } - _ => unreachable!(), - } - }}; - ($INPUT:expr, $OUTPUT:expr, $FUNC:ident) => {{ - let input = $INPUT; - let args = ColumnarValue::Array(Arc::new(input)); - let fail_on_error = ColumnarValue::Scalar(ScalarValue::Boolean(Some(true))); - let expected = $OUTPUT; - match spark_abs(&[args, fail_on_error]) { - Ok(ColumnarValue::Array(result)) => { - let actual = datafusion_common::cast::$FUNC(&result).unwrap(); - assert_eq!(actual, &expected); - } - _ => unreachable!(), - } - }}; - } - #[test] - fn test_abs_array_ansi_mode() { - eval_array_ansi_mode!( - UInt64Array::from(vec![Some(u64::MIN), Some(u64::MAX), None]), - UInt64Array::from(vec![Some(u64::MIN), Some(u64::MAX), None]), - as_uint64_array - ); - - eval_array_ansi_mode!(Int8Array::from(vec![ - Some(-1), - Some(i8::MIN), - Some(i8::MAX), - None - ])); - eval_array_ansi_mode!(Int16Array::from(vec![ - Some(-1), - Some(i16::MIN), - Some(i16::MAX), - None - ])); - eval_array_ansi_mode!(Int32Array::from(vec![ - Some(-1), - Some(i32::MIN), - Some(i32::MAX), - None - ])); - eval_array_ansi_mode!(Int64Array::from(vec![ - Some(-1), - Some(i64::MIN), - Some(i64::MAX), - None - ])); - eval_array_ansi_mode!( - Float32Array::from(vec![ - Some(-1f32), - Some(f32::MIN), - Some(f32::MAX), - None, - Some(f32::NAN), - Some(f32::INFINITY), - Some(f32::NEG_INFINITY), - Some(0.0), - Some(-0.0), - ]), - Float32Array::from(vec![ - Some(1f32), - Some(f32::MAX), - Some(f32::MAX), - None, - Some(f32::NAN), - Some(f32::INFINITY), - Some(f32::INFINITY), - Some(0.0), - Some(0.0), - ]), - as_float32_array - ); - - eval_array_ansi_mode!( - Float64Array::from(vec![ - Some(-1f64), - Some(f64::MIN), - Some(f64::MAX), - None, - Some(f64::NAN), - Some(f64::INFINITY), - Some(f64::NEG_INFINITY), - Some(0.0), - Some(-0.0), - ]), - Float64Array::from(vec![ - Some(1f64), - Some(f64::MAX), - Some(f64::MAX), - None, - Some(f64::NAN), - Some(f64::INFINITY), - Some(f64::INFINITY), - Some(0.0), - Some(0.0), - ]), - as_float64_array - ); - - eval_array_ansi_mode!(Decimal128Array::from(vec![Some(i128::MIN), None]) - .with_precision_and_scale(38, 37) - .unwrap()); - eval_array_ansi_mode!(Decimal256Array::from(vec![Some(i256::MIN), None]) - .with_precision_and_scale(5, 2) - .unwrap()); - eval_array_ansi_mode!(IntervalYearMonthArray::from(vec![i32::MIN, -1])); - eval_array_ansi_mode!(IntervalDayTimeArray::from(vec![IntervalDayTime::new( - i32::MIN, - i32::MIN, - )])); - } } From d330e536c78cac3190bdb1e165fa230e00f8233b Mon Sep 17 00:00:00 2001 From: huanghsiang_cheng Date: Tue, 11 Nov 2025 15:18:51 -0800 Subject: [PATCH 19/25] Remove Interval types --- datafusion/spark/src/function/math/abs.rs | 59 +++---------------- .../test_files/spark/math/abs.slt | 4 +- 2 files changed, 9 insertions(+), 54 deletions(-) diff --git a/datafusion/spark/src/function/math/abs.rs b/datafusion/spark/src/function/math/abs.rs index e622f232c8a4..818c6da0dc11 100644 --- a/datafusion/spark/src/function/math/abs.rs +++ b/datafusion/spark/src/function/math/abs.rs @@ -25,8 +25,7 @@ use datafusion_expr::{ ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility, }; use datafusion_functions::{ - downcast_named_arg, make_abs_function, - make_wrapping_abs_function, + downcast_named_arg, make_abs_function, make_wrapping_abs_function, }; use std::any::Any; use std::sync::Arc; @@ -98,16 +97,14 @@ impl ScalarUDFImpl for SparkAbs { | DataType::Float32 | DataType::Float64 | DataType::Decimal128(_, _) - | DataType::Decimal256(_, _) - | DataType::Interval(IntervalUnit::YearMonth) - | DataType::Interval(IntervalUnit::DayTime) => Ok(vec![arg_types[0].clone()]), + | DataType::Decimal256(_, _) => Ok(vec![arg_types[0].clone()]), other => { if other.is_numeric() { Ok(vec![DataType::Float64]) } else { Err(unsupported_data_type_exec_err( "abs", - "Numeric Type or ANSI Interval Type", + "Numeric Type", &arg_types[0], )) } @@ -154,8 +151,8 @@ pub fn spark_abs(args: &[ColumnarValue]) -> Result { - let abs_fun = make_wrapping_abs_function!(Int32Array); - abs_fun(array).map(ColumnarValue::Array) + let abs_fun = make_wrapping_abs_function!(Int32Array); + abs_fun(array).map(ColumnarValue::Array) } DataType::Int64 => { let abs_fun = make_wrapping_abs_function!(Int64Array); @@ -177,19 +174,6 @@ pub fn spark_abs(args: &[ColumnarValue]) -> Result match unit { - IntervalUnit::YearMonth => { - let abs_fun = make_wrapping_abs_function!(IntervalYearMonthArray); - abs_fun(array).map(ColumnarValue::Array) - } - IntervalUnit::DayTime => { - let abs_fun = make_wrapping_abs_function!(IntervalDayTimeArray); - abs_fun(array).map(ColumnarValue::Array) - } - IntervalUnit::MonthDayNano => internal_err!( - "MonthDayNano is not a supported Interval unit for Spark ABS" - ), - }, dt => internal_err!("Not supported datatype for Spark ABS: {dt}"), }, ColumnarValue::Scalar(sv) => match sv { @@ -234,15 +218,6 @@ pub fn spark_abs(args: &[ColumnarValue]) -> Result match a { - None => Ok(args[0].clone()), - Some(v) => scalar_compute_op!(v, IntervalYearMonth), - }, - ScalarValue::IntervalDayTime(a) => match a { - None => Ok(args[0].clone()), - Some(v) => scalar_compute_op!(v, IntervalDayTime), - }, - dt => internal_err!("Not supported datatype for Spark ABS: {dt}"), }, } @@ -308,7 +283,7 @@ mod tests { #[test] fn test_abs_scalar_legacy_mode() { - // NumericType, DayTimeIntervalType, and YearMonthIntervalType MIN + // NumericType MIN eval_legacy_mode!(UInt8, u8::MIN); eval_legacy_mode!(UInt16, u16::MIN); eval_legacy_mode!(UInt32, u32::MIN); @@ -321,22 +296,14 @@ mod tests { eval_legacy_mode!(Float64, f64::MIN, f64::MAX); eval_legacy_mode!(Decimal128, i128::MIN, 18, 10); eval_legacy_mode!(Decimal256, i256::MIN, 10, 2); - eval_legacy_mode!(IntervalYearMonth, i32::MIN); - eval_legacy_mode!(IntervalDayTime, IntervalDayTime::MIN); - // NumericType, DayTimeIntervalType, and YearMonthIntervalType not MIN + // NumericType not MIN eval_legacy_mode!(Int8, -1i8, 1i8); eval_legacy_mode!(Int16, -1i16, 1i16); eval_legacy_mode!(Int32, -1i32, 1i32); eval_legacy_mode!(Int64, -1i64, 1i64); eval_legacy_mode!(Decimal128, -1i128, 18, 10, 1i128); eval_legacy_mode!(Decimal256, i256::from(-1i8), 10, 2, i256::from(1i8)); - eval_legacy_mode!(IntervalYearMonth, -1i32, 1i32); - eval_legacy_mode!( - IntervalDayTime, - IntervalDayTime::new(-1i32, -1i32), - IntervalDayTime::new(1i32, 1i32) - ); // Float32, Float64 eval_legacy_mode!(Float32, f32::NEG_INFINITY, f32::INFINITY); @@ -461,17 +428,5 @@ mod tests { .unwrap(), as_decimal256_array ); - - eval_array_legacy_mode!( - IntervalYearMonthArray::from(vec![i32::MIN, -1]), - IntervalYearMonthArray::from(vec![i32::MIN, 1]), - as_interval_ym_array - ); - - eval_array_legacy_mode!( - IntervalDayTimeArray::from(vec![IntervalDayTime::new(i32::MIN, i32::MIN,)]), - IntervalDayTimeArray::from(vec![IntervalDayTime::new(i32::MIN, i32::MIN,)]), - as_interval_dt_array - ); } } diff --git a/datafusion/sqllogictest/test_files/spark/math/abs.slt b/datafusion/sqllogictest/test_files/spark/math/abs.slt index d279d00c5d0c..5e70fe6c3eae 100644 --- a/datafusion/sqllogictest/test_files/spark/math/abs.slt +++ b/datafusion/sqllogictest/test_files/spark/math/abs.slt @@ -91,5 +91,5 @@ drop table test_nullable_decimal ## Original Query: SELECT abs(INTERVAL -'1-1' YEAR TO MONTH); ## PySpark 3.5.5 Result: {"abs(INTERVAL '-1-1' YEAR TO MONTH)": 13, "typeof(abs(INTERVAL '-1-1' YEAR TO MONTH))": 'interval year to month', "typeof(INTERVAL '-1-1' YEAR TO MONTH)": 'interval year to month'} -query error DataFusion error: This feature is not implemented: Unsupported SQL type INTERVAL YEAR TO MONTH -SELECT abs(INTERVAL '-1-1' YEAR TO MONTH::interval year to month); +#query +#SELECT abs(INTERVAL '-1-1' YEAR TO MONTH::interval year to month); From ebb8ba3e59fc167bffef161ff6e251f0c9013520 Mon Sep 17 00:00:00 2001 From: huanghsiang_cheng Date: Tue, 11 Nov 2025 15:40:04 -0800 Subject: [PATCH 20/25] Revert debug formatter --- datafusion/functions/src/math/abs.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/datafusion/functions/src/math/abs.rs b/datafusion/functions/src/math/abs.rs index 179cb1391775..35d0f3eccf57 100644 --- a/datafusion/functions/src/math/abs.rs +++ b/datafusion/functions/src/math/abs.rs @@ -57,7 +57,7 @@ macro_rules! make_try_abs_function { let res: $ARRAY_TYPE = array.try_unary(|x| { x.checked_abs().ok_or_else(|| { ArrowError::ComputeError(format!( - "{} overflow on abs({:?})", + "{} overflow on abs({})", stringify!($ARRAY_TYPE), x )) From db0cbc10901b49ca26ffd1145e1a21ad152d2010 Mon Sep 17 00:00:00 2001 From: huanghsiang_cheng Date: Mon, 17 Nov 2025 22:09:13 -0800 Subject: [PATCH 21/25] remove coerce_types --- datafusion/spark/src/function/math/abs.rs | 32 ----------------------- 1 file changed, 32 deletions(-) diff --git a/datafusion/spark/src/function/math/abs.rs b/datafusion/spark/src/function/math/abs.rs index 818c6da0dc11..2c44a979b328 100644 --- a/datafusion/spark/src/function/math/abs.rs +++ b/datafusion/spark/src/function/math/abs.rs @@ -79,38 +79,6 @@ impl ScalarUDFImpl for SparkAbs { fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { spark_abs(&args.args) } - - fn coerce_types(&self, arg_types: &[DataType]) -> Result> { - if arg_types.len() != 1 { - return Err(invalid_arg_count_exec_err("abs", (1, 1), arg_types.len())); - } - match &arg_types[0] { - DataType::Null - | DataType::UInt8 - | DataType::UInt16 - | DataType::UInt32 - | DataType::UInt64 - | DataType::Int8 - | DataType::Int16 - | DataType::Int32 - | DataType::Int64 - | DataType::Float32 - | DataType::Float64 - | DataType::Decimal128(_, _) - | DataType::Decimal256(_, _) => Ok(vec![arg_types[0].clone()]), - other => { - if other.is_numeric() { - Ok(vec![DataType::Float64]) - } else { - Err(unsupported_data_type_exec_err( - "abs", - "Numeric Type", - &arg_types[0], - )) - } - } - } - } } macro_rules! scalar_compute_op { From 9bc101cd6b2409944a7e4ca71d9fce768367baaa Mon Sep 17 00:00:00 2001 From: huanghsiang_cheng Date: Mon, 17 Nov 2025 22:15:47 -0800 Subject: [PATCH 22/25] Null check first in matching --- datafusion/spark/src/function/math/abs.rs | 56 +++++++---------------- 1 file changed, 17 insertions(+), 39 deletions(-) diff --git a/datafusion/spark/src/function/math/abs.rs b/datafusion/spark/src/function/math/abs.rs index 2c44a979b328..f48f8964c28c 100644 --- a/datafusion/spark/src/function/math/abs.rs +++ b/datafusion/spark/src/function/math/abs.rs @@ -15,9 +15,6 @@ // specific language governing permissions and limitations // under the License. -use crate::function::error_utils::{ - invalid_arg_count_exec_err, unsupported_data_type_exec_err, -}; use arrow::array::*; use arrow::datatypes::DataType; use datafusion_common::{internal_err, DataFusionError, Result, ScalarValue}; @@ -150,42 +147,23 @@ pub fn spark_abs(args: &[ColumnarValue]) -> Result Ok(args[0].clone()), - ScalarValue::Int8(a) => match a { - None => Ok(args[0].clone()), - Some(v) => scalar_compute_op!(v, Int8), - }, - ScalarValue::Int16(a) => match a { - None => Ok(args[0].clone()), - Some(v) => scalar_compute_op!(v, Int16), - }, - ScalarValue::Int32(a) => match a { - None => Ok(args[0].clone()), - Some(v) => scalar_compute_op!(v, Int32), - }, - ScalarValue::Int64(a) => match a { - None => Ok(args[0].clone()), - Some(v) => scalar_compute_op!(v, Int64), - }, - ScalarValue::Float32(a) => match a { - None => Ok(args[0].clone()), - Some(v) => Ok(ColumnarValue::Scalar(ScalarValue::Float32(Some(v.abs())))), - }, - ScalarValue::Float64(a) => match a { - None => Ok(args[0].clone()), - Some(v) => Ok(ColumnarValue::Scalar(ScalarValue::Float64(Some(v.abs())))), - }, - ScalarValue::Decimal128(a, precision, scale) => match a { - None => Ok(args[0].clone()), - Some(v) => { - scalar_compute_op!(v, *precision, *scale, Decimal128) - } - }, - ScalarValue::Decimal256(a, precision, scale) => match a { - None => Ok(args[0].clone()), - Some(v) => { - scalar_compute_op!(v, *precision, *scale, Decimal256) - } - }, + sv if sv.is_null() => Ok(args[0].clone()), + ScalarValue::Int8(Some(v)) => scalar_compute_op!(v, Int8), + ScalarValue::Int16(Some(v)) => scalar_compute_op!(v, Int16), + ScalarValue::Int32(Some(v)) => scalar_compute_op!(v, Int32), + ScalarValue::Int64(Some(v)) => scalar_compute_op!(v, Int64), + ScalarValue::Float32(Some(v)) => { + Ok(ColumnarValue::Scalar(ScalarValue::Float32(Some(v.abs())))) + } + ScalarValue::Float64(Some(v)) => { + Ok(ColumnarValue::Scalar(ScalarValue::Float64(Some(v.abs())))) + } + ScalarValue::Decimal128(Some(v), precision, scale) => { + scalar_compute_op!(v, *precision, *scale, Decimal128) + } + ScalarValue::Decimal256(Some(v), precision, scale) => { + scalar_compute_op!(v, *precision, *scale, Decimal256) + } dt => internal_err!("Not supported datatype for Spark ABS: {dt}"), }, } From b7aff745ff950ca36155c90f8c6126382eea501b Mon Sep 17 00:00:00 2001 From: huanghsiang_cheng Date: Mon, 17 Nov 2025 22:51:41 -0800 Subject: [PATCH 23/25] Record 2 GitHub issues --- datafusion/sqllogictest/test_files/spark/math/abs.slt | 2 ++ 1 file changed, 2 insertions(+) diff --git a/datafusion/sqllogictest/test_files/spark/math/abs.slt b/datafusion/sqllogictest/test_files/spark/math/abs.slt index 5e70fe6c3eae..d44b1ef33429 100644 --- a/datafusion/sqllogictest/test_files/spark/math/abs.slt +++ b/datafusion/sqllogictest/test_files/spark/math/abs.slt @@ -31,6 +31,7 @@ SELECT abs(-127::TINYINT), abs(-32767::SMALLINT), abs(-2147483647::INT), abs(-92 127 32767 2147483647 9223372036854775807 NULL +# See https://github.com/apache/datafusion/issues/18794 for operator precedence # abs: signed int minimal values query IIII select abs((-128)::TINYINT), abs((-32768)::SMALLINT), abs((-2147483648)::INT), abs((-9223372036854775808)::BIGINT) @@ -93,3 +94,4 @@ drop table test_nullable_decimal ## PySpark 3.5.5 Result: {"abs(INTERVAL '-1-1' YEAR TO MONTH)": 13, "typeof(abs(INTERVAL '-1-1' YEAR TO MONTH))": 'interval year to month', "typeof(INTERVAL '-1-1' YEAR TO MONTH)": 'interval year to month'} #query #SELECT abs(INTERVAL '-1-1' YEAR TO MONTH::interval year to month); +# See GitHub issue for ANSI interval support: https://github.com/apache/datafusion/issues/18793 From 0a06d1c82873c35027ece2ebc6cd7f97dd85c234 Mon Sep 17 00:00:00 2001 From: huanghsiang_cheng Date: Tue, 18 Nov 2025 21:48:00 -0800 Subject: [PATCH 24/25] Test -0, inf, -inf --- .../sqllogictest/test_files/spark/math/abs.slt | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/datafusion/sqllogictest/test_files/spark/math/abs.slt b/datafusion/sqllogictest/test_files/spark/math/abs.slt index d44b1ef33429..46e4dcebd72a 100644 --- a/datafusion/sqllogictest/test_files/spark/math/abs.slt +++ b/datafusion/sqllogictest/test_files/spark/math/abs.slt @@ -39,16 +39,16 @@ select abs((-128)::TINYINT), abs((-32768)::SMALLINT), abs((-2147483648)::INT), a -128 -32768 -2147483648 -9223372036854775808 # abs: floats, NULL and NaN -query RRRR -SELECT abs(-1.0::FLOAT), abs(0.::FLOAT), abs(NULL::FLOAT), abs('NaN'::FLOAT) +query RRRRRRRRRRRR +SELECT abs(-1.0::FLOAT), abs(0.::FLOAT), abs(-0.::FLOAT), abs(-0::FLOAT), abs(NULL::FLOAT), abs('NaN'::FLOAT), abs('inf'::FLOAT), abs('+inf'::FLOAT), abs('-inf'::FLOAT), abs('infinity'::FLOAT), abs('+infinity'::FLOAT), abs('-infinity'::FLOAT) ---- -1 0 NULL NaN +1 0 0 0 NULL NaN Infinity Infinity Infinity Infinity Infinity Infinity # abs: doubles, NULL and NaN -query RRRR -SELECT abs(-1.0::DOUBLE), abs(0.::DOUBLE), abs(NULL::DOUBLE), abs('NaN'::DOUBLE) +query RRRRRRRRRRRR +SELECT abs(-1.0::DOUBLE), abs(0.::DOUBLE), abs(-0.::DOUBLE), abs(-0::DOUBLE), abs(NULL::DOUBLE), abs('NaN'::DOUBLE), abs('inf'::DOUBLE), abs('+inf'::DOUBLE), abs('-inf'::DOUBLE), abs('infinity'::DOUBLE), abs('+infinity'::DOUBLE), abs('-infinity'::DOUBLE) ---- -1 0 NULL NaN +1 0 0 0 NULL NaN Infinity Infinity Infinity Infinity Infinity Infinity # abs: decimal128 and decimal256 statement ok From 3349f1c2914db14edcc7cf2fde4dd7c8797beb2d Mon Sep 17 00:00:00 2001 From: huanghsiang_cheng Date: Tue, 18 Nov 2025 22:00:00 -0800 Subject: [PATCH 25/25] Fix test comments --- datafusion/sqllogictest/test_files/spark/math/abs.slt | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/datafusion/sqllogictest/test_files/spark/math/abs.slt b/datafusion/sqllogictest/test_files/spark/math/abs.slt index 46e4dcebd72a..19ca902ea3de 100644 --- a/datafusion/sqllogictest/test_files/spark/math/abs.slt +++ b/datafusion/sqllogictest/test_files/spark/math/abs.slt @@ -38,13 +38,13 @@ select abs((-128)::TINYINT), abs((-32768)::SMALLINT), abs((-2147483648)::INT), a ---- -128 -32768 -2147483648 -9223372036854775808 -# abs: floats, NULL and NaN +# abs: floats, NULL, NaN, -0, infinity, -infinity query RRRRRRRRRRRR SELECT abs(-1.0::FLOAT), abs(0.::FLOAT), abs(-0.::FLOAT), abs(-0::FLOAT), abs(NULL::FLOAT), abs('NaN'::FLOAT), abs('inf'::FLOAT), abs('+inf'::FLOAT), abs('-inf'::FLOAT), abs('infinity'::FLOAT), abs('+infinity'::FLOAT), abs('-infinity'::FLOAT) ---- 1 0 0 0 NULL NaN Infinity Infinity Infinity Infinity Infinity Infinity -# abs: doubles, NULL and NaN +# abs: doubles, NULL, NaN, -0, infinity, -infinity query RRRRRRRRRRRR SELECT abs(-1.0::DOUBLE), abs(0.::DOUBLE), abs(-0.::DOUBLE), abs(-0::DOUBLE), abs(NULL::DOUBLE), abs('NaN'::DOUBLE), abs('inf'::DOUBLE), abs('+inf'::DOUBLE), abs('-inf'::DOUBLE), abs('infinity'::DOUBLE), abs('+infinity'::DOUBLE), abs('-infinity'::DOUBLE) ----