Skip to content

Commit

Permalink
Add datetime/interval/duration into dyn scalar comparison (#3730)
Browse files Browse the repository at this point in the history
* Add datatime/interval/duration into comparison

* Add some tests
  • Loading branch information
viirya committed Feb 21, 2023
1 parent 18388b2 commit 61ea9f2
Showing 1 changed file with 219 additions and 0 deletions.
219 changes: 219 additions & 0 deletions arrow-ord/src/comparison.rs
Original file line number Diff line number Diff line change
Expand Up @@ -523,6 +523,91 @@ macro_rules! dyn_compare_scalar {
let left = as_primitive_array::<Decimal128Type>($LEFT);
$OP::<Decimal128Type>(left, right)
}
DataType::Date32 => {
let right = try_to_type!($RIGHT, to_i32)?;
let left = as_primitive_array::<Date32Type>($LEFT);
$OP::<Date32Type>(left, right)
}
DataType::Date64 => {
let right = try_to_type!($RIGHT, to_i64)?;
let left = as_primitive_array::<Date64Type>($LEFT);
$OP::<Date64Type>(left, right)
}
DataType::Timestamp(TimeUnit::Nanosecond, _) => {
let right = try_to_type!($RIGHT, to_i64)?;
let left = as_primitive_array::<TimestampNanosecondType>($LEFT);
$OP::<TimestampNanosecondType>(left, right)
}
DataType::Timestamp(TimeUnit::Microsecond, _) => {
let right = try_to_type!($RIGHT, to_i64)?;
let left = as_primitive_array::<TimestampMicrosecondType>($LEFT);
$OP::<TimestampMicrosecondType>(left, right)
}
DataType::Timestamp(TimeUnit::Millisecond, _) => {
let right = try_to_type!($RIGHT, to_i64)?;
let left = as_primitive_array::<TimestampMillisecondType>($LEFT);
$OP::<TimestampMillisecondType>(left, right)
}
DataType::Timestamp(TimeUnit::Second, _) => {
let right = try_to_type!($RIGHT, to_i64)?;
let left = as_primitive_array::<TimestampSecondType>($LEFT);
$OP::<TimestampSecondType>(left, right)
}
DataType::Time32(TimeUnit::Second) => {
let right = try_to_type!($RIGHT, to_i32)?;
let left = as_primitive_array::<Time32SecondType>($LEFT);
$OP::<Time32SecondType>(left, right)
}
DataType::Time32(TimeUnit::Millisecond) => {
let right = try_to_type!($RIGHT, to_i32)?;
let left = as_primitive_array::<Time32MillisecondType>($LEFT);
$OP::<Time32MillisecondType>(left, right)
}
DataType::Time64(TimeUnit::Microsecond) => {
let right = try_to_type!($RIGHT, to_i64)?;
let left = as_primitive_array::<Time64MicrosecondType>($LEFT);
$OP::<Time64MicrosecondType>(left, right)
}
DataType::Time64(TimeUnit::Nanosecond) => {
let right = try_to_type!($RIGHT, to_i64)?;
let left = as_primitive_array::<Time64NanosecondType>($LEFT);
$OP::<Time64NanosecondType>(left, right)
}
DataType::Interval(IntervalUnit::YearMonth) => {
let right = try_to_type!($RIGHT, to_i32)?;
let left = as_primitive_array::<IntervalYearMonthType>($LEFT);
$OP::<IntervalYearMonthType>(left, right)
}
DataType::Interval(IntervalUnit::DayTime) => {
let right = try_to_type!($RIGHT, to_i64)?;
let left = as_primitive_array::<IntervalDayTimeType>($LEFT);
$OP::<IntervalDayTimeType>(left, right)
}
DataType::Interval(IntervalUnit::MonthDayNano) => {
let right = try_to_type!($RIGHT, to_i128)?;
let left = as_primitive_array::<IntervalMonthDayNanoType>($LEFT);
$OP::<IntervalMonthDayNanoType>(left, right)
}
DataType::Duration(TimeUnit::Second) => {
let right = try_to_type!($RIGHT, to_i64)?;
let left = as_primitive_array::<DurationSecondType>($LEFT);
$OP::<DurationSecondType>(left, right)
}
DataType::Duration(TimeUnit::Millisecond) => {
let right = try_to_type!($RIGHT, to_i64)?;
let left = as_primitive_array::<DurationMillisecondType>($LEFT);
$OP::<DurationMillisecondType>(left, right)
}
DataType::Duration(TimeUnit::Microsecond) => {
let right = try_to_type!($RIGHT, to_i64)?;
let left = as_primitive_array::<DurationMicrosecondType>($LEFT);
$OP::<DurationMicrosecondType>(left, right)
}
DataType::Duration(TimeUnit::Nanosecond) => {
let right = try_to_type!($RIGHT, to_i64)?;
let left = as_primitive_array::<DurationNanosecondType>($LEFT);
$OP::<DurationNanosecondType>(left, right)
}
_ => Err(ArrowError::ComputeError(format!(
"Unsupported data type {:?} for comparison {} with {:?}",
$LEFT.data_type(),
Expand Down Expand Up @@ -1707,6 +1792,22 @@ macro_rules! typed_compares {
DataType::Interval(IntervalUnit::MonthDayNano),
DataType::Interval(IntervalUnit::MonthDayNano),
) => cmp_primitive_array::<IntervalMonthDayNanoType, _>($LEFT, $RIGHT, $OP),
(
DataType::Duration(TimeUnit::Second),
DataType::Duration(TimeUnit::Second),
) => cmp_primitive_array::<DurationSecondType, _>($LEFT, $RIGHT, $OP),
(
DataType::Duration(TimeUnit::Millisecond),
DataType::Duration(TimeUnit::Millisecond),
) => cmp_primitive_array::<DurationMillisecondType, _>($LEFT, $RIGHT, $OP),
(
DataType::Duration(TimeUnit::Microsecond),
DataType::Duration(TimeUnit::Microsecond),
) => cmp_primitive_array::<DurationMicrosecondType, _>($LEFT, $RIGHT, $OP),
(
DataType::Duration(TimeUnit::Nanosecond),
DataType::Duration(TimeUnit::Nanosecond),
) => cmp_primitive_array::<DurationNanosecondType, _>($LEFT, $RIGHT, $OP),
(t1, t2) if t1 == t2 => Err(ArrowError::NotYetImplemented(format!(
"Comparing arrays of type {} is not yet implemented",
t1
Expand Down Expand Up @@ -4006,6 +4107,124 @@ mod tests {
);
}

fn test_primitive_dyn_scalar<T: ArrowPrimitiveType>(array: PrimitiveArray<T>) {
let a_eq = eq_dyn_scalar(&array, 8).unwrap();
assert_eq!(
a_eq,
BooleanArray::from(vec![Some(false), None, Some(true), None, Some(false)])
);

let a_eq = gt_eq_dyn_scalar(&array, 8).unwrap();
assert_eq!(
a_eq,
BooleanArray::from(vec![Some(false), None, Some(true), None, Some(true)])
);

let a_eq = gt_dyn_scalar(&array, 8).unwrap();
assert_eq!(
a_eq,
BooleanArray::from(vec![Some(false), None, Some(false), None, Some(true)])
);

let a_eq = lt_eq_dyn_scalar(&array, 8).unwrap();
assert_eq!(
a_eq,
BooleanArray::from(vec![Some(true), None, Some(true), None, Some(false)])
);

let a_eq = lt_dyn_scalar(&array, 8).unwrap();
assert_eq!(
a_eq,
BooleanArray::from(vec![Some(true), None, Some(false), None, Some(false)])
);
}

#[test]
fn test_timestamp_dyn_scalar() {
let array =
TimestampSecondArray::from(vec![Some(1), None, Some(8), None, Some(10)]);
test_primitive_dyn_scalar(array);

let array =
TimestampMicrosecondArray::from(vec![Some(1), None, Some(8), None, Some(10)]);
test_primitive_dyn_scalar(array);

let array =
TimestampMicrosecondArray::from(vec![Some(1), None, Some(8), None, Some(10)]);
test_primitive_dyn_scalar(array);

let array =
TimestampNanosecondArray::from(vec![Some(1), None, Some(8), None, Some(10)]);
test_primitive_dyn_scalar(array);
}

#[test]
fn test_date32_dyn_scalar() {
let array = Date32Array::from(vec![Some(1), None, Some(8), None, Some(10)]);
test_primitive_dyn_scalar(array);
}

#[test]
fn test_date64_dyn_scalar() {
let array = Date64Array::from(vec![Some(1), None, Some(8), None, Some(10)]);
test_primitive_dyn_scalar(array);
}

#[test]
fn test_time32_dyn_scalar() {
let array = Time32SecondArray::from(vec![Some(1), None, Some(8), None, Some(10)]);
test_primitive_dyn_scalar(array);

let array =
Time32MillisecondArray::from(vec![Some(1), None, Some(8), None, Some(10)]);
test_primitive_dyn_scalar(array);
}

#[test]
fn test_time64_dyn_scalar() {
let array =
Time64MicrosecondArray::from(vec![Some(1), None, Some(8), None, Some(10)]);
test_primitive_dyn_scalar(array);

let array =
Time64NanosecondArray::from(vec![Some(1), None, Some(8), None, Some(10)]);
test_primitive_dyn_scalar(array);
}

#[test]
fn test_interval_dyn_scalar() {
let array =
IntervalDayTimeArray::from(vec![Some(1), None, Some(8), None, Some(10)]);
test_primitive_dyn_scalar(array);

let array =
IntervalMonthDayNanoArray::from(vec![Some(1), None, Some(8), None, Some(10)]);
test_primitive_dyn_scalar(array);

let array =
IntervalYearMonthArray::from(vec![Some(1), None, Some(8), None, Some(10)]);
test_primitive_dyn_scalar(array);
}

#[test]
fn test_duration_dyn_scalar() {
let array =
DurationSecondArray::from(vec![Some(1), None, Some(8), None, Some(10)]);
test_primitive_dyn_scalar(array);

let array =
DurationMicrosecondArray::from(vec![Some(1), None, Some(8), None, Some(10)]);
test_primitive_dyn_scalar(array);

let array =
DurationMillisecondArray::from(vec![Some(1), None, Some(8), None, Some(10)]);
test_primitive_dyn_scalar(array);

let array =
DurationNanosecondArray::from(vec![Some(1), None, Some(8), None, Some(10)]);
test_primitive_dyn_scalar(array);
}

#[test]
fn test_lt_eq_dyn_scalar_with_dict() {
let mut builder =
Expand Down

0 comments on commit 61ea9f2

Please sign in to comment.