diff --git a/datafusion/common/src/parsers.rs b/datafusion/common/src/parsers.rs index 94ad5d86be30..9016103d9702 100644 --- a/datafusion/common/src/parsers.rs +++ b/datafusion/common/src/parsers.rs @@ -16,12 +16,52 @@ // under the License. //! Interval parsing logic +use sqlparser::parser::ParserError; + use crate::{DataFusionError, Result, ScalarValue}; use std::str::FromStr; const SECONDS_PER_HOUR: f64 = 3_600_f64; const NANOS_PER_SECOND: f64 = 1_000_000_000_f64; +#[derive(Clone, Copy)] +#[repr(u16)] +enum IntervalType { + Century = 0b_00_0000_0001, + Decade = 0b_00_0000_0010, + Year = 0b_00_0000_0100, + Month = 0b_00_0000_1000, + Week = 0b_00_0001_0000, + Day = 0b_00_0010_0000, + Hour = 0b_00_0100_0000, + Minute = 0b_00_1000_0000, + Second = 0b_01_0000_0000, + Millisecond = 0b_10_0000_0000, +} + +impl FromStr for IntervalType { + type Err = DataFusionError; + + fn from_str(s: &str) -> Result { + match s.to_lowercase().as_str() { + "century" | "centuries" => Ok(Self::Century), + "decade" | "decades" => Ok(Self::Decade), + "year" | "years" => Ok(Self::Year), + "month" | "months" => Ok(Self::Month), + "week" | "weeks" => Ok(Self::Week), + "day" | "days" => Ok(Self::Day), + "hour" | "hours" => Ok(Self::Hour), + "minute" | "minutes" => Ok(Self::Minute), + "second" | "seconds" => Ok(Self::Second), + "millisecond" | "milliseconds" => Ok(Self::Millisecond), + _ => Err(DataFusionError::NotImplemented(format!( + "Unknown interval type: {}", + s + ))), + } + } +} + /// Parses a string with an interval like `'0.5 MONTH'` to an /// appropriately typed [`ScalarValue`] pub fn parse_interval(leading_field: &str, value: &str) -> Result { @@ -42,8 +82,10 @@ pub fn parse_interval(leading_field: &str, value: &str) -> Result { (month_part as i64, day_part as i64, nanos_part) }; - let calculate_from_part = |interval_period_str: &str, - interval_type: &str| + let mut used_interval_types = 0; + + let mut calculate_from_part = |interval_period_str: &str, + interval_type: &str| -> Result<(i64, i64, f64)> { // @todo It's better to use Decimal in order to protect rounding errors // Wait https://github.com/apache/arrow/pull/9232 @@ -64,33 +106,46 @@ pub fn parse_interval(leading_field: &str, value: &str) -> Result { ))); } - match interval_type.to_lowercase().as_str() { - "century" | "centuries" => { + let it = IntervalType::from_str(interval_type).map_err(|_| { + DataFusionError::NotImplemented(format!( + "Invalid input syntax for type interval: {:?}", + value + )) + })?; + + // Disallow duplicate interval types + if used_interval_types & (it as u16) != 0 { + return Err(DataFusionError::SQL(ParserError::ParserError(format!( + "Invalid input syntax for type interval: {:?}. Repeated type '{}'", + value, interval_type + )))); + } else { + used_interval_types |= it as u16; + } + + match it { + IntervalType::Century => { Ok(align_interval_parts(interval_period * 1200_f64, 0.0, 0.0)) } - "decade" | "decades" => { + IntervalType::Decade => { Ok(align_interval_parts(interval_period * 120_f64, 0.0, 0.0)) } - "year" | "years" => { + IntervalType::Year => { Ok(align_interval_parts(interval_period * 12_f64, 0.0, 0.0)) } - "month" | "months" => Ok(align_interval_parts(interval_period, 0.0, 0.0)), - "week" | "weeks" => { + IntervalType::Month => Ok(align_interval_parts(interval_period, 0.0, 0.0)), + IntervalType::Week => { Ok(align_interval_parts(0.0, interval_period * 7_f64, 0.0)) } - "day" | "days" => Ok(align_interval_parts(0.0, interval_period, 0.0)), - "hour" | "hours" => { + IntervalType::Day => Ok(align_interval_parts(0.0, interval_period, 0.0)), + IntervalType::Hour => { Ok((0, 0, interval_period * SECONDS_PER_HOUR * NANOS_PER_SECOND)) } - "minute" | "minutes" => { + IntervalType::Minute => { Ok((0, 0, interval_period * 60_f64 * NANOS_PER_SECOND)) } - "second" | "seconds" => Ok((0, 0, interval_period * NANOS_PER_SECOND)), - "millisecond" | "milliseconds" => Ok((0, 0, interval_period * 1_000_000f64)), - _ => Err(DataFusionError::NotImplemented(format!( - "Invalid input syntax for type interval: {:?}", - value - ))), + IntervalType::Second => Ok((0, 0, interval_period * NANOS_PER_SECOND)), + IntervalType::Millisecond => Ok((0, 0, interval_period * 1_000_000f64)), } }; @@ -234,4 +289,14 @@ mod test { ScalarValue::new_interval_mdn(12, 1, 1_00 * 1_000) ); } + + #[test] + fn test_duplicate_interval_type() { + let err = parse_interval("months", "1 month 1 second 1 second") + .expect_err("parsing interval should have failed"); + assert_eq!( + r#"SQL(ParserError("Invalid input syntax for type interval: \"1 month 1 second 1 second\". Repeated type 'second'"))"#, + format!("{:?}", err) + ); + } }