diff --git a/datafusion/src/scalar.rs b/datafusion/src/scalar.rs index 3896055e9233..3fbcadd3de5a 100644 --- a/datafusion/src/scalar.rs +++ b/datafusion/src/scalar.rs @@ -28,6 +28,7 @@ use arrow::{ }, }; use ordered_float::OrderedFloat; +use std::cmp::Ordering; use std::convert::{Infallible, TryInto}; use std::str::FromStr; use std::{convert::TryFrom, fmt, iter::repeat, sync::Arc}; @@ -156,6 +157,81 @@ impl PartialEq for ScalarValue { } } +// manual implementation of `PartialOrd` that uses OrderedFloat to +// get defined behavior for floating point +impl PartialOrd for ScalarValue { + fn partial_cmp(&self, other: &Self) -> Option { + use ScalarValue::*; + // This purposely doesn't have a catch-all "(_, _)" so that + // any newly added enum variant will require editing this list + // or else face a compile error + match (self, other) { + (Boolean(v1), Boolean(v2)) => v1.partial_cmp(v2), + (Boolean(_), _) => None, + (Float32(v1), Float32(v2)) => { + let v1 = v1.map(OrderedFloat); + let v2 = v2.map(OrderedFloat); + v1.partial_cmp(&v2) + } + (Float32(_), _) => None, + (Float64(v1), Float64(v2)) => { + let v1 = v1.map(OrderedFloat); + let v2 = v2.map(OrderedFloat); + v1.partial_cmp(&v2) + } + (Float64(_), _) => None, + (Int8(v1), Int8(v2)) => v1.partial_cmp(v2), + (Int8(_), _) => None, + (Int16(v1), Int16(v2)) => v1.partial_cmp(v2), + (Int16(_), _) => None, + (Int32(v1), Int32(v2)) => v1.partial_cmp(v2), + (Int32(_), _) => None, + (Int64(v1), Int64(v2)) => v1.partial_cmp(v2), + (Int64(_), _) => None, + (UInt8(v1), UInt8(v2)) => v1.partial_cmp(v2), + (UInt8(_), _) => None, + (UInt16(v1), UInt16(v2)) => v1.partial_cmp(v2), + (UInt16(_), _) => None, + (UInt32(v1), UInt32(v2)) => v1.partial_cmp(v2), + (UInt32(_), _) => None, + (UInt64(v1), UInt64(v2)) => v1.partial_cmp(v2), + (UInt64(_), _) => None, + (Utf8(v1), Utf8(v2)) => v1.partial_cmp(v2), + (Utf8(_), _) => None, + (LargeUtf8(v1), LargeUtf8(v2)) => v1.partial_cmp(v2), + (LargeUtf8(_), _) => None, + (Binary(v1), Binary(v2)) => v1.partial_cmp(v2), + (Binary(_), _) => None, + (LargeBinary(v1), LargeBinary(v2)) => v1.partial_cmp(v2), + (LargeBinary(_), _) => None, + (List(v1, t1), List(v2, t2)) => { + if t1.eq(t2) { + v1.partial_cmp(v2) + } else { + None + } + } + (List(_, _), _) => None, + (Date32(v1), Date32(v2)) => v1.partial_cmp(v2), + (Date32(_), _) => None, + (Date64(v1), Date64(v2)) => v1.partial_cmp(v2), + (Date64(_), _) => None, + (TimestampSecond(v1), TimestampSecond(v2)) => v1.partial_cmp(v2), + (TimestampSecond(_), _) => None, + (TimestampMillisecond(v1), TimestampMillisecond(v2)) => v1.partial_cmp(v2), + (TimestampMillisecond(_), _) => None, + (TimestampMicrosecond(v1), TimestampMicrosecond(v2)) => v1.partial_cmp(v2), + (TimestampMicrosecond(_), _) => None, + (TimestampNanosecond(v1), TimestampNanosecond(v2)) => v1.partial_cmp(v2), + (TimestampNanosecond(_), _) => None, + (IntervalYearMonth(v1), IntervalYearMonth(v2)) => v1.partial_cmp(v2), + (IntervalYearMonth(_), _) => None, + (IntervalDayTime(v1), IntervalDayTime(v2)) => v1.partial_cmp(v2), + (IntervalDayTime(_), _) => None, + } + } +} + impl Eq for ScalarValue {} // manual implementation of `Hash` that uses OrderedFloat to @@ -1577,4 +1653,74 @@ mod tests { // per distinct value. assert_eq!(std::mem::size_of::(), 32); } + + #[test] + fn scalar_partial_ordering() { + use ScalarValue::*; + + assert_eq!( + Int64(Some(33)).partial_cmp(&Int64(Some(0))), + Some(Ordering::Greater) + ); + assert_eq!( + Int64(Some(0)).partial_cmp(&Int64(Some(33))), + Some(Ordering::Less) + ); + assert_eq!( + Int64(Some(33)).partial_cmp(&Int64(Some(33))), + Some(Ordering::Equal) + ); + // For different data type, `partial_cmp` returns None. + assert_eq!(Int64(Some(33)).partial_cmp(&Int32(Some(33))), None); + assert_eq!(Int32(Some(33)).partial_cmp(&Int64(Some(33))), None); + + assert_eq!( + List( + Some(Box::new(vec![Int32(Some(1)), Int32(Some(5))])), + Box::new(DataType::Int32) + ) + .partial_cmp(&List( + Some(Box::new(vec![Int32(Some(1)), Int32(Some(5))])), + Box::new(DataType::Int32) + )), + Some(Ordering::Equal) + ); + + assert_eq!( + List( + Some(Box::new(vec![Int32(Some(10)), Int32(Some(5))])), + Box::new(DataType::Int32) + ) + .partial_cmp(&List( + Some(Box::new(vec![Int32(Some(1)), Int32(Some(5))])), + Box::new(DataType::Int32) + )), + Some(Ordering::Greater) + ); + + assert_eq!( + List( + Some(Box::new(vec![Int32(Some(1)), Int32(Some(5))])), + Box::new(DataType::Int32) + ) + .partial_cmp(&List( + Some(Box::new(vec![Int32(Some(10)), Int32(Some(5))])), + Box::new(DataType::Int32) + )), + Some(Ordering::Less) + ); + + // For different data type, `partial_cmp` returns None. + assert_eq!( + List( + Some(Box::new(vec![Int64(Some(1)), Int64(Some(5))])), + Box::new(DataType::Int64) + ) + .partial_cmp(&List( + Some(Box::new(vec![Int32(Some(1)), Int32(Some(5))])), + Box::new(DataType::Int32) + )), + None + ); + } }