diff --git a/datafusion/spark/src/function/datetime/date_add.rs b/datafusion/spark/src/function/datetime/date_add.rs index a00430febcdb..457d4d476dce 100644 --- a/datafusion/spark/src/function/datetime/date_add.rs +++ b/datafusion/spark/src/function/datetime/date_add.rs @@ -25,6 +25,7 @@ use arrow::error::ArrowError; use datafusion_common::cast::{ as_date32_array, as_int16_array, as_int32_array, as_int8_array, }; +use datafusion_common::utils::take_function_args; use datafusion_common::{internal_err, Result}; use datafusion_expr::{ ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, TypeSignature, @@ -87,12 +88,7 @@ impl ScalarUDFImpl for SparkDateAdd { } fn spark_date_add(args: &[ArrayRef]) -> Result { - let [date_arg, days_arg] = args else { - return internal_err!( - "Spark `date_add` function requires 2 arguments, got {}", - args.len() - ); - }; + let [date_arg, days_arg] = take_function_args("date_add", args)?; let date_array = as_date32_array(date_arg)?; let result = match days_arg.data_type() { DataType::Int8 => { diff --git a/datafusion/spark/src/function/datetime/last_day.rs b/datafusion/spark/src/function/datetime/last_day.rs index c01a6403649c..b75f10ad5e42 100644 --- a/datafusion/spark/src/function/datetime/last_day.rs +++ b/datafusion/spark/src/function/datetime/last_day.rs @@ -21,6 +21,7 @@ use std::sync::Arc; use arrow::array::{ArrayRef, AsArray, Date32Array}; use arrow::datatypes::{DataType, Date32Type}; use chrono::{Datelike, Duration, NaiveDate}; +use datafusion_common::utils::take_function_args; use datafusion_common::{exec_datafusion_err, internal_err, Result, ScalarValue}; use datafusion_expr::{ ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility, @@ -64,17 +65,12 @@ impl ScalarUDFImpl for SparkLastDay { fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { let ScalarFunctionArgs { args, .. } = args; - let [arg] = args.as_slice() else { - return internal_err!( - "Spark `last_day` function requires 1 argument, got {}", - args.len() - ); - }; + let [arg] = take_function_args("last_day", args)?; match arg { ColumnarValue::Scalar(ScalarValue::Date32(days)) => { if let Some(days) = days { Ok(ColumnarValue::Scalar(ScalarValue::Date32(Some( - spark_last_day(*days)?, + spark_last_day(days)?, )))) } else { Ok(ColumnarValue::Scalar(ScalarValue::Date32(None))) diff --git a/datafusion/spark/src/function/math/factorial.rs b/datafusion/spark/src/function/math/factorial.rs index 4921e73d262a..5cf33d6073e5 100644 --- a/datafusion/spark/src/function/math/factorial.rs +++ b/datafusion/spark/src/function/math/factorial.rs @@ -22,7 +22,9 @@ use arrow::array::{Array, Int64Array}; use arrow::datatypes::DataType; use arrow::datatypes::DataType::{Int32, Int64}; use datafusion_common::cast::as_int32_array; -use datafusion_common::{exec_err, internal_err, DataFusionError, Result, ScalarValue}; +use datafusion_common::{ + exec_err, utils::take_function_args, DataFusionError, Result, ScalarValue, +}; use datafusion_expr::Signature; use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Volatility}; @@ -99,11 +101,9 @@ const FACTORIALS: [i64; 21] = [ ]; pub fn spark_factorial(args: &[ColumnarValue]) -> Result { - if args.len() != 1 { - return internal_err!("`factorial` expects exactly one argument"); - } + let [arg] = take_function_args("factorial", args)?; - match &args[0] { + match arg { ColumnarValue::Scalar(ScalarValue::Int32(value)) => { let result = compute_factorial(*value); Ok(ColumnarValue::Scalar(ScalarValue::Int64(result))) diff --git a/datafusion/spark/src/function/math/hex.rs b/datafusion/spark/src/function/math/hex.rs index cdd13e903326..7029b5e43490 100644 --- a/datafusion/spark/src/function/math/hex.rs +++ b/datafusion/spark/src/function/math/hex.rs @@ -28,9 +28,10 @@ use arrow::{ datatypes::Int32Type, }; use datafusion_common::cast::as_string_view_array; +use datafusion_common::utils::take_function_args; use datafusion_common::{ cast::{as_binary_array, as_fixed_size_binary_array, as_int64_array}, - exec_err, internal_err, DataFusionError, + exec_err, DataFusionError, }; use datafusion_expr::Signature; use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Volatility}; @@ -184,13 +185,9 @@ pub fn compute_hex( args: &[ColumnarValue], lowercase: bool, ) -> Result { - if args.len() != 1 { - return internal_err!("hex expects exactly one argument"); - } - - let input = match &args[0] { - ColumnarValue::Scalar(value) => ColumnarValue::Array(value.to_array()?), - ColumnarValue::Array(_) => args[0].clone(), + let input = match take_function_args("hex", args)? { + [ColumnarValue::Scalar(value)] => ColumnarValue::Array(value.to_array()?), + [ColumnarValue::Array(arr)] => ColumnarValue::Array(Arc::clone(arr)), }; match &input { diff --git a/datafusion/spark/src/function/math/modulus.rs b/datafusion/spark/src/function/math/modulus.rs index fea0297a7ae9..b894c8cad521 100644 --- a/datafusion/spark/src/function/math/modulus.rs +++ b/datafusion/spark/src/function/math/modulus.rs @@ -18,7 +18,9 @@ use arrow::compute::kernels::numeric::add; use arrow::compute::kernels::{cmp::lt, numeric::rem, zip::zip}; use arrow::datatypes::DataType; -use datafusion_common::{internal_err, Result, ScalarValue}; +use datafusion_common::{ + assert_eq_or_internal_err, internal_err, DataFusionError, Result, ScalarValue, +}; use datafusion_expr::{ ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility, }; @@ -27,9 +29,7 @@ use std::any::Any; /// Spark-compatible `mod` function /// This function directly uses Arrow's arithmetic_op function for modulo operations pub fn spark_mod(args: &[ColumnarValue]) -> Result { - if args.len() != 2 { - return internal_err!("mod expects exactly two arguments"); - } + assert_eq_or_internal_err!(args.len(), 2, "mod expects exactly two arguments"); let args = ColumnarValue::values_to_arrays(args)?; let result = rem(&args[0], &args[1])?; Ok(ColumnarValue::Array(result)) @@ -38,9 +38,7 @@ pub fn spark_mod(args: &[ColumnarValue]) -> Result { /// Spark-compatible `pmod` function /// This function directly uses Arrow's arithmetic_op function for modulo operations pub fn spark_pmod(args: &[ColumnarValue]) -> Result { - if args.len() != 2 { - return internal_err!("pmod expects exactly two arguments"); - } + assert_eq_or_internal_err!(args.len(), 2, "pmod expects exactly two arguments"); let args = ColumnarValue::values_to_arrays(args)?; let left = &args[0]; let right = &args[1]; diff --git a/datafusion/spark/src/function/math/rint.rs b/datafusion/spark/src/function/math/rint.rs index 9b61529c5bc4..aae5455df0e1 100644 --- a/datafusion/spark/src/function/math/rint.rs +++ b/datafusion/spark/src/function/math/rint.rs @@ -15,6 +15,7 @@ // specific language governing permissions and limitations // under the License. +use datafusion_common::DataFusionError; use std::any::Any; use std::sync::Arc; @@ -24,7 +25,7 @@ use arrow::datatypes::DataType::{ Float32, Float64, Int16, Int32, Int64, Int8, UInt16, UInt32, UInt64, UInt8, }; use arrow::datatypes::{DataType, Float32Type, Float64Type}; -use datafusion_common::{exec_err, Result}; +use datafusion_common::{assert_eq_or_internal_err, exec_err, Result}; use datafusion_expr::sort_properties::{ExprProperties, SortProperties}; use datafusion_expr::{ ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility, @@ -84,7 +85,7 @@ impl ScalarUDFImpl for SparkRint { pub fn spark_rint(args: &[ArrayRef]) -> Result { if args.len() != 1 { - return exec_err!("rint expects exactly 1 argument, got {}", args.len()); + assert_eq_or_internal_err!(args.len(), 1, "`rint` expects exactly one argument"); } let array: &dyn Array = args[0].as_ref();