Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
231 changes: 161 additions & 70 deletions crates/runtime/src/datafusion/functions/date_add.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
use arrow::array::{Array, ArrayRef};
use arrow::compute::kernels::numeric::add_wrapping;
use arrow::datatypes::DataType;
use arrow::datatypes::DataType::{Date32, Date64, Int64, Time32, Time64, Timestamp, Utf8};
use arrow::datatypes::TimeUnit::{Microsecond, Millisecond, Nanosecond, Second};
use datafusion::common::{plan_err, Result};
use datafusion::logical_expr::TypeSignature::Exact;
use datafusion::logical_expr::{
ColumnarValue, ScalarUDFImpl, Signature, Volatility, TIMEZONE_WILDCARD,
};
use datafusion::logical_expr::TypeSignature::Coercible;
use datafusion::logical_expr::TypeSignatureClass;
use datafusion::logical_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility};
use datafusion::scalar::ScalarValue;
use datafusion_common::types::{logical_int64, logical_string};
use std::any::Any;
use std::sync::Arc;

#[derive(Debug)]
pub struct DateAddFunc {
Expand All @@ -28,39 +29,20 @@ impl DateAddFunc {
Self {
signature: Signature::one_of(
vec![
Exact(vec![Utf8, Int64, Date32]),
Exact(vec![Utf8, Int64, Date64]),
Exact(vec![Utf8, Int64, Time32(Second)]),
Exact(vec![Utf8, Int64, Time32(Nanosecond)]),
Exact(vec![Utf8, Int64, Time32(Microsecond)]),
Exact(vec![Utf8, Int64, Time32(Millisecond)]),
Exact(vec![Utf8, Int64, Time64(Second)]),
Exact(vec![Utf8, Int64, Time64(Nanosecond)]),
Exact(vec![Utf8, Int64, Time64(Microsecond)]),
Exact(vec![Utf8, Int64, Time64(Millisecond)]),
Exact(vec![Utf8, Int64, Timestamp(Second, None)]),
Exact(vec![Utf8, Int64, Timestamp(Millisecond, None)]),
Exact(vec![Utf8, Int64, Timestamp(Microsecond, None)]),
Exact(vec![Utf8, Int64, Timestamp(Nanosecond, None)]),
Exact(vec![
Utf8,
Int64,
Timestamp(Second, Some(TIMEZONE_WILDCARD.into())),
Coercible(vec![
TypeSignatureClass::Native(logical_string()),
TypeSignatureClass::Native(logical_int64()),
TypeSignatureClass::Timestamp,
]),
Exact(vec![
Utf8,
Int64,
Timestamp(Millisecond, Some(TIMEZONE_WILDCARD.into())),
Coercible(vec![
TypeSignatureClass::Native(logical_string()),
TypeSignatureClass::Native(logical_int64()),
TypeSignatureClass::Time,
]),
Exact(vec![
Utf8,
Int64,
Timestamp(Microsecond, Some(TIMEZONE_WILDCARD.into())),
]),
Exact(vec![
Utf8,
Int64,
Timestamp(Nanosecond, Some(TIMEZONE_WILDCARD.into())),
Coercible(vec![
TypeSignatureClass::Native(logical_string()),
TypeSignatureClass::Native(logical_int64()),
TypeSignatureClass::Date,
]),
],
Volatility::Immutable,
Expand All @@ -75,38 +57,35 @@ impl DateAddFunc {
}
}

fn add_nanoseconds(val: &ScalarValue, nanoseconds: i64) -> Result<ColumnarValue> {
Ok(ColumnarValue::Scalar(
val.add(ScalarValue::DurationNanosecond(Some(nanoseconds)))
.unwrap_or(ScalarValue::DurationNanosecond(Some(0))),
fn add_years(val: &Arc<dyn Array>, years: i64) -> Result<ArrayRef> {
let years = ColumnarValue::Scalar(ScalarValue::new_interval_ym(
i32::try_from(years).unwrap_or(0),
0,
))
.to_array(val.len())?;
Ok(add_wrapping(&val, &years)?)
}
fn add_years(val: &ScalarValue, years: i64) -> Result<ColumnarValue> {
Ok(ColumnarValue::Scalar(
val.add(ScalarValue::new_interval_ym(
i32::try_from(years).unwrap_or(0),
0,
))
.unwrap_or_else(|_| ScalarValue::new_interval_ym(0, 0)),
fn add_months(val: &Arc<dyn Array>, months: i64) -> Result<ArrayRef> {
let months = ColumnarValue::Scalar(ScalarValue::new_interval_ym(
0,
i32::try_from(months).unwrap_or(0),
))
.to_array(val.len())?;
Ok(add_wrapping(&val, &months)?)
}
fn add_months(val: &ScalarValue, months: i64) -> Result<ColumnarValue> {
Ok(ColumnarValue::Scalar(
val.add(ScalarValue::new_interval_ym(
0,
i32::try_from(months).unwrap_or(0),
))
.unwrap_or_else(|_| ScalarValue::new_interval_ym(0, 0)),
fn add_days(val: &Arc<dyn Array>, days: i64) -> Result<ArrayRef> {
let days = ColumnarValue::Scalar(ScalarValue::new_interval_dt(
i32::try_from(days).unwrap_or(0),
0,
))
.to_array(val.len())?;
Ok(add_wrapping(&val, &days)?)
}
fn add_days(val: &ScalarValue, days: i64) -> Result<ColumnarValue> {
Ok(ColumnarValue::Scalar(
val.add(ScalarValue::new_interval_dt(
i32::try_from(days).unwrap_or(0),
0,
))
.unwrap_or_else(|_| ScalarValue::new_interval_dt(0, 0)),
))

fn add_nanoseconds(val: &Arc<dyn Array>, nanoseconds: i64) -> Result<ArrayRef> {
let nanoseconds = ColumnarValue::Scalar(ScalarValue::new_interval_mdn(0, 0, nanoseconds))
.to_array(val.len())?;
Ok(add_wrapping(&val, &nanoseconds)?)
}
}

Expand Down Expand Up @@ -153,12 +132,11 @@ impl ScalarUDFImpl for DateAddFunc {
}
Ok(arg_types[2].clone())
}

fn invoke(&self, args: &[ColumnarValue]) -> Result<ColumnarValue> {
fn invoke_with_args(&self, args: datafusion_expr::ScalarFunctionArgs) -> Result<ColumnarValue> {
let args = args.args;
if args.len() != 3 {
return plan_err!("function requires three arguments");
}

let date_or_time_part = match &args[0] {
ColumnarValue::Scalar(ScalarValue::Utf8(Some(part))) => part.clone(),
_ => return plan_err!("Invalid unit type format"),
Expand All @@ -168,12 +146,12 @@ impl ScalarUDFImpl for DateAddFunc {
ColumnarValue::Scalar(ScalarValue::Int64(Some(val))) => *val,
_ => return plan_err!("Invalid value type"),
};
let date_or_time_expr = match &args[2] {
ColumnarValue::Scalar(val) => val.clone(),
ColumnarValue::Array(array) => ScalarValue::try_from_array(&array, 0)?,
let (is_scalar, date_or_time_expr) = match &args[2] {
ColumnarValue::Scalar(val) => (true, val.to_array()?),
ColumnarValue::Array(array) => (false, array.clone()),
};
//there shouldn't be overflows
match date_or_time_part.as_str() {
let result = match date_or_time_part.as_str() {
//should consider leap year (365-366 days)
"year" | "y" | "yy" | "yyy" | "yyyy" | "yr" | "years" => {
Self::add_years(&date_or_time_expr, value)
Expand Down Expand Up @@ -208,11 +186,124 @@ impl ScalarUDFImpl for DateAddFunc {
"nanosecond" | "ns" | "nsec" | "nanosec" | "nsecond" | "nanoseconds" | "nanosecs"
| "nseconds" => Self::add_nanoseconds(&date_or_time_expr, value),
_ => plan_err!("Invalid date_or_time_part type"),
};
if is_scalar {
let result = result.and_then(|array| ScalarValue::try_from_array(&array, 0));
return result.map(ColumnarValue::Scalar);
}
result.map(ColumnarValue::Array)
}
fn aliases(&self) -> &[String] {
&self.aliases
}
}

super::macros::make_udf_function!(DateAddFunc);
#[cfg(test)]
#[allow(clippy::unwrap_in_result)]
mod tests {
use super::DateAddFunc;
use datafusion_common::ScalarValue;
use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl};
use std::sync::Arc;

#[test]
fn test_date_add_days_timestamp() {
let args = vec![
ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("days")))),
ColumnarValue::Scalar(ScalarValue::Int64(Some(5i64))),
ColumnarValue::Scalar(ScalarValue::TimestampMicrosecond(
Some(1736168400000000i64),
Some(Arc::from(String::from("+00").into_boxed_str())),
)),
];
let fn_args = ScalarFunctionArgs {
args: args,
number_rows: 0,
return_type: &arrow_schema::DataType::Timestamp(
arrow_schema::TimeUnit::Microsecond,
Some(Arc::from(String::from("+00").into_boxed_str())),
),
};
match DateAddFunc::new().invoke_with_args(fn_args) {
Ok(ColumnarValue::Scalar(result)) => {
let expected = ScalarValue::TimestampMicrosecond(
Some(1736600400000000i64),
Some(Arc::from(String::from("+00").into_boxed_str())),
);
assert_eq!(&result, &expected, "date_add created a wrong value")
}
_ => panic!("Conversion failed"),
}
}
#[test]
fn test_date_add_days_timestamp_array() {
let args = vec![
ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("days")))),
ColumnarValue::Scalar(ScalarValue::Int64(Some(5i64))),
ColumnarValue::Array(
ScalarValue::TimestampMicrosecond(
Some(1736168400000000i64),
Some(Arc::from(String::from("+00").into_boxed_str())),
)
.to_array()
.unwrap(),
),
];
let fn_args = ScalarFunctionArgs {
args: args,
number_rows: 0,
return_type: &arrow_schema::DataType::Timestamp(
arrow_schema::TimeUnit::Microsecond,
Some(Arc::from(String::from("+00").into_boxed_str())),
),
};
match DateAddFunc::new().invoke_with_args(fn_args) {
Ok(ColumnarValue::Array(result)) => {
let expected = ScalarValue::TimestampMicrosecond(
Some(1736600400000000i64),
Some(Arc::from(String::from("+00").into_boxed_str())),
)
.to_array()
.unwrap();
assert_eq!(&result, &expected, "date_add created a wrong value")
}
_ => panic!("Conversion failed"),
}
}
#[test]
fn test_date_add_days_timestamp_array_multiple_values() {
let args = vec![
ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("days")))),
ColumnarValue::Scalar(ScalarValue::Int64(Some(5i64))),
ColumnarValue::Array(
ColumnarValue::Scalar(ScalarValue::TimestampMicrosecond(
Some(1736168400000000i64),
Some(Arc::from(String::from("+00").into_boxed_str())),
))
.to_array(2)
.unwrap(),
),
];
let fn_args = ScalarFunctionArgs {
args: args,
number_rows: 0,
return_type: &arrow_schema::DataType::Timestamp(
arrow_schema::TimeUnit::Microsecond,
Some(Arc::from(String::from("+00").into_boxed_str())),
),
};
match DateAddFunc::new().invoke_with_args(fn_args) {
Ok(ColumnarValue::Array(result)) => {
let expected = ColumnarValue::Scalar(ScalarValue::TimestampMicrosecond(
Some(1736600400000000i64),
Some(Arc::from(String::from("+00").into_boxed_str())),
))
.to_array(2)
.unwrap();
assert_eq!(&result, &expected, "date_add created a wrong value")
}
_ => panic!("Conversion failed"),
}
}
}
Loading