Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 2 additions & 6 deletions datafusion/spark/src/function/datetime/date_add.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ use arrow::error::ArrowError;
use datafusion_common::cast::{
as_date32_array, as_int16_array, as_int32_array, as_int8_array,
};
use datafusion_common::utils::take_function_args;
use datafusion_common::{internal_err, Result};
use datafusion_expr::{
ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, TypeSignature,
Expand Down Expand Up @@ -87,12 +88,7 @@ impl ScalarUDFImpl for SparkDateAdd {
}

fn spark_date_add(args: &[ArrayRef]) -> Result<ArrayRef> {
let [date_arg, days_arg] = args else {
return internal_err!(
"Spark `date_add` function requires 2 arguments, got {}",
args.len()
);
};
let [date_arg, days_arg] = take_function_args("date_add", args)?;
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Handled this with take_function_args as suggested. Noticed a few functions that have not used this, thinking ill include those in a separate PR.

let date_array = as_date32_array(date_arg)?;
let result = match days_arg.data_type() {
DataType::Int8 => {
Expand Down
10 changes: 3 additions & 7 deletions datafusion/spark/src/function/datetime/last_day.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ use std::sync::Arc;
use arrow::array::{ArrayRef, AsArray, Date32Array};
use arrow::datatypes::{DataType, Date32Type};
use chrono::{Datelike, Duration, NaiveDate};
use datafusion_common::utils::take_function_args;
use datafusion_common::{exec_datafusion_err, internal_err, Result, ScalarValue};
use datafusion_expr::{
ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility,
Expand Down Expand Up @@ -64,17 +65,12 @@ impl ScalarUDFImpl for SparkLastDay {

fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
let ScalarFunctionArgs { args, .. } = args;
let [arg] = args.as_slice() else {
return internal_err!(
"Spark `last_day` function requires 1 argument, got {}",
args.len()
);
};
let [arg] = take_function_args("last_day", args)?;
match arg {
ColumnarValue::Scalar(ScalarValue::Date32(days)) => {
if let Some(days) = days {
Ok(ColumnarValue::Scalar(ScalarValue::Date32(Some(
spark_last_day(*days)?,
spark_last_day(days)?,
))))
} else {
Ok(ColumnarValue::Scalar(ScalarValue::Date32(None)))
Expand Down
10 changes: 5 additions & 5 deletions datafusion/spark/src/function/math/factorial.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,9 @@ use arrow::array::{Array, Int64Array};
use arrow::datatypes::DataType;
use arrow::datatypes::DataType::{Int32, Int64};
use datafusion_common::cast::as_int32_array;
use datafusion_common::{exec_err, internal_err, DataFusionError, Result, ScalarValue};
use datafusion_common::{
exec_err, utils::take_function_args, DataFusionError, Result, ScalarValue,
};
use datafusion_expr::Signature;
use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Volatility};

Expand Down Expand Up @@ -99,11 +101,9 @@ const FACTORIALS: [i64; 21] = [
];

pub fn spark_factorial(args: &[ColumnarValue]) -> Result<ColumnarValue, DataFusionError> {
if args.len() != 1 {
return internal_err!("`factorial` expects exactly one argument");
}
let [arg] = take_function_args("factorial", args)?;

match &args[0] {
match arg {
ColumnarValue::Scalar(ScalarValue::Int32(value)) => {
let result = compute_factorial(*value);
Ok(ColumnarValue::Scalar(ScalarValue::Int64(result)))
Expand Down
13 changes: 5 additions & 8 deletions datafusion/spark/src/function/math/hex.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,10 @@ use arrow::{
datatypes::Int32Type,
};
use datafusion_common::cast::as_string_view_array;
use datafusion_common::utils::take_function_args;
use datafusion_common::{
cast::{as_binary_array, as_fixed_size_binary_array, as_int64_array},
exec_err, internal_err, DataFusionError,
exec_err, DataFusionError,
};
use datafusion_expr::Signature;
use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Volatility};
Expand Down Expand Up @@ -184,13 +185,9 @@ pub fn compute_hex(
args: &[ColumnarValue],
lowercase: bool,
) -> Result<ColumnarValue, DataFusionError> {
if args.len() != 1 {
return internal_err!("hex expects exactly one argument");
}

let input = match &args[0] {
ColumnarValue::Scalar(value) => ColumnarValue::Array(value.to_array()?),
ColumnarValue::Array(_) => args[0].clone(),
let input = match take_function_args("hex", args)? {
[ColumnarValue::Scalar(value)] => ColumnarValue::Array(value.to_array()?),
[ColumnarValue::Array(arr)] => ColumnarValue::Array(Arc::clone(arr)),
};

match &input {
Expand Down
12 changes: 5 additions & 7 deletions datafusion/spark/src/function/math/modulus.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@
use arrow::compute::kernels::numeric::add;
use arrow::compute::kernels::{cmp::lt, numeric::rem, zip::zip};
use arrow::datatypes::DataType;
use datafusion_common::{internal_err, Result, ScalarValue};
use datafusion_common::{
assert_eq_or_internal_err, internal_err, DataFusionError, Result, ScalarValue,
};
use datafusion_expr::{
ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility,
};
Expand All @@ -27,9 +29,7 @@ use std::any::Any;
/// Spark-compatible `mod` function
/// This function directly uses Arrow's arithmetic_op function for modulo operations
pub fn spark_mod(args: &[ColumnarValue]) -> Result<ColumnarValue> {
if args.len() != 2 {
return internal_err!("mod expects exactly two arguments");
}
assert_eq_or_internal_err!(args.len(), 2, "mod expects exactly two arguments");
let args = ColumnarValue::values_to_arrays(args)?;
let result = rem(&args[0], &args[1])?;
Ok(ColumnarValue::Array(result))
Expand All @@ -38,9 +38,7 @@ pub fn spark_mod(args: &[ColumnarValue]) -> Result<ColumnarValue> {
/// Spark-compatible `pmod` function
/// This function directly uses Arrow's arithmetic_op function for modulo operations
pub fn spark_pmod(args: &[ColumnarValue]) -> Result<ColumnarValue> {
if args.len() != 2 {
return internal_err!("pmod expects exactly two arguments");
}
assert_eq_or_internal_err!(args.len(), 2, "pmod expects exactly two arguments");
let args = ColumnarValue::values_to_arrays(args)?;
let left = &args[0];
let right = &args[1];
Expand Down
5 changes: 3 additions & 2 deletions datafusion/spark/src/function/math/rint.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
// specific language governing permissions and limitations
// under the License.

use datafusion_common::DataFusionError;
use std::any::Any;
use std::sync::Arc;

Expand All @@ -24,7 +25,7 @@ use arrow::datatypes::DataType::{
Float32, Float64, Int16, Int32, Int64, Int8, UInt16, UInt32, UInt64, UInt8,
};
use arrow::datatypes::{DataType, Float32Type, Float64Type};
use datafusion_common::{exec_err, Result};
use datafusion_common::{assert_eq_or_internal_err, exec_err, Result};
use datafusion_expr::sort_properties::{ExprProperties, SortProperties};
use datafusion_expr::{
ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility,
Expand Down Expand Up @@ -84,7 +85,7 @@ impl ScalarUDFImpl for SparkRint {

pub fn spark_rint(args: &[ArrayRef]) -> Result<ArrayRef> {
if args.len() != 1 {
return exec_err!("rint expects exactly 1 argument, got {}", args.len());
assert_eq_or_internal_err!(args.len(), 1, "`rint` expects exactly one argument");
}

let array: &dyn Array = args[0].as_ref();
Expand Down