From 45ffb4a4afdcd9b003f54c5d7af749662cd10fa9 Mon Sep 17 00:00:00 2001 From: Mustafa Akur Date: Thu, 6 Oct 2022 16:03:43 +0300 Subject: [PATCH 01/13] Support for non u64 values inside window frames --- datafusion/common/Cargo.toml | 3 +- datafusion/common/src/lib.rs | 2 + .../src/interval.rs => common/src/parsers.rs} | 33 +++- datafusion/common/src/scalar.rs | 23 +++ datafusion/core/Cargo.toml | 3 +- datafusion/core/src/physical_plan/planner.rs | 129 ++++++++++++- .../core/src/physical_plan/windows/mod.rs | 10 +- datafusion/core/tests/sql/window.rs | 82 ++++++++ datafusion/expr/Cargo.toml | 2 +- datafusion/expr/src/window_frame.rs | 181 +++++++++++++----- .../physical-expr/src/window/aggregate.rs | 159 +++++++-------- datafusion/proto/proto/datafusion.proto | 2 +- datafusion/proto/src/from_proto.rs | 4 +- datafusion/proto/src/to_proto.rs | 33 +++- datafusion/sql/Cargo.toml | 2 +- datafusion/sql/src/lib.rs | 1 - datafusion/sql/src/planner.rs | 37 +--- datafusion/sql/src/utils.rs | 2 +- 18 files changed, 514 insertions(+), 194 deletions(-) rename datafusion/{sql/src/interval.rs => common/src/parsers.rs} (88%) diff --git a/datafusion/common/Cargo.toml b/datafusion/common/Cargo.toml index 382f66e5dd98..9a72ac25dbdb 100644 --- a/datafusion/common/Cargo.toml +++ b/datafusion/common/Cargo.toml @@ -46,4 +46,5 @@ object_store = { version = "0.5.0", default-features = false, optional = true } ordered-float = "3.0" parquet = { version = "24.0.0", default-features = false, optional = true } pyo3 = { version = "0.17.1", optional = true } -sqlparser = "0.25" +# sqlparser = "0.25" +sqlparser = { git = "https://github.com/synnada-ai/sqlparser-rs", branch = "main" } diff --git a/datafusion/common/src/lib.rs b/datafusion/common/src/lib.rs index 864172d96a01..ba0f37631011 100644 --- a/datafusion/common/src/lib.rs +++ b/datafusion/common/src/lib.rs @@ -20,6 +20,7 @@ mod column; mod dfschema; mod error; pub mod from_slice; +pub mod parsers; #[cfg(feature = "pyarrow")] mod pyarrow; pub mod scalar; @@ -27,6 +28,7 @@ pub mod scalar; pub use column::Column; pub use dfschema::{DFField, DFSchema, DFSchemaRef, ExprSchema, ToDFSchema}; pub use error::{field_not_found, DataFusionError, Result, SchemaError}; +pub use parsers::parse_interval; pub use scalar::{ScalarType, ScalarValue}; /// Downcast an Arrow Array to a concrete type, return an `DataFusionError::Internal` if the cast is diff --git a/datafusion/sql/src/interval.rs b/datafusion/common/src/parsers.rs similarity index 88% rename from datafusion/sql/src/interval.rs rename to datafusion/common/src/parsers.rs index dbdd038aec71..0318060e5b4f 100644 --- a/datafusion/sql/src/interval.rs +++ b/datafusion/common/src/parsers.rs @@ -16,7 +16,7 @@ // under the License. //! Interval parsing logic -use datafusion_common::{DataFusionError, Result, ScalarValue}; +use crate::{DataFusionError, Result, ScalarValue}; use std::str::FromStr; const SECONDS_PER_HOUR: f32 = 3_600_f32; @@ -24,7 +24,7 @@ const MILLIS_PER_SECOND: f32 = 1_000_f32; /// Parses a string with an interval like `'0.5 MONTH'` to an /// appropriately typed [`ScalarValue`] -pub(crate) fn parse_interval(leading_field: &str, value: &str) -> Result { +pub fn parse_interval(leading_field: &str, value: &str) -> Result { // We are storing parts as integers, it's why we need to align parts fractional // INTERVAL '0.5 MONTH' = 15 days, INTERVAL '1.5 MONTH' = 1 month 15 days // INTERVAL '0.5 DAY' = 12 hours, INTERVAL '1.5 DAY' = 1 day 12 hours @@ -144,9 +144,9 @@ pub(crate) fn parse_interval(leading_field: &str, value: &str) -> Result Result) +#[macro_export] +macro_rules! assert_contains { + ($ACTUAL: expr, $EXPECTED: expr) => { + let actual_value: String = $ACTUAL.into(); + let expected_value: String = $EXPECTED.into(); + assert!( + actual_value.contains(&expected_value), + "Can not find expected in actual.\n\nExpected:\n{}\n\nActual:\n{}", + expected_value, + actual_value + ); + }; +} + #[cfg(test)] mod test { use crate::assert_contains; diff --git a/datafusion/common/src/scalar.rs b/datafusion/common/src/scalar.rs index c3f91dd9b1d1..6cbe7713d3e7 100644 --- a/datafusion/common/src/scalar.rs +++ b/datafusion/common/src/scalar.rs @@ -111,6 +111,21 @@ pub enum ScalarValue { Dictionary(Box, Box), } +pub fn get_scalar_value(number: &str) -> Result { + let my_string = number.to_string(); + if my_string.contains('.') { + let res = my_string.parse::().map_err(|_| { + DataFusionError::Internal(format!("couldn\'t parse {}", my_string)) + })?; + Ok(ScalarValue::Float64(Some(res))) + } else { + let res = my_string.parse::().map_err(|_| { + DataFusionError::Internal(format!("couldn\'t parse {}", my_string)) + })?; + Ok(ScalarValue::UInt64(Some(res))) + } +} + // manual implementation of `PartialEq` that uses OrderedFloat to // get defined behavior for floating point impl PartialEq for ScalarValue { @@ -490,6 +505,10 @@ macro_rules! impl_op { macro_rules! impl_distinct_cases_op { ($LHS:expr, $RHS:expr, +) => { match ($LHS, $RHS) { + ( + ScalarValue::TimestampNanosecond(Some(lhs), None), + ScalarValue::TimestampNanosecond(Some(rhs), None), + ) => Ok(ScalarValue::TimestampNanosecond(Some(lhs + rhs), None)), e => Err(DataFusionError::Internal(format!( "Addition is not implemented for {:?}", e @@ -498,6 +517,10 @@ macro_rules! impl_distinct_cases_op { }; ($LHS:expr, $RHS:expr, -) => { match ($LHS, $RHS) { + ( + ScalarValue::TimestampNanosecond(Some(lhs), None), + ScalarValue::TimestampNanosecond(Some(rhs), None), + ) => Ok(ScalarValue::TimestampNanosecond(Some(lhs - rhs), None)), e => Err(DataFusionError::Internal(format!( "Subtraction is not implemented for {:?}", e diff --git a/datafusion/core/Cargo.toml b/datafusion/core/Cargo.toml index 44a5cbe1afb4..e99b9ca11ba9 100644 --- a/datafusion/core/Cargo.toml +++ b/datafusion/core/Cargo.toml @@ -89,7 +89,8 @@ pyo3 = { version = "0.17.1", optional = true } rand = "0.8" rayon = { version = "1.5", optional = true } smallvec = { version = "1.6", features = ["union"] } -sqlparser = "0.25" +# sqlparser = "0.25" +sqlparser = { git = "https://github.com/synnada-ai/sqlparser-rs", branch = "main" } tempfile = "3" tokio = { version = "1.0", features = ["macros", "rt", "rt-multi-thread", "sync", "fs", "parking_lot"] } tokio-stream = "0.1" diff --git a/datafusion/core/src/physical_plan/planner.rs b/datafusion/core/src/physical_plan/planner.rs index 8bb1d95a48a6..8ad2f8ca95df 100644 --- a/datafusion/core/src/physical_plan/planner.rs +++ b/datafusion/core/src/physical_plan/planner.rs @@ -57,12 +57,13 @@ use crate::{ physical_plan::displayable, }; use arrow::compute::SortOptions; -use arrow::datatypes::{Schema, SchemaRef}; +use arrow::datatypes::{Schema, SchemaRef, TimeUnit}; use async_trait::async_trait; +use datafusion_common::scalar::TryFromValue; use datafusion_common::ScalarValue; use datafusion_expr::expr::GroupingSet; use datafusion_expr::utils::{expand_wildcard, expr_to_columns}; -use datafusion_expr::WindowFrameUnits; +use datafusion_expr::{WindowFrameBound, WindowFrameUnits}; use datafusion_physical_expr::expressions::Literal; use datafusion_sql::utils::window_expr_common_partition_keys; use futures::future::BoxFuture; @@ -1369,6 +1370,97 @@ fn get_physical_expr_pair( Ok((physical_expr, physical_name)) } +fn convert_to_column_type( + order_by: &[PhysicalSortExpr], + physical_input_schema: &Schema, + in_scalar: &ScalarValue, +) -> Result { + // Below query may produce error, it will not produce error + // if ScalarValues is not Null + let column_type = order_by + .first() + .ok_or_else(|| { + DataFusionError::Internal("Order By column cannot be empty".to_string()) + })? + .expr + .data_type(physical_input_schema); + match in_scalar { + ScalarValue::Null => Ok(ScalarValue::Null), + ScalarValue::UInt64(Some(val)) => { + ScalarValue::try_from_value(&column_type?, *val) + } + ScalarValue::Float64(Some(val)) => { + ScalarValue::try_from_value(&column_type?, *val) + } + ScalarValue::IntervalDayTime(Some(val)) => { + // source val in millisecond precision + // convert IntervalDayTime to days and millisecond part + // TODO: below operation can be done with shift operationa + let denom = 2_i64.pow(32) as i64; + let days = val / denom; + let milli = val % denom; + println!("days: {}, milli :{}", days, milli); + let interval_in_milli = days * 24 * 60 * 60 * 1000 + milli; + Ok(match &column_type? { + arrow::datatypes::DataType::Timestamp(TimeUnit::Second, tz_opt) => { + ScalarValue::TimestampSecond( + Some(interval_in_milli / 1000), + tz_opt.clone(), + ) + } + arrow::datatypes::DataType::Timestamp(TimeUnit::Millisecond, tz_opt) => { + ScalarValue::TimestampMillisecond( + Some(interval_in_milli), + tz_opt.clone(), + ) + } + arrow::datatypes::DataType::Timestamp(TimeUnit::Microsecond, tz_opt) => { + ScalarValue::TimestampMicrosecond( + Some(interval_in_milli * 1000), + tz_opt.clone(), + ) + } + arrow::datatypes::DataType::Timestamp(TimeUnit::Nanosecond, tz_opt) => { + ScalarValue::TimestampNanosecond( + Some(interval_in_milli * 1000000), + tz_opt.clone(), + ) + } + datatype => { + return Err(DataFusionError::NotImplemented(format!( + "Can't create a scalar from data_type \"{:?}\"", + datatype + ))); + } + }) + } + // TODO: Add handling for other valid ScalarValue types + // For now we can either get ScalarValue::Uint64 or None from PRECEDING AND + // FOLLOWING fields. When sql parser supports datetime types in the window + // range queries extend below to support datetime types inside the window. + unexpected => Err(DataFusionError::Internal(format!( + "unexpected: {:?}", + unexpected + ))), + } +} + +fn convert_range_bound_to_column_type( + order_by: &[PhysicalSortExpr], + physical_input_schema: &Schema, + bound: &WindowFrameBound, +) -> Result { + Ok(match bound { + WindowFrameBound::Preceding(val) => WindowFrameBound::Preceding( + convert_to_column_type(order_by, physical_input_schema, val)?, + ), + WindowFrameBound::CurrentRow => WindowFrameBound::CurrentRow, + WindowFrameBound::Following(val) => WindowFrameBound::Following( + convert_to_column_type(order_by, physical_input_schema, val)?, + ), + }) +} + /// Create a window expression with a name from a logical expression pub fn create_window_expr_with_name( e: &Expr, @@ -1430,21 +1522,38 @@ pub fn create_window_expr_with_name( )), }) .collect::>>()?; - if window_frame.is_some() - && window_frame.unwrap().units == WindowFrameUnits::Groups - { - return Err(DataFusionError::NotImplemented( - "Window frame definitions involving GROUPS are not supported yet" - .to_string(), - )); + let mut new_window_frame = window_frame.clone(); + if let Some(window_frame) = window_frame { + if window_frame.units == WindowFrameUnits::Groups { + return Err(DataFusionError::NotImplemented( + "Window frame definitions involving GROUPS are not supported yet" + .to_string(), + )); + } + if window_frame.units == WindowFrameUnits::Range { + new_window_frame.as_mut().unwrap().start_bound = + convert_range_bound_to_column_type( + &order_by, + physical_input_schema, + &window_frame.start_bound, + )?; + new_window_frame.as_mut().unwrap().end_bound = + convert_range_bound_to_column_type( + &order_by, + physical_input_schema, + &window_frame.end_bound, + )?; + } } + + let new_window_frame = new_window_frame.map(Arc::new); windows::create_window_expr( fun, name, &args, &partition_by, &order_by, - *window_frame, + new_window_frame, physical_input_schema, ) } diff --git a/datafusion/core/src/physical_plan/windows/mod.rs b/datafusion/core/src/physical_plan/windows/mod.rs index 26cb14fe33a9..be9421a9de84 100644 --- a/datafusion/core/src/physical_plan/windows/mod.rs +++ b/datafusion/core/src/physical_plan/windows/mod.rs @@ -51,7 +51,7 @@ pub fn create_window_expr( args: &[Arc], partition_by: &[Arc], order_by: &[PhysicalSortExpr], - window_frame: Option, + window_frame: Option>, input_schema: &Schema, ) -> Result> { Ok(match fun { @@ -186,7 +186,7 @@ mod tests { &[col("c3", &schema)?], &[], &[], - Some(WindowFrame::default()), + Some(Arc::new(WindowFrame::default())), schema.as_ref(), )?, create_window_expr( @@ -195,7 +195,7 @@ mod tests { &[col("c3", &schema)?], &[], &[], - Some(WindowFrame::default()), + Some(Arc::new(WindowFrame::default())), schema.as_ref(), )?, create_window_expr( @@ -204,7 +204,7 @@ mod tests { &[col("c3", &schema)?], &[], &[], - Some(WindowFrame::default()), + Some(Arc::new(WindowFrame::default())), schema.as_ref(), )?, ], @@ -250,7 +250,7 @@ mod tests { &[col("a", &schema)?], &[], &[], - Some(WindowFrame::default()), + Some(Arc::new(WindowFrame::default())), schema.as_ref(), )?], blocking_exec, diff --git a/datafusion/core/tests/sql/window.rs b/datafusion/core/tests/sql/window.rs index 2708d91e0d75..4d7f23beb1c4 100644 --- a/datafusion/core/tests/sql/window.rs +++ b/datafusion/core/tests/sql/window.rs @@ -1075,6 +1075,88 @@ async fn window_frame_partition_by_order_by_desc() -> Result<()> { Ok(()) } +#[tokio::test] +async fn window_frame_range_float() -> Result<()> { + let ctx = SessionContext::new(); + register_aggregate_csv(&ctx).await?; + let sql = "SELECT SUM(c12) OVER (ORDER BY C12 RANGE BETWEEN 0.2 PRECEDING AND 0.2 FOLLOWING) +FROM aggregate_test_100 +ORDER BY C9 +LIMIT 5;"; + let actual = execute_to_batches(&ctx, sql).await; + let expected = vec![ + "+-----------------------------+", + "| SUM(aggregate_test_100.c12) |", + "+-----------------------------+", + "| 2.5476701803634296 |", + "| 10.6299412548214 |", + "| 2.5476701803634296 |", + "| 20.349518503437288 |", + "| 21.408674363507753 |", + "+-----------------------------+", + ]; + assert_batches_eq!(expected, &actual); + Ok(()) +} + +#[tokio::test] +async fn window_frame_ranges_timestamp() -> Result<()> { + // define a schema. + let schema = Arc::new(Schema::new(vec![Field::new( + "ts", + DataType::Timestamp(TimeUnit::Nanosecond, None), + false, + )])); + + // define data in two partitions + let batch = RecordBatch::try_new( + schema.clone(), + vec![Arc::new(TimestampNanosecondArray::from_slice(&[ + 1664264591000000000, + 1664264592000000000, + 1664264593000000000, + 1664264594000000000, + 1664364594000000000, + 1664464594000000000, + 1664564594000000000, + ]))], + ) + .unwrap(); + + let ctx = SessionContext::new(); + // declare a new context. In spark API, this corresponds to a new spark SQLsession + // declare a table in memory. In spark API, this corresponds to createDataFrame(...). + let provider = MemTable::try_new(schema, vec![vec![batch]]).unwrap(); + // Register table + ctx.register_table("t", Arc::new(provider)).unwrap(); + + // execute the query + let df = ctx + .sql( + // "SELECT COUNT(*) OVER (ORDER BY ts) FROM t;" + "SELECT ts, COUNT(*) OVER (ORDER BY ts RANGE BETWEEN '1 DAY' PRECEDING AND '2 DAY' FOLLOWING) FROM t;" + // "SELECT a FROM t;" + ) + .await?; + + let actual = df.collect().await?; + let expected = vec![ + "+---------------------+-----------------+", + "| ts | COUNT(UInt8(1)) |", + "+---------------------+-----------------+", + "| 2022-09-27 07:43:11 | 5 |", + "| 2022-09-27 07:43:12 | 5 |", + "| 2022-09-27 07:43:13 | 5 |", + "| 2022-09-27 07:43:14 | 5 |", + "| 2022-09-28 11:29:54 | 2 |", + "| 2022-09-29 15:16:34 | 2 |", + "| 2022-09-30 19:03:14 | 1 |", + "+---------------------+-----------------+", + ]; + assert_batches_eq!(expected, &actual); + Ok(()) +} + #[tokio::test] async fn window_frame_ranges_unbounded_preceding_err() -> Result<()> { let ctx = SessionContext::new(); diff --git a/datafusion/expr/Cargo.toml b/datafusion/expr/Cargo.toml index 3280628a42eb..6b25a0ddedca 100644 --- a/datafusion/expr/Cargo.toml +++ b/datafusion/expr/Cargo.toml @@ -38,4 +38,4 @@ path = "src/lib.rs" ahash = { version = "0.8", default-features = false, features = ["runtime-rng"] } arrow = { version = "24.0.0", default-features = false } datafusion-common = { path = "../common", version = "13.0.0" } -sqlparser = "0.25" +sqlparser = { git = "https://github.com/synnada-ai/sqlparser-rs", branch = "main" } diff --git a/datafusion/expr/src/window_frame.rs b/datafusion/expr/src/window_frame.rs index 7f9afd0b51a8..300793384e7b 100644 --- a/datafusion/expr/src/window_frame.rs +++ b/datafusion/expr/src/window_frame.rs @@ -23,7 +23,8 @@ //! - An ending frame boundary, //! - An EXCLUDE clause. -use datafusion_common::{DataFusionError, Result}; +use datafusion_common::scalar::get_scalar_value; +use datafusion_common::{DataFusionError, Result, ScalarValue}; use sqlparser::ast; use std::cmp::Ordering; use std::convert::{From, TryFrom}; @@ -35,7 +36,7 @@ use std::hash::{Hash, Hasher}; /// The ending frame boundary can be omitted (if the BETWEEN and AND keywords that surround the /// starting frame boundary are also omitted), in which case the ending frame boundary defaults to /// CURRENT ROW. -#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Hash)] +#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Hash)] pub struct WindowFrame { /// A frame type - either ROWS, RANGE or GROUPS pub units: WindowFrameUnits, @@ -66,12 +67,12 @@ impl TryFrom for WindowFrame { .map(WindowFrameBound::from) .unwrap_or(WindowFrameBound::CurrentRow); - if let WindowFrameBound::Following(None) = start_bound { + if let WindowFrameBound::Following(ScalarValue::Null) = start_bound { Err(DataFusionError::Execution( "Invalid window frame: start bound cannot be unbounded following" .to_owned(), )) - } else if let WindowFrameBound::Preceding(None) = end_bound { + } else if let WindowFrameBound::Preceding(ScalarValue::Null) = end_bound { Err(DataFusionError::Execution( "Invalid window frame: end bound cannot be unbounded preceding" .to_owned(), @@ -96,12 +97,41 @@ impl Default for WindowFrame { fn default() -> Self { WindowFrame { units: WindowFrameUnits::Range, - start_bound: WindowFrameBound::Preceding(None), + start_bound: WindowFrameBound::Preceding(ScalarValue::Null), end_bound: WindowFrameBound::CurrentRow, } } } +pub fn convert_range_bound_to_scalar_value(v: ast::RangeBounds) -> Result { + match v { + ast::RangeBounds::Number(number) => get_scalar_value(&number[..]), + ast::RangeBounds::Interval(ast::Expr::Interval { + value, + leading_field: _, + leading_precision: _, + last_field: _, + fractional_seconds_precision: _, + }) => { + let res = match *value { + ast::Expr::Value(ast::Value::SingleQuotedString(elem)) => Ok(elem), + unexpected => Err(DataFusionError::Internal(format!( + "INTERVAL expression cannot be {:?}", + unexpected + ))), + }; + // parse the interval most precise way possible which is millisecond for now + // TODO: Add parser the IntervalMonthDayNano case with nanosecond option + let res = datafusion_common::parsers::parse_interval("millisecond", &res?)?; + Ok(res) + } + unexpected => Err(DataFusionError::Internal(format!( + "RangeBounds cannot be {:?}", + unexpected + ))), + } +} + /// There are five ways to describe starting and ending frame boundaries: /// /// 1. UNBOUNDED PRECEDING @@ -111,7 +141,7 @@ impl Default for WindowFrame { /// 5. UNBOUNDED FOLLOWING /// /// in this implementation we'll only allow to be u64 (i.e. no dynamic boundary) -#[derive(Debug, Clone, Copy, Eq)] +#[derive(Debug, Clone, Eq)] pub enum WindowFrameBound { /// 1. UNBOUNDED PRECEDING /// The frame boundary is the first row in the partition. @@ -119,7 +149,7 @@ pub enum WindowFrameBound { /// 2. PRECEDING /// must be a non-negative constant numeric expression. The boundary is a row that /// is "units" prior to the current row. - Preceding(Option), + Preceding(ScalarValue), /// 3. The current row. /// /// For RANGE and GROUPS frame types, peers of the current row are also @@ -132,14 +162,27 @@ pub enum WindowFrameBound { /// /// 5. UNBOUNDED FOLLOWING /// The frame boundary is the last row in the partition. - Following(Option), + Following(ScalarValue), } impl From for WindowFrameBound { + // TODO: Add handling for other ScalarValue, once sql parser supports other types than literal int + // see https://github.com/sqlparser-rs/sqlparser-rs/issues/631 + // For now we can either get Some(u64) or None from PRECEDING AND + // FOLLOWING fields. When sql parser supports datetime types in the window + // range queries extend below to support datetime types inside the window. fn from(value: ast::WindowFrameBound) -> Self { match value { - ast::WindowFrameBound::Preceding(v) => Self::Preceding(v), - ast::WindowFrameBound::Following(v) => Self::Following(v), + ast::WindowFrameBound::Preceding(Some(v)) => { + let res = convert_range_bound_to_scalar_value(v).unwrap(); + Self::Preceding(res) + } + ast::WindowFrameBound::Preceding(None) => Self::Preceding(ScalarValue::Null), + ast::WindowFrameBound::Following(Some(v)) => { + let res = convert_range_bound_to_scalar_value(v).unwrap(); + Self::Following(res) + } + ast::WindowFrameBound::Following(None) => Self::Following(ScalarValue::Null), ast::WindowFrameBound::CurrentRow => Self::CurrentRow, } } @@ -148,11 +191,15 @@ impl From for WindowFrameBound { impl fmt::Display for WindowFrameBound { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match self { + WindowFrameBound::Preceding(ScalarValue::Null) => { + f.write_str("UNBOUNDED PRECEDING") + } + WindowFrameBound::Preceding(n) => write!(f, "{} PRECEDING", n), WindowFrameBound::CurrentRow => f.write_str("CURRENT ROW"), - WindowFrameBound::Preceding(None) => f.write_str("UNBOUNDED PRECEDING"), - WindowFrameBound::Following(None) => f.write_str("UNBOUNDED FOLLOWING"), - WindowFrameBound::Preceding(Some(n)) => write!(f, "{} PRECEDING", n), - WindowFrameBound::Following(Some(n)) => write!(f, "{} FOLLOWING", n), + WindowFrameBound::Following(ScalarValue::Null) => { + f.write_str("UNBOUNDED FOLLOWING") + } + WindowFrameBound::Following(n) => write!(f, "{} FOLLOWING", n), } } } @@ -189,13 +236,26 @@ impl WindowFrameBound { /// rank and also for 0 preceding / following it is the same as current row fn get_rank(&self) -> (u8, u64) { match self { - WindowFrameBound::Preceding(None) => (0, 0), - WindowFrameBound::Following(None) => (4, 0), - WindowFrameBound::Preceding(Some(0)) + WindowFrameBound::Preceding(ScalarValue::Null) => (0, 0), + WindowFrameBound::Following(ScalarValue::Null) => (4, 0), + WindowFrameBound::Preceding(ScalarValue::UInt64(Some(0))) | WindowFrameBound::CurrentRow - | WindowFrameBound::Following(Some(0)) => (2, 0), - WindowFrameBound::Preceding(Some(v)) => (1, u64::MAX - *v), - WindowFrameBound::Following(Some(v)) => (3, *v), + | WindowFrameBound::Following(ScalarValue::UInt64(Some(0))) => (2, 0), + WindowFrameBound::Preceding(ScalarValue::UInt64(Some(v))) => { + (1, u64::MAX - *v as u64) + } + WindowFrameBound::Preceding(ScalarValue::Float64(Some(v))) => { + (1, u64::MAX - *v as u64) + } + WindowFrameBound::Following(ScalarValue::UInt64(Some(v))) => (3, *v as u64), + WindowFrameBound::Following(ScalarValue::Float64(Some(v))) => (3, *v as u64), + WindowFrameBound::Preceding(ScalarValue::IntervalDayTime(Some(v))) => { + (1, u64::MAX - *v as u64) + } + WindowFrameBound::Following(ScalarValue::IntervalDayTime(Some(v))) => { + (3, *v as u64) + } + _ => todo!(), } } } @@ -271,8 +331,12 @@ mod tests { let window_frame = ast::WindowFrame { units: ast::WindowFrameUnits::Range, - start_bound: ast::WindowFrameBound::Preceding(Some(1)), - end_bound: Some(ast::WindowFrameBound::Preceding(Some(2))), + start_bound: ast::WindowFrameBound::Preceding(Some( + ast::RangeBounds::Number("1".to_string()), + )), + end_bound: Some(ast::WindowFrameBound::Preceding(Some( + ast::RangeBounds::Number("2".to_string()), + ))), }; let result = WindowFrame::try_from(window_frame); assert_eq!( @@ -282,8 +346,12 @@ mod tests { let window_frame = ast::WindowFrame { units: ast::WindowFrameUnits::Rows, - start_bound: ast::WindowFrameBound::Preceding(Some(2)), - end_bound: Some(ast::WindowFrameBound::Preceding(Some(1))), + start_bound: ast::WindowFrameBound::Preceding(Some( + ast::RangeBounds::Number("2".to_string()), + )), + end_bound: Some(ast::WindowFrameBound::Preceding(Some( + ast::RangeBounds::Number("1".to_string()), + ))), }; let result = WindowFrame::try_from(window_frame); assert!(result.is_ok()); @@ -293,62 +361,77 @@ mod tests { #[test] fn test_eq() { assert_eq!( - WindowFrameBound::Preceding(Some(0)), + WindowFrameBound::Preceding(ScalarValue::UInt64(Some(0))), WindowFrameBound::CurrentRow ); assert_eq!( WindowFrameBound::CurrentRow, - WindowFrameBound::Following(Some(0)) + WindowFrameBound::Following(ScalarValue::UInt64(Some(0))) ); assert_eq!( - WindowFrameBound::Following(Some(2)), - WindowFrameBound::Following(Some(2)) + WindowFrameBound::Following(ScalarValue::UInt64(Some(2))), + WindowFrameBound::Following(ScalarValue::UInt64(Some(2))) ); assert_eq!( - WindowFrameBound::Following(None), - WindowFrameBound::Following(None) + WindowFrameBound::Following(ScalarValue::Null), + WindowFrameBound::Following(ScalarValue::Null) ); assert_eq!( - WindowFrameBound::Preceding(Some(2)), - WindowFrameBound::Preceding(Some(2)) + WindowFrameBound::Preceding(ScalarValue::UInt64(Some(2))), + WindowFrameBound::Preceding(ScalarValue::UInt64(Some(2))) ); assert_eq!( - WindowFrameBound::Preceding(None), - WindowFrameBound::Preceding(None) + WindowFrameBound::Preceding(ScalarValue::Null), + WindowFrameBound::Preceding(ScalarValue::Null) ); } #[test] fn test_ord() { - assert!(WindowFrameBound::Preceding(Some(1)) < WindowFrameBound::CurrentRow); + assert!( + WindowFrameBound::Preceding(ScalarValue::UInt64(Some(1))) + < WindowFrameBound::CurrentRow + ); // ! yes this is correct! assert!( - WindowFrameBound::Preceding(Some(2)) < WindowFrameBound::Preceding(Some(1)) + WindowFrameBound::Preceding(ScalarValue::UInt64(Some(2))) + < WindowFrameBound::Preceding(ScalarValue::UInt64(Some(1))) ); assert!( - WindowFrameBound::Preceding(Some(u64::MAX)) - < WindowFrameBound::Preceding(Some(u64::MAX - 1)) + WindowFrameBound::Preceding(ScalarValue::UInt64(Some(u64::MAX))) + < WindowFrameBound::Preceding(ScalarValue::UInt64(Some(u64::MAX - 1))) ); assert!( - WindowFrameBound::Preceding(None) - < WindowFrameBound::Preceding(Some(1000000)) + WindowFrameBound::Preceding(ScalarValue::Null) + < WindowFrameBound::Preceding(ScalarValue::UInt64(Some(1000000))) ); assert!( - WindowFrameBound::Preceding(None) - < WindowFrameBound::Preceding(Some(u64::MAX)) + WindowFrameBound::Preceding(ScalarValue::Null) + < WindowFrameBound::Preceding(ScalarValue::UInt64(Some(u64::MAX))) + ); + assert!( + WindowFrameBound::Preceding(ScalarValue::Null) + < WindowFrameBound::Following(ScalarValue::UInt64(Some(0))) + ); + assert!( + WindowFrameBound::Preceding(ScalarValue::UInt64(Some(1))) + < WindowFrameBound::Following(ScalarValue::UInt64(Some(1))) + ); + assert!( + WindowFrameBound::CurrentRow + < WindowFrameBound::Following(ScalarValue::UInt64(Some(1))) ); - assert!(WindowFrameBound::Preceding(None) < WindowFrameBound::Following(Some(0))); assert!( - WindowFrameBound::Preceding(Some(1)) < WindowFrameBound::Following(Some(1)) + WindowFrameBound::Following(ScalarValue::UInt64(Some(1))) + < WindowFrameBound::Following(ScalarValue::UInt64(Some(2))) ); - assert!(WindowFrameBound::CurrentRow < WindowFrameBound::Following(Some(1))); assert!( - WindowFrameBound::Following(Some(1)) < WindowFrameBound::Following(Some(2)) + WindowFrameBound::Following(ScalarValue::UInt64(Some(2))) + < WindowFrameBound::Following(ScalarValue::Null) ); - assert!(WindowFrameBound::Following(Some(2)) < WindowFrameBound::Following(None)); assert!( - WindowFrameBound::Following(Some(u64::MAX)) - < WindowFrameBound::Following(None) + WindowFrameBound::Following(ScalarValue::UInt64(Some(u64::MAX))) + < WindowFrameBound::Following(ScalarValue::Null) ); } } diff --git a/datafusion/physical-expr/src/window/aggregate.rs b/datafusion/physical-expr/src/window/aggregate.rs index b81c0eaf243b..0d78624ca4a3 100644 --- a/datafusion/physical-expr/src/window/aggregate.rs +++ b/datafusion/physical-expr/src/window/aggregate.rs @@ -44,7 +44,7 @@ pub struct AggregateWindowExpr { aggregate: Arc, partition_by: Vec>, order_by: Vec, - window_frame: Option, + window_frame: Option>, } impl AggregateWindowExpr { @@ -53,7 +53,7 @@ impl AggregateWindowExpr { aggregate: Arc, partition_by: &[Arc], order_by: &[PhysicalSortExpr], - window_frame: Option, + window_frame: Option>, ) -> Self { Self { aggregate, @@ -66,7 +66,7 @@ impl AggregateWindowExpr { /// create a new accumulator based on the underlying aggregation function fn create_accumulator(&self) -> Result { let accumulator = self.aggregate.create_accumulator()?; - let window_frame = self.window_frame; + let window_frame = self.window_frame.clone(); let partition_by = self.partition_by().to_vec(); let order_by = self.order_by.to_vec(); let field = self.aggregate.field()?; @@ -144,15 +144,13 @@ fn calculate_index_of_row( range_columns: &[ArrayRef], sort_options: &[SortOptions], idx: usize, - delta: u64, + delta: Option<&ScalarValue>, ) -> Result { let current_row_values = range_columns .iter() .map(|col| ScalarValue::try_from_array(col, idx)) .collect::>>()?; - let end_range = if delta == 0 { - current_row_values - } else { + let end_range = if let Some(delta) = delta { let is_descending: bool = sort_options .first() .ok_or_else(|| DataFusionError::Internal("Array is empty".to_string()))? @@ -161,21 +159,20 @@ fn calculate_index_of_row( current_row_values .iter() .map(|value| { - if value.is_null() { - return Ok(value.clone()); - }; - let offset = ScalarValue::try_from_value(&value.get_datatype(), delta)?; - if SEARCH_SIDE == is_descending { - // TODO: Handle positive overflows - value.add(&offset) - } else if value.is_unsigned() && value < &offset { - ScalarValue::try_from_value(&value.get_datatype(), 0) + Ok(if value.is_null() { + value.clone() + } else if SEARCH_SIDE == is_descending { + // TODO: ADD overflow check + value.add(delta)? + } else if value.is_unsigned() && value < delta { + ScalarValue::try_from_value(&value.get_datatype(), 0)? } else { - // TODO: Handle negative overflows - value.sub(&offset) - } + value.sub(delta)? + }) }) .collect::>>()? + } else { + current_row_values }; // `BISECT_SIDE` true means bisect_left, false means bisect_right bisect::(range_columns, &end_range, sort_options) @@ -192,116 +189,122 @@ fn calculate_current_window( ) -> Result<(usize, usize)> { match window_frame.units { WindowFrameUnits::Range => { - let start = match window_frame.start_bound { + let start = match &window_frame.start_bound { // UNBOUNDED PRECEDING - WindowFrameBound::Preceding(None) => Ok(0), - WindowFrameBound::Preceding(Some(n)) => { - calculate_index_of_row::( - range_columns, - sort_options, - idx, - n, - ) - } + WindowFrameBound::Preceding(ScalarValue::Null) => Ok(0), + WindowFrameBound::Preceding(n) => calculate_index_of_row::( + range_columns, + sort_options, + idx, + Some(n), + ), WindowFrameBound::CurrentRow => calculate_index_of_row::( range_columns, sort_options, idx, - 0, + None, ), - WindowFrameBound::Following(Some(n)) => { - calculate_index_of_row::( - range_columns, - sort_options, - idx, - n, - ) - } // UNBOUNDED FOLLOWING - WindowFrameBound::Following(None) => { + WindowFrameBound::Following(ScalarValue::Null) => { Err(DataFusionError::Internal(format!( "Frame start cannot be UNBOUNDED FOLLOWING '{:?}'", window_frame ))) } + WindowFrameBound::Following(n) => calculate_index_of_row::( + range_columns, + sort_options, + idx, + Some(n), + ), }; - let end = match window_frame.end_bound { + let end = match &window_frame.end_bound { // UNBOUNDED PRECEDING - WindowFrameBound::Preceding(None) => { + WindowFrameBound::Preceding(ScalarValue::Null) => { Err(DataFusionError::Internal(format!( "Frame end cannot be UNBOUNDED PRECEDING '{:?}'", window_frame ))) } - WindowFrameBound::Preceding(Some(n)) => { - calculate_index_of_row::( - range_columns, - sort_options, - idx, - n, - ) - } + WindowFrameBound::Preceding(n) => calculate_index_of_row::( + range_columns, + sort_options, + idx, + Some(n), + ), WindowFrameBound::CurrentRow => calculate_index_of_row::( range_columns, sort_options, idx, - 0, + None, ), - WindowFrameBound::Following(Some(n)) => { - calculate_index_of_row::( - range_columns, - sort_options, - idx, - n, - ) - } // UNBOUNDED FOLLOWING - WindowFrameBound::Following(None) => Ok(length), + WindowFrameBound::Following(ScalarValue::Null) => Ok(length), + WindowFrameBound::Following(n) => calculate_index_of_row::( + range_columns, + sort_options, + idx, + Some(n), + ), }; Ok((start?, end?)) } WindowFrameUnits::Rows => { - let start = match window_frame.start_bound { + let start = match &window_frame.start_bound { // UNBOUNDED PRECEDING - WindowFrameBound::Preceding(None) => Ok(0), - WindowFrameBound::Preceding(Some(n)) => { - if idx >= n as usize { - Ok(idx - n as usize) + WindowFrameBound::Preceding(ScalarValue::Null) => Ok(0), + WindowFrameBound::Preceding(ScalarValue::UInt64(Some(n))) => { + if idx >= *n as usize { + Ok(idx - *n as usize) } else { Ok(0) } } + WindowFrameBound::Preceding(_) => { + Err(DataFusionError::Internal("Rows should be Uint".to_string())) + } WindowFrameBound::CurrentRow => Ok(idx), - WindowFrameBound::Following(Some(n)) => Ok(min(idx + n as usize, length)), // UNBOUNDED FOLLOWING - WindowFrameBound::Following(None) => { + WindowFrameBound::Following(ScalarValue::Null) => { Err(DataFusionError::Internal(format!( "Frame start cannot be UNBOUNDED FOLLOWING '{:?}'", window_frame ))) } + WindowFrameBound::Following(ScalarValue::UInt64(Some(n))) => { + Ok(min(idx + *n as usize, length)) + } + WindowFrameBound::Following(_) => { + Err(DataFusionError::Internal("Rows should be Uint".to_string())) + } }; let end = match window_frame.end_bound { // UNBOUNDED PRECEDING - WindowFrameBound::Preceding(None) => { + WindowFrameBound::Preceding(ScalarValue::Null) => { Err(DataFusionError::Internal(format!( "Frame end cannot be UNBOUNDED PRECEDING '{:?}'", window_frame ))) } - WindowFrameBound::Preceding(Some(n)) => { + WindowFrameBound::Preceding(ScalarValue::UInt64(Some(n))) => { if idx >= n as usize { Ok(idx - n as usize + 1) } else { Ok(0) } } + WindowFrameBound::Preceding(_) => { + Err(DataFusionError::Internal("Rows should be Uint".to_string())) + } WindowFrameBound::CurrentRow => Ok(idx + 1), - WindowFrameBound::Following(Some(n)) => { + // UNBOUNDED FOLLOWING + WindowFrameBound::Following(ScalarValue::Null) => Ok(length), + WindowFrameBound::Following(ScalarValue::UInt64(Some(n))) => { Ok(min(idx + n as usize + 1, length)) } - // UNBOUNDED FOLLOWING - WindowFrameBound::Following(None) => Ok(length), + WindowFrameBound::Following(_) => { + Err(DataFusionError::Internal("Rows should be Uint".to_string())) + } }; Ok((start?, end?)) } @@ -317,7 +320,7 @@ fn calculate_current_window( #[derive(Debug)] struct AggregateWindowAccumulator { accumulator: Box, - window_frame: Option, + window_frame: Option>, partition_by: Vec>, order_by: Vec, field: Field, @@ -325,12 +328,12 @@ struct AggregateWindowAccumulator { impl AggregateWindowAccumulator { /// This function constructs a simple window frame with a single ORDER BY. - fn implicit_order_by_window() -> WindowFrame { - WindowFrame { + fn implicit_order_by_window() -> Arc { + Arc::new(WindowFrame { units: WindowFrameUnits::Range, - start_bound: WindowFrameBound::Preceding(None), - end_bound: WindowFrameBound::Following(Some(0)), - } + start_bound: WindowFrameBound::Preceding(ScalarValue::Null), + end_bound: WindowFrameBound::CurrentRow, + }) } /// This function calculates the aggregation on all rows in `value_slice`. /// Returns an array of size `len`. @@ -433,7 +436,7 @@ impl AggregateWindowAccumulator { .map(|v| v.slice(value_range.start, length)) .collect::>(); let order_columns = &order_bys[self.partition_by.len()..order_bys.len()].to_vec(); - match (order_columns.len(), self.window_frame) { + match (order_columns.len(), &self.window_frame) { (0, None) => { // OVER () case self.calculate_whole_table(&value_slice, length) diff --git a/datafusion/proto/proto/datafusion.proto b/datafusion/proto/proto/datafusion.proto index 9f2cc2d07839..79676a9165ca 100644 --- a/datafusion/proto/proto/datafusion.proto +++ b/datafusion/proto/proto/datafusion.proto @@ -632,7 +632,7 @@ message WindowFrameBound { // "optional" keyword is stable in protoc 3.15 but prost is still on 3.14 (see https://github.com/tokio-rs/prost/issues/430 and https://github.com/tokio-rs/prost/pull/455) // this syntax is ugly but is binary compatible with the "optional" keyword (see https://stackoverflow.com/questions/42622015/how-to-define-an-optional-field-in-protobuf-3) oneof bound_value { - uint64 value = 2; + ScalarValue value = 2; } } diff --git a/datafusion/proto/src/from_proto.rs b/datafusion/proto/src/from_proto.rs index fe1fdfaa0c6a..0111ae200561 100644 --- a/datafusion/proto/src/from_proto.rs +++ b/datafusion/proto/src/from_proto.rs @@ -1280,12 +1280,12 @@ impl TryFrom for WindowFrameBound { protobuf::WindowFrameBoundType::Preceding => { // FIXME implement bound value parsing // https://github.com/apache/arrow-datafusion/issues/361 - Ok(Self::Preceding(Some(1))) + Ok(Self::Preceding(ScalarValue::UInt64(Some(1)))) } protobuf::WindowFrameBoundType::Following => { // FIXME implement bound value parsing // https://github.com/apache/arrow-datafusion/issues/361 - Ok(Self::Following(Some(1))) + Ok(Self::Following(ScalarValue::UInt64(Some(1)))) } } } diff --git a/datafusion/proto/src/to_proto.rs b/datafusion/proto/src/to_proto.rs index 3c4f01075d2b..032542af4063 100644 --- a/datafusion/proto/src/to_proto.rs +++ b/datafusion/proto/src/to_proto.rs @@ -397,14 +397,26 @@ impl From for protobuf::WindowFrameBound { .into(), bound_value: None, }, - WindowFrameBound::Preceding(v) => Self { - window_frame_bound_type: protobuf::WindowFrameBoundType::Preceding.into(), - bound_value: v.map(protobuf::window_frame_bound::BoundValue::Value), - }, - WindowFrameBound::Following(v) => Self { - window_frame_bound_type: protobuf::WindowFrameBoundType::Following.into(), - bound_value: v.map(protobuf::window_frame_bound::BoundValue::Value), - }, + WindowFrameBound::Preceding(v) => { + let pb_value: protobuf::ScalarValue = (&v).try_into().unwrap(); + Self { + window_frame_bound_type: protobuf::WindowFrameBoundType::Preceding + .into(), + bound_value: Some(protobuf::window_frame_bound::BoundValue::Value( + pb_value, + )), + } + } + WindowFrameBound::Following(v) => { + let pb_value: protobuf::ScalarValue = (&v).try_into().unwrap(); + Self { + window_frame_bound_type: protobuf::WindowFrameBoundType::Following + .into(), + bound_value: Some(protobuf::window_frame_bound::BoundValue::Value( + pb_value, + )), + } + } } } } @@ -543,8 +555,9 @@ impl TryFrom<&Expr> for protobuf::LogicalExprNode { .iter() .map(|e| e.try_into()) .collect::, _>>()?; - let window_frame = window_frame.map(|window_frame| { - protobuf::window_expr_node::WindowFrame::Frame(window_frame.into()) + let window_frame: Option = window_frame.as_ref().map(|window_frame| { + let pb_value: protobuf::WindowFrame = window_frame.clone().into(); + protobuf::window_expr_node::WindowFrame::Frame(pb_value) }); let window_expr = Box::new(protobuf::WindowExprNode { expr: arg_expr, diff --git a/datafusion/sql/Cargo.toml b/datafusion/sql/Cargo.toml index 4633e225ae51..b3b79c7cf23f 100644 --- a/datafusion/sql/Cargo.toml +++ b/datafusion/sql/Cargo.toml @@ -40,4 +40,4 @@ unicode_expressions = [] arrow = { version = "24.0.0", default-features = false } datafusion-common = { path = "../common", version = "13.0.0" } datafusion-expr = { path = "../expr", version = "13.0.0" } -sqlparser = "0.25" +sqlparser = { git = "https://github.com/synnada-ai/sqlparser-rs", branch = "main" } diff --git a/datafusion/sql/src/lib.rs b/datafusion/sql/src/lib.rs index 20d3ed2f72be..19404419b26c 100644 --- a/datafusion/sql/src/lib.rs +++ b/datafusion/sql/src/lib.rs @@ -18,7 +18,6 @@ //! This module provides a SQL parser that translates SQL queries into an abstract syntax //! tree (AST), and a SQL query planner that creates a logical plan from the AST. -mod interval; pub mod parser; pub mod planner; mod table_reference; diff --git a/datafusion/sql/src/planner.rs b/datafusion/sql/src/planner.rs index b75efd6f4723..9c3950fa2000 100644 --- a/datafusion/sql/src/planner.rs +++ b/datafusion/sql/src/planner.rs @@ -17,9 +17,9 @@ //! SQL Query Planner (produces logical plan from SQL AST) -use crate::interval::parse_interval; use crate::parser::{CreateExternalTable, DescribeTable, Statement as DFStatement}; use arrow::datatypes::*; +use datafusion_common::parsers::parse_interval; use datafusion_common::{context, ToDFSchema}; use datafusion_expr::expr_rewriter::normalize_col; use datafusion_expr::expr_rewriter::normalize_col_with_schemas; @@ -60,8 +60,10 @@ use sqlparser::ast::{ FunctionArgExpr, Ident, Join, JoinConstraint, JoinOperator, ObjectName, Offset as SQLOffset, Query, Select, SelectItem, SetExpr, SetOperator, ShowCreateObject, ShowStatementFilter, TableAlias, TableFactor, TableWithJoins, - TimezoneInfo, TrimWhereField, UnaryOperator, Value, Values as SQLValues, + TrimWhereField, UnaryOperator, Value, Values as SQLValues, }; +// use sqlparser::ast::TimezoneInfo; +use sqlparser::ast::TimezoneInfo; use sqlparser::ast::{ColumnDef as SQLColumnDef, ColumnOption}; use sqlparser::ast::{ObjectType, OrderByExpr, Statement}; use sqlparser::parser::ParserError::ParserError; @@ -154,6 +156,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { analyze, format: _, describe_alias: _, + .. } => self.explain_statement_to_plan(verbose, analyze, *statement), Statement::Query(query) => self.query_to_plan(*query, &mut HashMap::new()), Statement::ShowVariable { variable } => self.show_variable_to_plan(&variable), @@ -1749,7 +1752,6 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { }), SQLExpr::Array(arr) => self.sql_array_literal(arr.elem, schema), - SQLExpr::Interval { value, leading_field, @@ -1763,7 +1765,6 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { last_field, fractional_seconds_precision, ), - SQLExpr::Identifier(id) => { if id.value.starts_with('@') { // TODO: figure out if ScalarVariables should be insensitive. @@ -2773,6 +2774,9 @@ pub fn convert_simple_data_type(sql_type: &SQLDataType) -> Result { | SQLDataType::Set(_) | SQLDataType::MediumInt(_) | SQLDataType::UnsignedMediumInt(_) + | SQLDataType::Character(_) + | SQLDataType::CharacterVarying(_) + | SQLDataType::CharVarying(_) | SQLDataType::Clob(_) => Err(DataFusionError::NotImplemented(format!( "Unsupported SQL type {:?}", sql_type @@ -2812,7 +2816,7 @@ fn parse_sql_number(n: &str) -> Result { #[cfg(test)] mod tests { use super::*; - use crate::assert_contains; + use datafusion_common::assert_contains; use sqlparser::dialect::{Dialect, GenericDialect, HiveDialect, MySqlDialect}; use std::any::Any; @@ -5361,29 +5365,6 @@ mod tests { } } - /// A macro to assert that one string is contained within another with - /// a nice error message if they are not. - /// - /// Usage: `assert_contains!(actual, expected)` - /// - /// Is a macro so test error - /// messages are on the same line as the failure; - /// - /// Both arguments must be convertable into Strings (Into) - #[macro_export] - macro_rules! assert_contains { - ($ACTUAL: expr, $EXPECTED: expr) => { - let actual_value: String = $ACTUAL.into(); - let expected_value: String = $EXPECTED.into(); - assert!( - actual_value.contains(&expected_value), - "Can not find expected in actual.\n\nExpected:\n{}\n\nActual:\n{}", - expected_value, - actual_value - ); - }; - } - struct EmptyTable { table_schema: SchemaRef, } diff --git a/datafusion/sql/src/utils.rs b/datafusion/sql/src/utils.rs index 952ef31106fd..65a397689e0e 100644 --- a/datafusion/sql/src/utils.rs +++ b/datafusion/sql/src/utils.rs @@ -193,7 +193,7 @@ where .iter() .map(|e| clone_with_replacement(e, replacement_fn)) .collect::>>()?, - window_frame: *window_frame, + window_frame: window_frame.clone(), }), Expr::AggregateUDF { fun, args, filter } => Ok(Expr::AggregateUDF { fun: fun.clone(), From 10204c8467ec066223776dc6b21fc49f338f24c9 Mon Sep 17 00:00:00 2001 From: Mustafa Akur Date: Fri, 7 Oct 2022 19:08:56 +0300 Subject: [PATCH 02/13] timestamp handling is added --- datafusion/common/Cargo.toml | 1 + datafusion/common/src/datetime.rs | 159 ++++++++++++++ .../src/expressions => common/src}/delta.rs | 4 +- datafusion/common/src/lib.rs | 3 + datafusion/common/src/parsers.rs | 26 +-- datafusion/common/src/scalar.rs | 86 +++----- datafusion/common/src/test_util.rs | 64 ++++++ .../src/physical_plan/file_format/parquet.rs | 3 +- datafusion/core/src/physical_plan/planner.rs | 188 ++++++++-------- datafusion/core/src/test_util.rs | 46 ---- datafusion/core/tests/sql/idenfifers.rs | 3 +- datafusion/core/tests/sql/mod.rs | 6 +- datafusion/core/tests/sql/window.rs | 14 +- datafusion/expr/Cargo.toml | 1 + datafusion/expr/src/window_frame.rs | 122 ++++++++--- .../physical-expr/src/expressions/datetime.rs | 201 ++---------------- .../physical-expr/src/expressions/mod.rs | 1 - .../physical-expr/src/window/aggregate.rs | 37 ++-- datafusion/proto/proto/datafusion.proto | 6 +- datafusion/proto/src/to_proto.rs | 22 +- datafusion/sql/Cargo.toml | 1 + datafusion/sql/src/planner.rs | 3 +- 22 files changed, 503 insertions(+), 494 deletions(-) create mode 100644 datafusion/common/src/datetime.rs rename datafusion/{physical-expr/src/expressions => common/src}/delta.rs (98%) create mode 100644 datafusion/common/src/test_util.rs diff --git a/datafusion/common/Cargo.toml b/datafusion/common/Cargo.toml index 9a72ac25dbdb..ee9148c803a8 100644 --- a/datafusion/common/Cargo.toml +++ b/datafusion/common/Cargo.toml @@ -41,6 +41,7 @@ pyarrow = ["pyo3", "arrow/pyarrow"] [dependencies] apache-avro = { version = "0.14", default-features = false, features = ["snappy"], optional = true } arrow = { version = "24.0.0", default-features = false } +chrono = { version = "0.4", default-features = false } cranelift-module = { version = "0.88.0", optional = true } object_store = { version = "0.5.0", default-features = false, optional = true } ordered-float = "3.0" diff --git a/datafusion/common/src/datetime.rs b/datafusion/common/src/datetime.rs new file mode 100644 index 000000000000..80407a902c0b --- /dev/null +++ b/datafusion/common/src/datetime.rs @@ -0,0 +1,159 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::delta::shift_months; +use crate::Result; +use crate::{DataFusionError, ScalarValue}; +use arrow::datatypes::{IntervalDayTimeType, IntervalMonthDayNanoType}; +use chrono::{Datelike, Duration, NaiveDate, NaiveDateTime}; +use std::ops::{Add, Sub}; + +pub fn evaluate_scalar( + operand: ScalarValue, + sign: i32, + scalar: &ScalarValue, +) -> Result { + let res = match operand { + ScalarValue::Date32(Some(days)) => { + let value = date32_add(days, scalar, sign)?; + ScalarValue::Date32(Some(value)) + } + ScalarValue::Date64(Some(ms)) => { + let value = date64_add(ms, scalar, sign)?; + ScalarValue::Date64(Some(value)) + } + ScalarValue::TimestampSecond(Some(ts_s), zone) => { + let value = seconds_add(ts_s, scalar, sign)?; + ScalarValue::TimestampSecond(Some(value), zone) + } + ScalarValue::TimestampMillisecond(Some(ts_ms), zone) => { + let value = milliseconds_add(ts_ms, scalar, sign)?; + ScalarValue::TimestampMillisecond(Some(value), zone) + } + ScalarValue::TimestampMicrosecond(Some(ts_us), zone) => { + let value = microseconds_add(ts_us, scalar, sign)?; + ScalarValue::TimestampMicrosecond(Some(value), zone) + } + ScalarValue::TimestampNanosecond(Some(ts_ns), zone) => { + let value = nanoseconds_add(ts_ns, scalar, sign)?; + ScalarValue::TimestampNanosecond(Some(value), zone) + } + _ => Err(DataFusionError::Execution(format!( + "Invalid lhs type {} for DateIntervalExpr", + operand.get_datatype() + )))?, + }; + Ok(res) +} + +#[inline] +pub fn date32_add(days: i32, scalar: &ScalarValue, sign: i32) -> Result { + let epoch = NaiveDate::from_ymd(1970, 1, 1); + let prior = epoch.add(Duration::days(days as i64)); + let posterior = do_date_math(prior, scalar, sign)?; + Ok(posterior.sub(epoch).num_days() as i32) +} + +#[inline] +pub fn date64_add(ms: i64, scalar: &ScalarValue, sign: i32) -> Result { + let epoch = NaiveDate::from_ymd(1970, 1, 1); + let prior = epoch.add(Duration::milliseconds(ms)); + let posterior = do_date_math(prior, scalar, sign)?; + Ok(posterior.sub(epoch).num_milliseconds()) +} + +#[inline] +pub fn seconds_add(ts_s: i64, scalar: &ScalarValue, sign: i32) -> Result { + Ok(do_date_time_math(ts_s, 0, scalar, sign)?.timestamp()) +} + +#[inline] +pub fn milliseconds_add(ts_ms: i64, scalar: &ScalarValue, sign: i32) -> Result { + let secs = ts_ms / 1000; + let nsecs = ((ts_ms % 1000) * 1_000_000) as u32; + Ok(do_date_time_math(secs, nsecs, scalar, sign)?.timestamp_millis()) +} + +#[inline] +pub fn microseconds_add(ts_us: i64, scalar: &ScalarValue, sign: i32) -> Result { + let secs = ts_us / 1_000_000; + let nsecs = ((ts_us % 1_000_000) * 1000) as u32; + Ok(do_date_time_math(secs, nsecs, scalar, sign)?.timestamp_nanos() / 1000) +} + +#[inline] +pub fn nanoseconds_add(ts_ns: i64, scalar: &ScalarValue, sign: i32) -> Result { + let secs = ts_ns / 1_000_000_000; + let nsecs = (ts_ns % 1_000_000_000) as u32; + Ok(do_date_time_math(secs, nsecs, scalar, sign)?.timestamp_nanos()) +} + +#[inline] +fn do_date_time_math( + secs: i64, + nsecs: u32, + scalar: &ScalarValue, + sign: i32, +) -> Result { + let prior = NaiveDateTime::from_timestamp(secs, nsecs); + do_date_math(prior, scalar, sign) +} + +fn do_date_math(prior: D, scalar: &ScalarValue, sign: i32) -> Result +where + D: Datelike + Add, +{ + Ok(match scalar { + ScalarValue::IntervalDayTime(Some(i)) => add_day_time(prior, *i, sign), + ScalarValue::IntervalYearMonth(Some(i)) => shift_months(prior, *i * sign), + ScalarValue::IntervalMonthDayNano(Some(i)) => add_m_d_nano(prior, *i, sign), + other => Err(DataFusionError::Execution(format!( + "DateIntervalExpr does not support non-interval type {:?}", + other + )))?, + }) +} + +// Can remove once https://github.com/apache/arrow-rs/pull/2031 is released +fn add_m_d_nano(prior: D, interval: i128, sign: i32) -> D +where + D: Datelike + Add, +{ + // let interval = interval as u128; + // let nanos = (interval >> 64) as i64 * sign as i64; + // let days = (interval >> 32) as i32 * sign; + // let months = interval as i32 * sign; + let (months, days, nanos) = IntervalMonthDayNanoType::to_parts(interval); + let months = months * sign; + let days = days * sign; + let nanos = nanos * sign as i64; + let a = shift_months(prior, months); + let b = a.add(Duration::days(days as i64)); + b.add(Duration::nanoseconds(nanos)) +} + +// Can remove once https://github.com/apache/arrow-rs/pull/2031 is released +fn add_day_time(prior: D, interval: i64, sign: i32) -> D +where + D: Datelike + Add, +{ + let (days, ms) = IntervalDayTimeType::to_parts(interval); + let days = days * sign; + let ms = ms * sign; + let intermediate = prior.add(Duration::days(days as i64)); + intermediate.add(Duration::milliseconds(ms as i64)) +} diff --git a/datafusion/physical-expr/src/expressions/delta.rs b/datafusion/common/src/delta.rs similarity index 98% rename from datafusion/physical-expr/src/expressions/delta.rs rename to datafusion/common/src/delta.rs index b7efdab0a48d..1de0836fc3ec 100644 --- a/datafusion/physical-expr/src/expressions/delta.rs +++ b/datafusion/common/src/delta.rs @@ -27,7 +27,7 @@ use chrono::Datelike; /// Returns true if the year is a leap-year, as naively defined in the Gregorian calendar. #[inline] -pub(crate) fn is_leap_year(year: i32) -> bool { +fn is_leap_year(year: i32) -> bool { year % 4 == 0 && (year % 100 != 0 || year % 400 == 0) } @@ -49,7 +49,7 @@ fn normalise_day(year: i32, month: u32, day: u32) -> u32 { /// Shift a date by the given number of months. /// Ambiguous month-ends are shifted backwards as necessary. -pub(crate) fn shift_months(date: D, months: i32) -> D { +pub fn shift_months(date: D, months: i32) -> D { let mut year = date.year() + (date.month() as i32 + months) / 12; let mut month = (date.month() as i32 + months) % 12; let mut day = date.day(); diff --git a/datafusion/common/src/lib.rs b/datafusion/common/src/lib.rs index ba0f37631011..b5bd5fba4ead 100644 --- a/datafusion/common/src/lib.rs +++ b/datafusion/common/src/lib.rs @@ -17,6 +17,8 @@ pub mod bisect; mod column; +pub mod datetime; +pub mod delta; mod dfschema; mod error; pub mod from_slice; @@ -24,6 +26,7 @@ pub mod parsers; #[cfg(feature = "pyarrow")] mod pyarrow; pub mod scalar; +pub mod test_util; pub use column::Column; pub use dfschema::{DFField, DFSchema, DFSchemaRef, ExprSchema, ToDFSchema}; diff --git a/datafusion/common/src/parsers.rs b/datafusion/common/src/parsers.rs index 0318060e5b4f..7b78c92bcbdf 100644 --- a/datafusion/common/src/parsers.rs +++ b/datafusion/common/src/parsers.rs @@ -160,34 +160,10 @@ pub fn parse_interval(leading_field: &str, value: &str) -> Result { Ok(ScalarValue::IntervalDayTime(Some(result))) } -/// A macro to assert that one string is contained within another with -/// a nice error message if they are not. -/// -/// Usage: `assert_contains!(actual, expected)` -/// -/// Is a macro so test error -/// messages are on the same line as the failure; -/// -/// Both arguments must be convertable into Strings (Into) -#[macro_export] -macro_rules! assert_contains { - ($ACTUAL: expr, $EXPECTED: expr) => { - let actual_value: String = $ACTUAL.into(); - let expected_value: String = $EXPECTED.into(); - assert!( - actual_value.contains(&expected_value), - "Can not find expected in actual.\n\nExpected:\n{}\n\nActual:\n{}", - expected_value, - actual_value - ); - }; -} - #[cfg(test)] mod test { - use crate::assert_contains; - use super::*; + use crate::assert_contains; #[test] fn test_parse_ym() { diff --git a/datafusion/common/src/scalar.rs b/datafusion/common/src/scalar.rs index 6cbe7713d3e7..4db54b97ba74 100644 --- a/datafusion/common/src/scalar.rs +++ b/datafusion/common/src/scalar.rs @@ -38,6 +38,7 @@ use arrow::{ }; use ordered_float::OrderedFloat; +use crate::datetime::evaluate_scalar; use crate::error::{DataFusionError, Result}; /// Represents a dynamically typed, nullable single value. @@ -111,21 +112,6 @@ pub enum ScalarValue { Dictionary(Box, Box), } -pub fn get_scalar_value(number: &str) -> Result { - let my_string = number.to_string(); - if my_string.contains('.') { - let res = my_string.parse::().map_err(|_| { - DataFusionError::Internal(format!("couldn\'t parse {}", my_string)) - })?; - Ok(ScalarValue::Float64(Some(res))) - } else { - let res = my_string.parse::().map_err(|_| { - DataFusionError::Internal(format!("couldn\'t parse {}", my_string)) - })?; - Ok(ScalarValue::UInt64(Some(res))) - } -} - // manual implementation of `PartialEq` that uses OrderedFloat to // get defined behavior for floating point impl PartialEq for ScalarValue { @@ -505,10 +491,18 @@ macro_rules! impl_op { macro_rules! impl_distinct_cases_op { ($LHS:expr, $RHS:expr, +) => { match ($LHS, $RHS) { - ( - ScalarValue::TimestampNanosecond(Some(lhs), None), - ScalarValue::TimestampNanosecond(Some(rhs), None), - ) => Ok(ScalarValue::TimestampNanosecond(Some(lhs + rhs), None)), + (ScalarValue::TimestampNanosecond(_, _), ScalarValue::IntervalDayTime(_)) + | ( + ScalarValue::TimestampNanosecond(_, _), + ScalarValue::IntervalYearMonth(_), + ) + | ( + ScalarValue::TimestampNanosecond(_, _), + ScalarValue::IntervalMonthDayNano(_), + ) => { + // 1 means addition + evaluate_scalar($LHS.clone(), 1, &$RHS) + } e => Err(DataFusionError::Internal(format!( "Addition is not implemented for {:?}", e @@ -517,10 +511,18 @@ macro_rules! impl_distinct_cases_op { }; ($LHS:expr, $RHS:expr, -) => { match ($LHS, $RHS) { - ( - ScalarValue::TimestampNanosecond(Some(lhs), None), - ScalarValue::TimestampNanosecond(Some(rhs), None), - ) => Ok(ScalarValue::TimestampNanosecond(Some(lhs - rhs), None)), + (ScalarValue::TimestampNanosecond(_, _), ScalarValue::IntervalDayTime(_)) + | ( + ScalarValue::TimestampNanosecond(_, _), + ScalarValue::IntervalYearMonth(_), + ) + | ( + ScalarValue::TimestampNanosecond(_, _), + ScalarValue::IntervalMonthDayNano(_), + ) => { + // -1 means subtraction + evaluate_scalar($LHS.clone(), -1, &$RHS) + } e => Err(DataFusionError::Internal(format!( "Subtraction is not implemented for {:?}", e @@ -2227,44 +2229,6 @@ impl TryFrom<&DataType> for ScalarValue { } } -// TODO: Remove these coercions once the hardcoded "u64" offset is changed to a -// ScalarValue in WindowFrameBound. -pub trait TryFromValue { - fn try_from_value(datatype: &DataType, value: T) -> Result; -} - -macro_rules! impl_try_from_value { - ($NATIVE:ty, [$([$SCALAR:ident, $PRIMITIVE:ty]),+]) => { - impl TryFromValue<$NATIVE> for ScalarValue { - fn try_from_value(datatype: &DataType, value: $NATIVE) -> Result { - match datatype { - $(DataType::$SCALAR => Ok(ScalarValue::$SCALAR(Some(value as $PRIMITIVE))),)+ - _ => { - let msg = format!("Can't create a scalar from data_type \"{:?}\"", datatype); - Err(DataFusionError::NotImplemented(msg)) - } - } - } - } - }; -} - -impl_try_from_value!( - u64, - [ - [Float64, f64], - [Float32, f32], - [UInt64, u64], - [UInt32, u32], - [UInt16, u16], - [UInt8, u8], - [Int64, i64], - [Int32, i32], - [Int16, i16], - [Int8, i8] - ] -); - macro_rules! format_option { ($F:expr, $EXPR:expr) => {{ match $EXPR { diff --git a/datafusion/common/src/test_util.rs b/datafusion/common/src/test_util.rs new file mode 100644 index 000000000000..3545fd270a76 --- /dev/null +++ b/datafusion/common/src/test_util.rs @@ -0,0 +1,64 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Utility functions to make testing DataFusion based crates easier + +/// A macro to assert that one string is contained within another with +/// a nice error message if they are not. +/// +/// Usage: `assert_contains!(actual, expected)` +/// +/// Is a macro so test error +/// messages are on the same line as the failure; +/// +/// Both arguments must be convertable into Strings (Into) +#[macro_export] +macro_rules! assert_contains { + ($ACTUAL: expr, $EXPECTED: expr) => { + let actual_value: String = $ACTUAL.into(); + let expected_value: String = $EXPECTED.into(); + assert!( + actual_value.contains(&expected_value), + "Can not find expected in actual.\n\nExpected:\n{}\n\nActual:\n{}", + expected_value, + actual_value + ); + }; +} + +/// A macro to assert that one string is NOT contained within another with +/// a nice error message if they are are. +/// +/// Usage: `assert_not_contains!(actual, unexpected)` +/// +/// Is a macro so test error +/// messages are on the same line as the failure; +/// +/// Both arguments must be convertable into Strings (Into) +#[macro_export] +macro_rules! assert_not_contains { + ($ACTUAL: expr, $UNEXPECTED: expr) => { + let actual_value: String = $ACTUAL.into(); + let unexpected_value: String = $UNEXPECTED.into(); + assert!( + !actual_value.contains(&unexpected_value), + "Found unexpected in actual.\n\nUnexpected:\n{}\n\nActual:\n{}", + unexpected_value, + actual_value + ); + }; +} diff --git a/datafusion/core/src/physical_plan/file_format/parquet.rs b/datafusion/core/src/physical_plan/file_format/parquet.rs index 5f72c7acc8b5..b71783113519 100644 --- a/datafusion/core/src/physical_plan/file_format/parquet.rs +++ b/datafusion/core/src/physical_plan/file_format/parquet.rs @@ -887,7 +887,7 @@ mod tests { use crate::prelude::{ParquetReadOptions, SessionConfig, SessionContext}; use crate::test::object_store::local_unpartitioned_file; use crate::{ - assert_batches_sorted_eq, assert_contains, + assert_batches_sorted_eq, datasource::file_format::{parquet::ParquetFormat, FileFormat}, physical_plan::collect, }; @@ -899,6 +899,7 @@ mod tests { datatypes::{DataType, Field}, }; use chrono::{TimeZone, Utc}; + use datafusion_common::assert_contains; use datafusion_expr::{cast, col, lit}; use futures::StreamExt; use object_store::local::LocalFileSystem; diff --git a/datafusion/core/src/physical_plan/planner.rs b/datafusion/core/src/physical_plan/planner.rs index 8ad2f8ca95df..a4204a7d8fcf 100644 --- a/datafusion/core/src/physical_plan/planner.rs +++ b/datafusion/core/src/physical_plan/planner.rs @@ -57,13 +57,12 @@ use crate::{ physical_plan::displayable, }; use arrow::compute::SortOptions; -use arrow::datatypes::{Schema, SchemaRef, TimeUnit}; +use arrow::datatypes::{DataType, Schema, SchemaRef}; use async_trait::async_trait; -use datafusion_common::scalar::TryFromValue; -use datafusion_common::ScalarValue; +use datafusion_common::{parse_interval, ScalarValue}; use datafusion_expr::expr::GroupingSet; use datafusion_expr::utils::{expand_wildcard, expr_to_columns}; -use datafusion_expr::{WindowFrameBound, WindowFrameUnits}; +use datafusion_expr::{WindowFrame, WindowFrameBound, WindowFrameUnits}; use datafusion_physical_expr::expressions::Literal; use datafusion_sql::utils::window_expr_common_partition_keys; use futures::future::BoxFuture; @@ -1369,75 +1368,34 @@ fn get_physical_expr_pair( let physical_name = physical_name(expr)?; Ok((physical_expr, physical_name)) } - +/// Casts the ScalarValue: `in_scalar` to column type once we have schema information +/// The resulting type is not necessarily same type with the `column_type`. For instance +/// if `column_type` is Timestamp the result is casted to Interval type. The reason is that +/// Operation between Timestamps is not meaningful, However operation between Timestamp and +/// Interval is valid. For basic types `column_type` is indeed the resulting type. fn convert_to_column_type( - order_by: &[PhysicalSortExpr], - physical_input_schema: &Schema, + column_type: arrow::datatypes::DataType, in_scalar: &ScalarValue, ) -> Result { - // Below query may produce error, it will not produce error - // if ScalarValues is not Null - let column_type = order_by - .first() - .ok_or_else(|| { - DataFusionError::Internal("Order By column cannot be empty".to_string()) - })? - .expr - .data_type(physical_input_schema); match in_scalar { - ScalarValue::Null => Ok(ScalarValue::Null), - ScalarValue::UInt64(Some(val)) => { - ScalarValue::try_from_value(&column_type?, *val) - } - ScalarValue::Float64(Some(val)) => { - ScalarValue::try_from_value(&column_type?, *val) - } - ScalarValue::IntervalDayTime(Some(val)) => { - // source val in millisecond precision - // convert IntervalDayTime to days and millisecond part - // TODO: below operation can be done with shift operationa - let denom = 2_i64.pow(32) as i64; - let days = val / denom; - let milli = val % denom; - println!("days: {}, milli :{}", days, milli); - let interval_in_milli = days * 24 * 60 * 60 * 1000 + milli; - Ok(match &column_type? { - arrow::datatypes::DataType::Timestamp(TimeUnit::Second, tz_opt) => { - ScalarValue::TimestampSecond( - Some(interval_in_milli / 1000), - tz_opt.clone(), - ) - } - arrow::datatypes::DataType::Timestamp(TimeUnit::Millisecond, tz_opt) => { - ScalarValue::TimestampMillisecond( - Some(interval_in_milli), - tz_opt.clone(), - ) - } - arrow::datatypes::DataType::Timestamp(TimeUnit::Microsecond, tz_opt) => { - ScalarValue::TimestampMicrosecond( - Some(interval_in_milli * 1000), - tz_opt.clone(), - ) - } - arrow::datatypes::DataType::Timestamp(TimeUnit::Nanosecond, tz_opt) => { - ScalarValue::TimestampNanosecond( - Some(interval_in_milli * 1000000), - tz_opt.clone(), - ) - } - datatype => { - return Err(DataFusionError::NotImplemented(format!( - "Can't create a scalar from data_type \"{:?}\"", - datatype - ))); - } - }) + // In here we can either get ScalarValue::Utf8(None) or + // ScalarValue::Utf8(Some(val)). The reason is that we convert the sqlparser result + // to the Utf8 for all possible cases, since we have no schema information during conversion. + // Here we have schema information, hence we can cast the appropriate ScalarValue Type. + ScalarValue::Utf8(None) => Ok(ScalarValue::Utf8(None)), + ScalarValue::Utf8(Some(val)) => { + if let DataType::Timestamp(..) = column_type { + // TODO: When the query is like ... '3' DAYS PRECEDING ..., "val" is "3 DAYS". + // In this case, the leading_field argument is unused and the code below works. + // When the query is like ... '3 DAYS' PRECEDING ..., "val" is "3". In this case, + // the code assumes it 3 milliseconds and produces wrong results. + // + // I'm not sure, but we may need to fix our sqlparser code as we try to fix this. + parse_interval("millisecond", val) + } else { + ScalarValue::try_from_string(val.clone(), &column_type) + } } - // TODO: Add handling for other valid ScalarValue types - // For now we can either get ScalarValue::Uint64 or None from PRECEDING AND - // FOLLOWING fields. When sql parser supports datetime types in the window - // range queries extend below to support datetime types inside the window. unexpected => Err(DataFusionError::Internal(format!( "unexpected: {:?}", unexpected @@ -1446,21 +1404,34 @@ fn convert_to_column_type( } fn convert_range_bound_to_column_type( - order_by: &[PhysicalSortExpr], - physical_input_schema: &Schema, + column_type: arrow::datatypes::DataType, bound: &WindowFrameBound, ) -> Result { Ok(match bound { - WindowFrameBound::Preceding(val) => WindowFrameBound::Preceding( - convert_to_column_type(order_by, physical_input_schema, val)?, - ), + WindowFrameBound::Preceding(val) => { + WindowFrameBound::Preceding(convert_to_column_type(column_type, val)?) + } WindowFrameBound::CurrentRow => WindowFrameBound::CurrentRow, - WindowFrameBound::Following(val) => WindowFrameBound::Following( - convert_to_column_type(order_by, physical_input_schema, val)?, - ), + WindowFrameBound::Following(val) => { + WindowFrameBound::Following(convert_to_column_type(column_type, val)?) + } }) } - +/// Check if window bounds are valid after schema information is available, and +/// window_frame bounds are casted to the corresponding column type. +/// queries like: +/// OVER (ORDER BY a RANGES BETWEEN 3 PRECEDING AND 5 PRECEDING) +/// OVER (ORDER BY a RANGES BETWEEN INTERVAL '3 DAY' PRECEDING AND '5 DAY' PRECEDING) are rejected +pub fn is_window_valid(window_frame: &Arc) -> Result<()> { + if window_frame.start_bound > window_frame.end_bound { + Err(DataFusionError::Execution(format!( + "Invalid window frame: start bound ({}) cannot be larger than end bound ({})", + window_frame.start_bound, window_frame.end_bound + ))) + } else { + Ok(()) + } +} /// Create a window expression with a name from a logical expression pub fn create_window_expr_with_name( e: &Expr, @@ -1523,30 +1494,54 @@ pub fn create_window_expr_with_name( }) .collect::>>()?; let mut new_window_frame = window_frame.clone(); + // Below query may produce error. We are calling its ? method only when it will + // not produce error logically (Such as when WindowFrameUnits is Range). + let order_by_column = order_by.first().ok_or_else(|| { + DataFusionError::Internal("Order By column cannot be empty".to_string()) + }); if let Some(window_frame) = window_frame { - if window_frame.units == WindowFrameUnits::Groups { - return Err(DataFusionError::NotImplemented( + match window_frame.units { + WindowFrameUnits::Groups => { + return Err(DataFusionError::NotImplemented( "Window frame definitions involving GROUPS are not supported yet" - .to_string(), - )); - } - if window_frame.units == WindowFrameUnits::Range { - new_window_frame.as_mut().unwrap().start_bound = - convert_range_bound_to_column_type( - &order_by, - physical_input_schema, - &window_frame.start_bound, - )?; - new_window_frame.as_mut().unwrap().end_bound = - convert_range_bound_to_column_type( - &order_by, - physical_input_schema, - &window_frame.end_bound, - )?; + .to_string(), + )); + } + WindowFrameUnits::Range => { + let column_type = + order_by_column?.expr.data_type(physical_input_schema)?; + new_window_frame.as_mut().unwrap().start_bound = + convert_range_bound_to_column_type( + column_type.clone(), + &window_frame.start_bound, + )?; + new_window_frame.as_mut().unwrap().end_bound = + convert_range_bound_to_column_type( + column_type, + &window_frame.end_bound, + )?; + } + WindowFrameUnits::Rows => { + // ROWS should have type usize which is Uint64 for our case + let column_type = arrow::datatypes::DataType::UInt64; + new_window_frame.as_mut().unwrap().start_bound = + convert_range_bound_to_column_type( + column_type.clone(), + &window_frame.start_bound, + )?; + new_window_frame.as_mut().unwrap().end_bound = + convert_range_bound_to_column_type( + column_type, + &window_frame.end_bound, + )?; + } } } let new_window_frame = new_window_frame.map(Arc::new); + if let Some(ref window_frame) = new_window_frame { + is_window_valid(window_frame)?; + } windows::create_window_expr( fun, name, @@ -1785,7 +1780,7 @@ fn tuple_err(value: (Result, Result)) -> Result<(T, R)> { #[cfg(test)] mod tests { use super::*; - use crate::assert_contains; + // use crate::assert_contains; use crate::datasource::MemTable; use crate::execution::context::TaskContext; use crate::execution::options::CsvReadOptions; @@ -1803,6 +1798,7 @@ mod tests { use arrow::array::{ArrayRef, DictionaryArray, Int32Array}; use arrow::datatypes::{DataType, Field, Int32Type, SchemaRef}; use arrow::record_batch::RecordBatch; + use datafusion_common::assert_contains; use datafusion_common::{DFField, DFSchema, DFSchemaRef}; use datafusion_expr::expr::GroupingSet; use datafusion_expr::sum; diff --git a/datafusion/core/src/test_util.rs b/datafusion/core/src/test_util.rs index ad27ea3c1aaf..36ec759f14e1 100644 --- a/datafusion/core/src/test_util.rs +++ b/datafusion/core/src/test_util.rs @@ -98,52 +98,6 @@ macro_rules! assert_batches_sorted_eq { }; } -/// A macro to assert that one string is contained within another with -/// a nice error message if they are not. -/// -/// Usage: `assert_contains!(actual, expected)` -/// -/// Is a macro so test error -/// messages are on the same line as the failure; -/// -/// Both arguments must be convertable into Strings (Into) -#[macro_export] -macro_rules! assert_contains { - ($ACTUAL: expr, $EXPECTED: expr) => { - let actual_value: String = $ACTUAL.into(); - let expected_value: String = $EXPECTED.into(); - assert!( - actual_value.contains(&expected_value), - "Can not find expected in actual.\n\nExpected:\n{}\n\nActual:\n{}", - expected_value, - actual_value - ); - }; -} - -/// A macro to assert that one string is NOT contained within another with -/// a nice error message if they are are. -/// -/// Usage: `assert_not_contains!(actual, unexpected)` -/// -/// Is a macro so test error -/// messages are on the same line as the failure; -/// -/// Both arguments must be convertable into Strings (Into) -#[macro_export] -macro_rules! assert_not_contains { - ($ACTUAL: expr, $UNEXPECTED: expr) => { - let actual_value: String = $ACTUAL.into(); - let unexpected_value: String = $UNEXPECTED.into(); - assert!( - !actual_value.contains(&unexpected_value), - "Found unexpected in actual.\n\nUnexpected:\n{}\n\nActual:\n{}", - unexpected_value, - actual_value - ); - }; -} - /// Returns the arrow test data directory, which is by default stored /// in a git submodule rooted at `testing/data`. /// diff --git a/datafusion/core/tests/sql/idenfifers.rs b/datafusion/core/tests/sql/idenfifers.rs index d50e5989cc98..e2fde56e959e 100644 --- a/datafusion/core/tests/sql/idenfifers.rs +++ b/datafusion/core/tests/sql/idenfifers.rs @@ -18,7 +18,8 @@ use std::sync::Arc; use arrow::{array::StringArray, record_batch::RecordBatch}; -use datafusion::{assert_batches_sorted_eq, assert_contains, prelude::*}; +use datafusion::{assert_batches_sorted_eq, prelude::*}; +use datafusion_common::assert_contains; use crate::sql::plan_and_collect; diff --git a/datafusion/core/tests/sql/mod.rs b/datafusion/core/tests/sql/mod.rs index e0167b5c4f3a..5ac6f003dbe9 100644 --- a/datafusion/core/tests/sql/mod.rs +++ b/datafusion/core/tests/sql/mod.rs @@ -25,10 +25,6 @@ use arrow::{ use chrono::prelude::*; use chrono::Duration; -use datafusion::assert_batches_eq; -use datafusion::assert_batches_sorted_eq; -use datafusion::assert_contains; -use datafusion::assert_not_contains; use datafusion::datasource::TableProvider; use datafusion::from_slice::FromSlice; use datafusion::logical_expr::{Aggregate, LogicalPlan, Projection, TableScan}; @@ -37,12 +33,14 @@ use datafusion::physical_plan::ExecutionPlan; use datafusion::physical_plan::ExecutionPlanVisitor; use datafusion::prelude::*; use datafusion::test_util; +use datafusion::{assert_batches_eq, assert_batches_sorted_eq}; use datafusion::{datasource::MemTable, physical_plan::collect}; use datafusion::{ error::{DataFusionError, Result}, physical_plan::ColumnarValue, }; use datafusion::{execution::context::SessionContext, physical_plan::displayable}; +use datafusion_common::{assert_contains, assert_not_contains}; use datafusion_expr::Volatility; use object_store::path::Path; use std::fs::File; diff --git a/datafusion/core/tests/sql/window.rs b/datafusion/core/tests/sql/window.rs index 4d7f23beb1c4..1b5344eb88cb 100644 --- a/datafusion/core/tests/sql/window.rs +++ b/datafusion/core/tests/sql/window.rs @@ -523,6 +523,7 @@ async fn window_frame_rows_preceding() -> Result<()> { assert_batches_eq!(expected, &actual); Ok(()) } + #[tokio::test] async fn window_frame_rows_preceding_with_partition_unique_order_by() -> Result<()> { let ctx = SessionContext::new(); @@ -1079,10 +1080,11 @@ async fn window_frame_partition_by_order_by_desc() -> Result<()> { async fn window_frame_range_float() -> Result<()> { let ctx = SessionContext::new(); register_aggregate_csv(&ctx).await?; - let sql = "SELECT SUM(c12) OVER (ORDER BY C12 RANGE BETWEEN 0.2 PRECEDING AND 0.2 FOLLOWING) -FROM aggregate_test_100 -ORDER BY C9 -LIMIT 5;"; + let sql = "SELECT + SUM(c12) OVER (ORDER BY C12 RANGE BETWEEN 0.2 PRECEDING AND 0.2 FOLLOWING) + FROM aggregate_test_100 + ORDER BY C9 + LIMIT 5"; let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+-----------------------------+", @@ -1133,9 +1135,7 @@ async fn window_frame_ranges_timestamp() -> Result<()> { // execute the query let df = ctx .sql( - // "SELECT COUNT(*) OVER (ORDER BY ts) FROM t;" - "SELECT ts, COUNT(*) OVER (ORDER BY ts RANGE BETWEEN '1 DAY' PRECEDING AND '2 DAY' FOLLOWING) FROM t;" - // "SELECT a FROM t;" + "SELECT ts, COUNT(*) OVER (ORDER BY ts RANGE BETWEEN INTERVAL '1' DAY PRECEDING AND INTERVAL '2 DAY' FOLLOWING) FROM t;" ) .await?; diff --git a/datafusion/expr/Cargo.toml b/datafusion/expr/Cargo.toml index 6b25a0ddedca..bce56da99e8a 100644 --- a/datafusion/expr/Cargo.toml +++ b/datafusion/expr/Cargo.toml @@ -38,4 +38,5 @@ path = "src/lib.rs" ahash = { version = "0.8", default-features = false, features = ["runtime-rng"] } arrow = { version = "24.0.0", default-features = false } datafusion-common = { path = "../common", version = "13.0.0" } +# sqlparser = "0.25" sqlparser = { git = "https://github.com/synnada-ai/sqlparser-rs", branch = "main" } diff --git a/datafusion/expr/src/window_frame.rs b/datafusion/expr/src/window_frame.rs index 300793384e7b..eb90bb863ee1 100644 --- a/datafusion/expr/src/window_frame.rs +++ b/datafusion/expr/src/window_frame.rs @@ -23,7 +23,6 @@ //! - An ending frame boundary, //! - An EXCLUDE clause. -use datafusion_common::scalar::get_scalar_value; use datafusion_common::{DataFusionError, Result, ScalarValue}; use sqlparser::ast; use std::cmp::Ordering; @@ -67,21 +66,16 @@ impl TryFrom for WindowFrame { .map(WindowFrameBound::from) .unwrap_or(WindowFrameBound::CurrentRow); - if let WindowFrameBound::Following(ScalarValue::Null) = start_bound { + if let WindowFrameBound::Following(ScalarValue::Utf8(None)) = start_bound { Err(DataFusionError::Execution( "Invalid window frame: start bound cannot be unbounded following" .to_owned(), )) - } else if let WindowFrameBound::Preceding(ScalarValue::Null) = end_bound { + } else if let WindowFrameBound::Preceding(ScalarValue::Utf8(None)) = end_bound { Err(DataFusionError::Execution( "Invalid window frame: end bound cannot be unbounded preceding" .to_owned(), )) - } else if start_bound > end_bound { - Err(DataFusionError::Execution(format!( - "Invalid window frame: start bound ({}) cannot be larger than end bound ({})", - start_bound, end_bound - ))) } else { let units = value.units.into(); Ok(Self { @@ -97,7 +91,7 @@ impl Default for WindowFrame { fn default() -> Self { WindowFrame { units: WindowFrameUnits::Range, - start_bound: WindowFrameBound::Preceding(ScalarValue::Null), + start_bound: WindowFrameBound::Preceding(ScalarValue::Utf8(None)), end_bound: WindowFrameBound::CurrentRow, } } @@ -105,25 +99,25 @@ impl Default for WindowFrame { pub fn convert_range_bound_to_scalar_value(v: ast::RangeBounds) -> Result { match v { - ast::RangeBounds::Number(number) => get_scalar_value(&number[..]), + ast::RangeBounds::Number(number) => Ok(ScalarValue::Utf8(Some(number))), ast::RangeBounds::Interval(ast::Expr::Interval { value, - leading_field: _, + leading_field, leading_precision: _, last_field: _, fractional_seconds_precision: _, }) => { - let res = match *value { + let mut res = match *value { ast::Expr::Value(ast::Value::SingleQuotedString(elem)) => Ok(elem), unexpected => Err(DataFusionError::Internal(format!( "INTERVAL expression cannot be {:?}", unexpected ))), }; - // parse the interval most precise way possible which is millisecond for now - // TODO: Add parser the IntervalMonthDayNano case with nanosecond option - let res = datafusion_common::parsers::parse_interval("millisecond", &res?)?; - Ok(res) + if let Some(leading_field) = leading_field { + res = Ok(format!("{} {}", res?, leading_field)); + }; + Ok(ScalarValue::Utf8(Some(res?))) } unexpected => Err(DataFusionError::Internal(format!( "RangeBounds cannot be {:?}", @@ -177,12 +171,16 @@ impl From for WindowFrameBound { let res = convert_range_bound_to_scalar_value(v).unwrap(); Self::Preceding(res) } - ast::WindowFrameBound::Preceding(None) => Self::Preceding(ScalarValue::Null), + ast::WindowFrameBound::Preceding(None) => { + Self::Preceding(ScalarValue::Utf8(None)) + } ast::WindowFrameBound::Following(Some(v)) => { let res = convert_range_bound_to_scalar_value(v).unwrap(); Self::Following(res) } - ast::WindowFrameBound::Following(None) => Self::Following(ScalarValue::Null), + ast::WindowFrameBound::Following(None) => { + Self::Following(ScalarValue::Utf8(None)) + } ast::WindowFrameBound::CurrentRow => Self::CurrentRow, } } @@ -191,12 +189,12 @@ impl From for WindowFrameBound { impl fmt::Display for WindowFrameBound { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match self { - WindowFrameBound::Preceding(ScalarValue::Null) => { + WindowFrameBound::Preceding(ScalarValue::Utf8(None)) => { f.write_str("UNBOUNDED PRECEDING") } WindowFrameBound::Preceding(n) => write!(f, "{} PRECEDING", n), WindowFrameBound::CurrentRow => f.write_str("CURRENT ROW"), - WindowFrameBound::Following(ScalarValue::Null) => { + WindowFrameBound::Following(ScalarValue::Utf8(None)) => { f.write_str("UNBOUNDED FOLLOWING") } WindowFrameBound::Following(n) => write!(f, "{} FOLLOWING", n), @@ -236,25 +234,75 @@ impl WindowFrameBound { /// rank and also for 0 preceding / following it is the same as current row fn get_rank(&self) -> (u8, u64) { match self { - WindowFrameBound::Preceding(ScalarValue::Null) => (0, 0), - WindowFrameBound::Following(ScalarValue::Null) => (4, 0), + WindowFrameBound::Preceding(ScalarValue::Utf8(None)) => (0, 0), + WindowFrameBound::Following(ScalarValue::Utf8(None)) => (4, 0), WindowFrameBound::Preceding(ScalarValue::UInt64(Some(0))) + | WindowFrameBound::Preceding(ScalarValue::UInt32(Some(0))) + | WindowFrameBound::Preceding(ScalarValue::UInt16(Some(0))) + | WindowFrameBound::Preceding(ScalarValue::Int64(Some(0))) + | WindowFrameBound::Preceding(ScalarValue::Int32(Some(0))) + | WindowFrameBound::Preceding(ScalarValue::Int16(Some(0))) | WindowFrameBound::CurrentRow - | WindowFrameBound::Following(ScalarValue::UInt64(Some(0))) => (2, 0), + | WindowFrameBound::Following(ScalarValue::UInt64(Some(0))) + | WindowFrameBound::Following(ScalarValue::UInt32(Some(0))) + | WindowFrameBound::Following(ScalarValue::UInt16(Some(0))) + | WindowFrameBound::Following(ScalarValue::Int64(Some(0))) + | WindowFrameBound::Following(ScalarValue::Int32(Some(0))) + | WindowFrameBound::Following(ScalarValue::Int16(Some(0))) => (2, 0), WindowFrameBound::Preceding(ScalarValue::UInt64(Some(v))) => { (1, u64::MAX - *v as u64) } + WindowFrameBound::Preceding(ScalarValue::UInt32(Some(v))) => { + (1, u64::MAX - *v as u64) + } + WindowFrameBound::Preceding(ScalarValue::UInt16(Some(v))) => { + (1, u64::MAX - *v as u64) + } + WindowFrameBound::Preceding(ScalarValue::Int64(Some(v))) => { + (1, u64::MAX - *v as u64) + } + WindowFrameBound::Preceding(ScalarValue::Int32(Some(v))) => { + (1, u64::MAX - *v as u64) + } + WindowFrameBound::Preceding(ScalarValue::Int16(Some(v))) => { + (1, u64::MAX - *v as u64) + } WindowFrameBound::Preceding(ScalarValue::Float64(Some(v))) => { (1, u64::MAX - *v as u64) } - WindowFrameBound::Following(ScalarValue::UInt64(Some(v))) => (3, *v as u64), - WindowFrameBound::Following(ScalarValue::Float64(Some(v))) => (3, *v as u64), + WindowFrameBound::Preceding(ScalarValue::Float32(Some(v))) => { + (1, u64::MAX - *v as u64) + } WindowFrameBound::Preceding(ScalarValue::IntervalDayTime(Some(v))) => { (1, u64::MAX - *v as u64) } + WindowFrameBound::Following(ScalarValue::UInt64(Some(v))) => (3, *v as u64), + WindowFrameBound::Following(ScalarValue::UInt32(Some(v))) => (3, *v as u64), + WindowFrameBound::Following(ScalarValue::UInt16(Some(v))) => (3, *v as u64), + WindowFrameBound::Following(ScalarValue::Int64(Some(v))) => (3, *v as u64), + WindowFrameBound::Following(ScalarValue::Int32(Some(v))) => (3, *v as u64), + WindowFrameBound::Following(ScalarValue::Int16(Some(v))) => (3, *v as u64), + WindowFrameBound::Following(ScalarValue::Float64(Some(v))) => (3, *v as u64), + WindowFrameBound::Following(ScalarValue::Float32(Some(v))) => (3, *v as u64), WindowFrameBound::Following(ScalarValue::IntervalDayTime(Some(v))) => { (3, *v as u64) } + WindowFrameBound::Preceding(ScalarValue::Utf8(Some(v))) => { + match v.as_ref() { + "0" => (2, 0), + // After schema information we do not have string type for WindowFrameBound + // hence we can cast it to a arbitrary point, TODO: fix here for better handling + _elem => (1, u64::MAX - 1), + } + } + WindowFrameBound::Following(ScalarValue::Utf8(Some(v))) => { + match v.as_ref() { + "0" => (2, 0), + // After schema information we do not have string type for WindowFrameBound + // hence we can cast it to a arbitrary point, TODO: fix here for better handling + _elem => (3, 1), + } + } _ => todo!(), } } @@ -304,6 +352,12 @@ mod tests { use super::*; #[test] + #[ignore] + // We no longer check for validity of the preceding, following during window frame creation + // Since we are accepting different kind of types during creation, validity of that check can be only + // done when schema information is available. The is no trivial way to reject PRECEDING '1 MONTH' is + // later than PRECEDING "40 DAYS". Hence this test is ignored, However they are rejected during physical + // plan creation once schema information is available. fn test_window_frame_creation() -> Result<()> { let window_frame = ast::WindowFrame { units: ast::WindowFrameUnits::Range, @@ -373,16 +427,16 @@ mod tests { WindowFrameBound::Following(ScalarValue::UInt64(Some(2))) ); assert_eq!( - WindowFrameBound::Following(ScalarValue::Null), - WindowFrameBound::Following(ScalarValue::Null) + WindowFrameBound::Following(ScalarValue::Utf8(None)), + WindowFrameBound::Following(ScalarValue::Utf8(None)) ); assert_eq!( WindowFrameBound::Preceding(ScalarValue::UInt64(Some(2))), WindowFrameBound::Preceding(ScalarValue::UInt64(Some(2))) ); assert_eq!( - WindowFrameBound::Preceding(ScalarValue::Null), - WindowFrameBound::Preceding(ScalarValue::Null) + WindowFrameBound::Preceding(ScalarValue::Utf8(None)), + WindowFrameBound::Preceding(ScalarValue::Utf8(None)) ); } @@ -402,15 +456,15 @@ mod tests { < WindowFrameBound::Preceding(ScalarValue::UInt64(Some(u64::MAX - 1))) ); assert!( - WindowFrameBound::Preceding(ScalarValue::Null) + WindowFrameBound::Preceding(ScalarValue::Utf8(None)) < WindowFrameBound::Preceding(ScalarValue::UInt64(Some(1000000))) ); assert!( - WindowFrameBound::Preceding(ScalarValue::Null) + WindowFrameBound::Preceding(ScalarValue::Utf8(None)) < WindowFrameBound::Preceding(ScalarValue::UInt64(Some(u64::MAX))) ); assert!( - WindowFrameBound::Preceding(ScalarValue::Null) + WindowFrameBound::Preceding(ScalarValue::Utf8(None)) < WindowFrameBound::Following(ScalarValue::UInt64(Some(0))) ); assert!( @@ -427,11 +481,11 @@ mod tests { ); assert!( WindowFrameBound::Following(ScalarValue::UInt64(Some(2))) - < WindowFrameBound::Following(ScalarValue::Null) + < WindowFrameBound::Following(ScalarValue::Utf8(None)) ); assert!( WindowFrameBound::Following(ScalarValue::UInt64(Some(u64::MAX))) - < WindowFrameBound::Following(ScalarValue::Null) + < WindowFrameBound::Following(ScalarValue::Utf8(None)) ); } } diff --git a/datafusion/physical-expr/src/expressions/datetime.rs b/datafusion/physical-expr/src/expressions/datetime.rs index fa021f61a940..a83ce0208b6c 100644 --- a/datafusion/physical-expr/src/expressions/datetime.rs +++ b/datafusion/physical-expr/src/expressions/datetime.rs @@ -15,7 +15,6 @@ // specific language governing permissions and limitations // under the License. -use crate::expressions::delta::shift_months; use crate::PhysicalExpr; use arrow::array::{ Array, ArrayRef, Date32Array, Date64Array, TimestampMicrosecondArray, @@ -27,13 +26,15 @@ use arrow::datatypes::{ TimestampMillisecondType, TimestampNanosecondType, TimestampSecondType, }; use arrow::record_batch::RecordBatch; -use chrono::{Datelike, Duration, NaiveDate, NaiveDateTime}; +use datafusion_common::datetime::{ + date32_add, date64_add, evaluate_scalar, microseconds_add, milliseconds_add, + nanoseconds_add, seconds_add, +}; use datafusion_common::Result; use datafusion_common::{DataFusionError, ScalarValue}; use datafusion_expr::{ColumnarValue, Operator}; use std::any::Any; use std::fmt::{Display, Formatter}; -use std::ops::{Add, Sub}; use std::sync::Arc; /// Perform DATE/TIME/TIMESTAMP +/ INTERVAL math @@ -136,7 +137,9 @@ impl PhysicalExpr for DateTimeIntervalExpr { }; match dates { - ColumnarValue::Scalar(operand) => evaluate_scalar(operand, sign, intervals), + ColumnarValue::Scalar(operand) => Ok(ColumnarValue::Scalar(evaluate_scalar( + operand, sign, intervals, + )?)), ColumnarValue::Array(array) => evaluate_array(array, sign, intervals), } } @@ -214,138 +217,6 @@ pub fn evaluate_array( Ok(ColumnarValue::Array(ret)) } -fn evaluate_scalar( - operand: ScalarValue, - sign: i32, - scalar: &ScalarValue, -) -> Result { - let res = match operand { - ScalarValue::Date32(Some(days)) => { - let value = date32_add(days, scalar, sign)?; - ColumnarValue::Scalar(ScalarValue::Date32(Some(value))) - } - ScalarValue::Date64(Some(ms)) => { - let value = date64_add(ms, scalar, sign)?; - ColumnarValue::Scalar(ScalarValue::Date64(Some(value))) - } - ScalarValue::TimestampSecond(Some(ts_s), zone) => { - let value = seconds_add(ts_s, scalar, sign)?; - ColumnarValue::Scalar(ScalarValue::TimestampSecond(Some(value), zone)) - } - ScalarValue::TimestampMillisecond(Some(ts_ms), zone) => { - let value = milliseconds_add(ts_ms, scalar, sign)?; - ColumnarValue::Scalar(ScalarValue::TimestampMillisecond(Some(value), zone)) - } - ScalarValue::TimestampMicrosecond(Some(ts_us), zone) => { - let value = microseconds_add(ts_us, scalar, sign)?; - ColumnarValue::Scalar(ScalarValue::TimestampMicrosecond(Some(value), zone)) - } - ScalarValue::TimestampNanosecond(Some(ts_ns), zone) => { - let value = nanoseconds_add(ts_ns, scalar, sign)?; - ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(value), zone)) - } - _ => Err(DataFusionError::Execution(format!( - "Invalid lhs type {} for DateIntervalExpr", - operand.get_datatype() - )))?, - }; - Ok(res) -} - -#[inline] -fn date32_add(days: i32, scalar: &ScalarValue, sign: i32) -> Result { - let epoch = NaiveDate::from_ymd(1970, 1, 1); - let prior = epoch.add(Duration::days(days as i64)); - let posterior = do_date_math(prior, scalar, sign)?; - Ok(posterior.sub(epoch).num_days() as i32) -} - -#[inline] -fn date64_add(ms: i64, scalar: &ScalarValue, sign: i32) -> Result { - let epoch = NaiveDate::from_ymd(1970, 1, 1); - let prior = epoch.add(Duration::milliseconds(ms)); - let posterior = do_date_math(prior, scalar, sign)?; - Ok(posterior.sub(epoch).num_milliseconds()) -} - -#[inline] -fn seconds_add(ts_s: i64, scalar: &ScalarValue, sign: i32) -> Result { - Ok(do_date_time_math(ts_s, 0, scalar, sign)?.timestamp()) -} - -#[inline] -fn milliseconds_add(ts_ms: i64, scalar: &ScalarValue, sign: i32) -> Result { - let secs = ts_ms / 1000; - let nsecs = ((ts_ms % 1000) * 1_000_000) as u32; - Ok(do_date_time_math(secs, nsecs, scalar, sign)?.timestamp_millis()) -} - -#[inline] -fn microseconds_add(ts_us: i64, scalar: &ScalarValue, sign: i32) -> Result { - let secs = ts_us / 1_000_000; - let nsecs = ((ts_us % 1_000_000) * 1000) as u32; - Ok(do_date_time_math(secs, nsecs, scalar, sign)?.timestamp_nanos() / 1000) -} - -#[inline] -fn nanoseconds_add(ts_ns: i64, scalar: &ScalarValue, sign: i32) -> Result { - let secs = ts_ns / 1_000_000_000; - let nsecs = (ts_ns % 1_000_000_000) as u32; - Ok(do_date_time_math(secs, nsecs, scalar, sign)?.timestamp_nanos()) -} - -#[inline] -fn do_date_time_math( - secs: i64, - nsecs: u32, - scalar: &ScalarValue, - sign: i32, -) -> Result { - let prior = NaiveDateTime::from_timestamp(secs, nsecs); - do_date_math(prior, scalar, sign) -} - -fn do_date_math(prior: D, scalar: &ScalarValue, sign: i32) -> Result -where - D: Datelike + Add, -{ - Ok(match scalar { - ScalarValue::IntervalDayTime(Some(i)) => add_day_time(prior, *i, sign), - ScalarValue::IntervalYearMonth(Some(i)) => shift_months(prior, *i * sign), - ScalarValue::IntervalMonthDayNano(Some(i)) => add_m_d_nano(prior, *i, sign), - other => Err(DataFusionError::Execution(format!( - "DateIntervalExpr does not support non-interval type {:?}", - other - )))?, - }) -} - -// Can remove once https://github.com/apache/arrow-rs/pull/2031 is released -fn add_m_d_nano(prior: D, interval: i128, sign: i32) -> D -where - D: Datelike + Add, -{ - let interval = interval as u128; - let nanos = (interval >> 64) as i64 * sign as i64; - let days = (interval >> 32) as i32 * sign; - let months = interval as i32 * sign; - let a = shift_months(prior, months); - let b = a.add(Duration::days(days as i64)); - b.add(Duration::nanoseconds(nanos)) -} - -// Can remove once https://github.com/apache/arrow-rs/pull/2031 is released -fn add_day_time(prior: D, interval: i64, sign: i32) -> D -where - D: Datelike + Add, -{ - let interval = interval as u64; - let days = (interval >> 32) as i32 * sign; - let ms = interval as i32 * sign; - let intermediate = prior.add(Duration::days(days as i64)); - intermediate.add(Duration::milliseconds(ms as i64)) -} - #[cfg(test)] mod tests { use super::*; @@ -353,8 +224,11 @@ mod tests { use crate::execution_props::ExecutionProps; use arrow::array::{ArrayRef, Date32Builder}; use arrow::datatypes::*; + use chrono::{Duration, NaiveDate}; + use datafusion_common::delta::shift_months; use datafusion_common::{Column, Result, ToDFSchema}; use datafusion_expr::Expr; + use std::ops::Add; #[test] fn add_11_months() { @@ -403,8 +277,7 @@ mod tests { // setup let dt = Expr::Literal(ScalarValue::Date32(Some(0))); let op = Operator::Plus; - let interval = create_day_time(1, 0); - let interval = Expr::Literal(ScalarValue::IntervalDayTime(Some(interval))); + let interval = Expr::Literal(ScalarValue::new_interval_dt(1, 0)); // exercise let res = exercise(&dt, op, &interval)?; @@ -454,8 +327,8 @@ mod tests { // setup let dt = Expr::Literal(ScalarValue::Date64(Some(0))); let op = Operator::Plus; - let interval = create_day_time(-15, -24 * 60 * 60 * 1000); - let interval = Expr::Literal(ScalarValue::IntervalDayTime(Some(interval))); + let interval = + Expr::Literal(ScalarValue::new_interval_dt(-15, -24 * 60 * 60 * 1000)); // exercise let res = exercise(&dt, op, &interval)?; @@ -505,10 +378,7 @@ mod tests { // setup let dt = Expr::Literal(ScalarValue::Date32(Some(0))); let op = Operator::Plus; - - let interval = create_month_day_nano(-12, -15, -42); - - let interval = Expr::Literal(ScalarValue::IntervalMonthDayNano(Some(interval))); + let interval = Expr::Literal(ScalarValue::new_interval_mdn(-12, -15, -42)); // exercise let res = exercise(&dt, op, &interval)?; @@ -534,8 +404,7 @@ mod tests { let now_ts_ns = chrono::Utc::now().timestamp_nanos(); let dt = Expr::Literal(ScalarValue::TimestampNanosecond(Some(now_ts_ns), None)); let op = Operator::Plus; - let interval = create_day_time(0, 1); - let interval = Expr::Literal(ScalarValue::IntervalDayTime(Some(interval))); + let interval = Expr::Literal(ScalarValue::new_interval_dt(0, 1)); // exercise let res = exercise(&dt, op, &interval)?; @@ -558,8 +427,7 @@ mod tests { let now_ts_s = chrono::Utc::now().timestamp(); let dt = Expr::Literal(ScalarValue::TimestampSecond(Some(now_ts_s), None)); let op = Operator::Plus; - let interval = create_day_time(0, 2 * 3600 * 1_000); - let interval = Expr::Literal(ScalarValue::IntervalDayTime(Some(interval))); + let interval = Expr::Literal(ScalarValue::new_interval_dt(0, 2 * 3600 * 1_000)); // exercise let res = exercise(&dt, op, &interval)?; @@ -582,8 +450,7 @@ mod tests { let now_ts_s = chrono::Utc::now().timestamp(); let dt = Expr::Literal(ScalarValue::TimestampSecond(Some(now_ts_s), None)); let op = Operator::Minus; - let interval = create_day_time(0, 4 * 3600 * 1_000); - let interval = Expr::Literal(ScalarValue::IntervalDayTime(Some(interval))); + let interval = Expr::Literal(ScalarValue::new_interval_dt(0, 4 * 3600 * 1_000)); // exercise let res = exercise(&dt, op, &interval)?; @@ -606,8 +473,7 @@ mod tests { let now_ts_ns = chrono::Utc::now().timestamp_nanos(); let dt = Expr::Literal(ScalarValue::TimestampNanosecond(Some(now_ts_ns), None)); let op = Operator::Plus; - let interval = create_day_time(8, 0); - let interval = Expr::Literal(ScalarValue::IntervalDayTime(Some(interval))); + let interval = Expr::Literal(ScalarValue::new_interval_dt(8, 0)); // exercise let res = exercise(&dt, op, &interval)?; @@ -630,8 +496,7 @@ mod tests { let now_ts_ns = chrono::Utc::now().timestamp_nanos(); let dt = Expr::Literal(ScalarValue::TimestampNanosecond(Some(now_ts_ns), None)); let op = Operator::Minus; - let interval = create_day_time(16, 0); - let interval = Expr::Literal(ScalarValue::IntervalDayTime(Some(interval))); + let interval = Expr::Literal(ScalarValue::new_interval_dt(16, 0)); // exercise let res = exercise(&dt, op, &interval)?; @@ -660,8 +525,7 @@ mod tests { let props = ExecutionProps::new(); let dt = Expr::Column(Column::from_name("a")); - let interval = create_day_time(26, 0); - let interval = Expr::Literal(ScalarValue::IntervalDayTime(Some(interval))); + let interval = Expr::Literal(ScalarValue::new_interval_dt(26, 0)); let op = Operator::Plus; let lhs = create_physical_expr(&dt, &dfs, &schema, &props)?; @@ -754,29 +618,4 @@ mod tests { let res = cut.evaluate(&batch)?; Ok(res) } - - // Can remove once https://github.com/apache/arrow-rs/pull/2031 is released - - /// Creates an IntervalDayTime given its constituent components - /// - /// https://github.com/apache/arrow-rs/blob/e59b023480437f67e84ba2f827b58f78fd44c3a1/integration-testing/src/lib.rs#L222 - fn create_day_time(days: i32, millis: i32) -> i64 { - let m = millis as u64 & u32::MAX as u64; - let d = (days as u64 & u32::MAX as u64) << 32; - (m | d) as i64 - } - - // Can remove once https://github.com/apache/arrow-rs/pull/2031 is released - /// Creates an IntervalMonthDayNano given its constituent components - /// - /// Source: https://github.com/apache/arrow-rs/blob/e59b023480437f67e84ba2f827b58f78fd44c3a1/integration-testing/src/lib.rs#L340 - /// ((nanoseconds as i128) & 0xFFFFFFFFFFFFFFFF) << 64 - /// | ((days as i128) & 0xFFFFFFFF) << 32 - /// | ((months as i128) & 0xFFFFFFFF); - fn create_month_day_nano(months: i32, days: i32, nanos: i64) -> i128 { - let m = months as u128 & u32::MAX as u128; - let d = (days as u128 & u32::MAX as u128) << 32; - let n = (nanos as u128) << 64; - (m | d | n) as i128 - } } diff --git a/datafusion/physical-expr/src/expressions/mod.rs b/datafusion/physical-expr/src/expressions/mod.rs index 208e6d0b51fb..ffbefbd3fd52 100644 --- a/datafusion/physical-expr/src/expressions/mod.rs +++ b/datafusion/physical-expr/src/expressions/mod.rs @@ -23,7 +23,6 @@ mod case; mod cast; mod column; mod datetime; -mod delta; mod get_indexed_field; mod in_list; mod is_not_null; diff --git a/datafusion/physical-expr/src/window/aggregate.rs b/datafusion/physical-expr/src/window/aggregate.rs index 0d78624ca4a3..4ace0b94eeff 100644 --- a/datafusion/physical-expr/src/window/aggregate.rs +++ b/datafusion/physical-expr/src/window/aggregate.rs @@ -29,7 +29,6 @@ use arrow::record_batch::RecordBatch; use arrow::{array::ArrayRef, datatypes::Field}; use datafusion_common::bisect::bisect; -use datafusion_common::scalar::TryFromValue; use datafusion_common::Result; use datafusion_common::{DataFusionError, ScalarValue}; use datafusion_expr::{Accumulator, WindowFrameBound}; @@ -159,16 +158,20 @@ fn calculate_index_of_row( current_row_values .iter() .map(|value| { - Ok(if value.is_null() { - value.clone() - } else if SEARCH_SIDE == is_descending { + if value.is_null() { + return Ok(value.clone()); + } + if SEARCH_SIDE == is_descending { // TODO: ADD overflow check - value.add(delta)? + value.add(delta) } else if value.is_unsigned() && value < delta { - ScalarValue::try_from_value(&value.get_datatype(), 0)? + // NOTE: This gets a polymorphic zero without having long coercion code for ScalarValue. + // If we decide to implement a "default" construction mechanism for ScalarValue, + // change the following statement to use that. + value.sub(value) } else { - value.sub(delta)? - }) + value.sub(delta) + } }) .collect::>>()? } else { @@ -191,7 +194,7 @@ fn calculate_current_window( WindowFrameUnits::Range => { let start = match &window_frame.start_bound { // UNBOUNDED PRECEDING - WindowFrameBound::Preceding(ScalarValue::Null) => Ok(0), + WindowFrameBound::Preceding(ScalarValue::Utf8(None)) => Ok(0), WindowFrameBound::Preceding(n) => calculate_index_of_row::( range_columns, sort_options, @@ -205,7 +208,7 @@ fn calculate_current_window( None, ), // UNBOUNDED FOLLOWING - WindowFrameBound::Following(ScalarValue::Null) => { + WindowFrameBound::Following(ScalarValue::Utf8(None)) => { Err(DataFusionError::Internal(format!( "Frame start cannot be UNBOUNDED FOLLOWING '{:?}'", window_frame @@ -220,7 +223,7 @@ fn calculate_current_window( }; let end = match &window_frame.end_bound { // UNBOUNDED PRECEDING - WindowFrameBound::Preceding(ScalarValue::Null) => { + WindowFrameBound::Preceding(ScalarValue::Utf8(None)) => { Err(DataFusionError::Internal(format!( "Frame end cannot be UNBOUNDED PRECEDING '{:?}'", window_frame @@ -239,7 +242,7 @@ fn calculate_current_window( None, ), // UNBOUNDED FOLLOWING - WindowFrameBound::Following(ScalarValue::Null) => Ok(length), + WindowFrameBound::Following(ScalarValue::Utf8(None)) => Ok(length), WindowFrameBound::Following(n) => calculate_index_of_row::( range_columns, sort_options, @@ -252,7 +255,7 @@ fn calculate_current_window( WindowFrameUnits::Rows => { let start = match &window_frame.start_bound { // UNBOUNDED PRECEDING - WindowFrameBound::Preceding(ScalarValue::Null) => Ok(0), + WindowFrameBound::Preceding(ScalarValue::Utf8(None)) => Ok(0), WindowFrameBound::Preceding(ScalarValue::UInt64(Some(n))) => { if idx >= *n as usize { Ok(idx - *n as usize) @@ -265,7 +268,7 @@ fn calculate_current_window( } WindowFrameBound::CurrentRow => Ok(idx), // UNBOUNDED FOLLOWING - WindowFrameBound::Following(ScalarValue::Null) => { + WindowFrameBound::Following(ScalarValue::Utf8(None)) => { Err(DataFusionError::Internal(format!( "Frame start cannot be UNBOUNDED FOLLOWING '{:?}'", window_frame @@ -280,7 +283,7 @@ fn calculate_current_window( }; let end = match window_frame.end_bound { // UNBOUNDED PRECEDING - WindowFrameBound::Preceding(ScalarValue::Null) => { + WindowFrameBound::Preceding(ScalarValue::Utf8(None)) => { Err(DataFusionError::Internal(format!( "Frame end cannot be UNBOUNDED PRECEDING '{:?}'", window_frame @@ -298,7 +301,7 @@ fn calculate_current_window( } WindowFrameBound::CurrentRow => Ok(idx + 1), // UNBOUNDED FOLLOWING - WindowFrameBound::Following(ScalarValue::Null) => Ok(length), + WindowFrameBound::Following(ScalarValue::Utf8(None)) => Ok(length), WindowFrameBound::Following(ScalarValue::UInt64(Some(n))) => { Ok(min(idx + n as usize + 1, length)) } @@ -331,7 +334,7 @@ impl AggregateWindowAccumulator { fn implicit_order_by_window() -> Arc { Arc::new(WindowFrame { units: WindowFrameUnits::Range, - start_bound: WindowFrameBound::Preceding(ScalarValue::Null), + start_bound: WindowFrameBound::Preceding(ScalarValue::Utf8(None)), end_bound: WindowFrameBound::CurrentRow, }) } diff --git a/datafusion/proto/proto/datafusion.proto b/datafusion/proto/proto/datafusion.proto index 79676a9165ca..1ec22d9d7bdb 100644 --- a/datafusion/proto/proto/datafusion.proto +++ b/datafusion/proto/proto/datafusion.proto @@ -629,11 +629,7 @@ enum WindowFrameBoundType { message WindowFrameBound { WindowFrameBoundType window_frame_bound_type = 1; - // "optional" keyword is stable in protoc 3.15 but prost is still on 3.14 (see https://github.com/tokio-rs/prost/issues/430 and https://github.com/tokio-rs/prost/pull/455) - // this syntax is ugly but is binary compatible with the "optional" keyword (see https://stackoverflow.com/questions/42622015/how-to-define-an-optional-field-in-protobuf-3) - oneof bound_value { - ScalarValue value = 2; - } + ScalarValue bound_value = 2; } /////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/datafusion/proto/src/to_proto.rs b/datafusion/proto/src/to_proto.rs index 032542af4063..330fe3e3ee30 100644 --- a/datafusion/proto/src/to_proto.rs +++ b/datafusion/proto/src/to_proto.rs @@ -392,19 +392,21 @@ impl From for protobuf::WindowFrameUnits { impl From for protobuf::WindowFrameBound { fn from(bound: WindowFrameBound) -> Self { match bound { - WindowFrameBound::CurrentRow => Self { - window_frame_bound_type: protobuf::WindowFrameBoundType::CurrentRow - .into(), - bound_value: None, - }, + WindowFrameBound::CurrentRow => { + let pb_value: protobuf::ScalarValue = + (&ScalarValue::Utf8(None)).try_into().unwrap(); + Self { + window_frame_bound_type: protobuf::WindowFrameBoundType::CurrentRow + .into(), + bound_value: Some(pb_value), + } + } WindowFrameBound::Preceding(v) => { let pb_value: protobuf::ScalarValue = (&v).try_into().unwrap(); Self { window_frame_bound_type: protobuf::WindowFrameBoundType::Preceding .into(), - bound_value: Some(protobuf::window_frame_bound::BoundValue::Value( - pb_value, - )), + bound_value: Some(pb_value), } } WindowFrameBound::Following(v) => { @@ -412,9 +414,7 @@ impl From for protobuf::WindowFrameBound { Self { window_frame_bound_type: protobuf::WindowFrameBoundType::Following .into(), - bound_value: Some(protobuf::window_frame_bound::BoundValue::Value( - pb_value, - )), + bound_value: Some(pb_value), } } } diff --git a/datafusion/sql/Cargo.toml b/datafusion/sql/Cargo.toml index b3b79c7cf23f..ea03cc20277b 100644 --- a/datafusion/sql/Cargo.toml +++ b/datafusion/sql/Cargo.toml @@ -40,4 +40,5 @@ unicode_expressions = [] arrow = { version = "24.0.0", default-features = false } datafusion-common = { path = "../common", version = "13.0.0" } datafusion-expr = { path = "../expr", version = "13.0.0" } +# sqlparser = "0.25" sqlparser = { git = "https://github.com/synnada-ai/sqlparser-rs", branch = "main" } diff --git a/datafusion/sql/src/planner.rs b/datafusion/sql/src/planner.rs index 9c3950fa2000..b2b4e66e99c2 100644 --- a/datafusion/sql/src/planner.rs +++ b/datafusion/sql/src/planner.rs @@ -55,6 +55,7 @@ use datafusion_expr::expr::{Case, GroupingSet}; use datafusion_expr::logical_plan::builder::project_with_alias; use datafusion_expr::logical_plan::{Filter, Subquery}; use datafusion_expr::Expr::Alias; +use sqlparser::ast::TimezoneInfo; use sqlparser::ast::{ BinaryOperator, DataType as SQLDataType, DateTimeField, Expr as SQLExpr, FunctionArg, FunctionArgExpr, Ident, Join, JoinConstraint, JoinOperator, ObjectName, @@ -62,8 +63,6 @@ use sqlparser::ast::{ ShowCreateObject, ShowStatementFilter, TableAlias, TableFactor, TableWithJoins, TrimWhereField, UnaryOperator, Value, Values as SQLValues, }; -// use sqlparser::ast::TimezoneInfo; -use sqlparser::ast::TimezoneInfo; use sqlparser::ast::{ColumnDef as SQLColumnDef, ColumnOption}; use sqlparser::ast::{ObjectType, OrderByExpr, Statement}; use sqlparser::parser::ParserError::ParserError; From 964f8c76a337758687568ad72169d62973755833 Mon Sep 17 00:00:00 2001 From: Mustafa Akur Date: Mon, 10 Oct 2022 18:25:38 +0300 Subject: [PATCH 03/13] get_rank removed --- datafusion/core/src/physical_plan/planner.rs | 38 +++- datafusion/core/tests/sql/window.rs | 11 +- datafusion/expr/src/window_frame.rs | 197 +++---------------- 3 files changed, 60 insertions(+), 186 deletions(-) diff --git a/datafusion/core/src/physical_plan/planner.rs b/datafusion/core/src/physical_plan/planner.rs index a4204a7d8fcf..57bec2a3e020 100644 --- a/datafusion/core/src/physical_plan/planner.rs +++ b/datafusion/core/src/physical_plan/planner.rs @@ -1385,12 +1385,6 @@ fn convert_to_column_type( ScalarValue::Utf8(None) => Ok(ScalarValue::Utf8(None)), ScalarValue::Utf8(Some(val)) => { if let DataType::Timestamp(..) = column_type { - // TODO: When the query is like ... '3' DAYS PRECEDING ..., "val" is "3 DAYS". - // In this case, the leading_field argument is unused and the code below works. - // When the query is like ... '3 DAYS' PRECEDING ..., "val" is "3". In this case, - // the code assumes it 3 milliseconds and produces wrong results. - // - // I'm not sure, but we may need to fix our sqlparser code as we try to fix this. parse_interval("millisecond", val) } else { ScalarValue::try_from_string(val.clone(), &column_type) @@ -1423,7 +1417,36 @@ fn convert_range_bound_to_column_type( /// OVER (ORDER BY a RANGES BETWEEN 3 PRECEDING AND 5 PRECEDING) /// OVER (ORDER BY a RANGES BETWEEN INTERVAL '3 DAY' PRECEDING AND '5 DAY' PRECEDING) are rejected pub fn is_window_valid(window_frame: &Arc) -> Result<()> { - if window_frame.start_bound > window_frame.end_bound { + let is_valid = match (&window_frame.start_bound, &window_frame.end_bound) { + ( + WindowFrameBound::Preceding(_), + // UNBOUNDED PRECEDING + WindowFrameBound::Preceding(ScalarValue::Utf8(None)), + ) + | ( + // UNBOUNDED FOLLOWING + WindowFrameBound::Following(ScalarValue::Utf8(None)), + WindowFrameBound::Following(_), + ) => false, + ( + WindowFrameBound::Preceding(ScalarValue::Utf8(None)), + WindowFrameBound::Preceding(_), + ) + | ( + WindowFrameBound::Following(_), + WindowFrameBound::Following(ScalarValue::Utf8(None)), + ) => true, + (WindowFrameBound::Preceding(lhs), WindowFrameBound::Preceding(rhs)) => { + lhs >= rhs + } + (WindowFrameBound::Following(lhs), WindowFrameBound::Following(rhs)) => { + lhs <= rhs + } + (WindowFrameBound::Following(_), WindowFrameBound::Preceding(_)) + | (WindowFrameBound::Following(_), WindowFrameBound::CurrentRow) => false, + _ => true, + }; + if !is_valid { Err(DataFusionError::Execution(format!( "Invalid window frame: start bound ({}) cannot be larger than end bound ({})", window_frame.start_bound, window_frame.end_bound @@ -1780,7 +1803,6 @@ fn tuple_err(value: (Result, Result)) -> Result<(T, R)> { #[cfg(test)] mod tests { use super::*; - // use crate::assert_contains; use crate::datasource::MemTable; use crate::execution::context::TaskContext; use crate::execution::options::CsvReadOptions; diff --git a/datafusion/core/tests/sql/window.rs b/datafusion/core/tests/sql/window.rs index 1b5344eb88cb..e82c9be0e0af 100644 --- a/datafusion/core/tests/sql/window.rs +++ b/datafusion/core/tests/sql/window.rs @@ -1116,6 +1116,7 @@ async fn window_frame_ranges_timestamp() -> Result<()> { vec![Arc::new(TimestampNanosecondArray::from_slice(&[ 1664264591000000000, 1664264592000000000, + 1664264592000000000, 1664264593000000000, 1664264594000000000, 1664364594000000000, @@ -1136,6 +1137,7 @@ async fn window_frame_ranges_timestamp() -> Result<()> { let df = ctx .sql( "SELECT ts, COUNT(*) OVER (ORDER BY ts RANGE BETWEEN INTERVAL '1' DAY PRECEDING AND INTERVAL '2 DAY' FOLLOWING) FROM t;" + // "SELECT ts, COUNT(*) OVER (ORDER BY ts RANGE BETWEEN '0 DAY' PRECEDING AND '0 DAY' FOLLOWING) FROM t;" ) .await?; @@ -1144,10 +1146,11 @@ async fn window_frame_ranges_timestamp() -> Result<()> { "+---------------------+-----------------+", "| ts | COUNT(UInt8(1)) |", "+---------------------+-----------------+", - "| 2022-09-27 07:43:11 | 5 |", - "| 2022-09-27 07:43:12 | 5 |", - "| 2022-09-27 07:43:13 | 5 |", - "| 2022-09-27 07:43:14 | 5 |", + "| 2022-09-27 07:43:11 | 6 |", + "| 2022-09-27 07:43:12 | 6 |", + "| 2022-09-27 07:43:12 | 6 |", + "| 2022-09-27 07:43:13 | 6 |", + "| 2022-09-27 07:43:14 | 6 |", "| 2022-09-28 11:29:54 | 2 |", "| 2022-09-29 15:16:34 | 2 |", "| 2022-09-30 19:03:14 | 1 |", diff --git a/datafusion/expr/src/window_frame.rs b/datafusion/expr/src/window_frame.rs index eb90bb863ee1..0b97ce19c104 100644 --- a/datafusion/expr/src/window_frame.rs +++ b/datafusion/expr/src/window_frame.rs @@ -25,17 +25,16 @@ use datafusion_common::{DataFusionError, Result, ScalarValue}; use sqlparser::ast; -use std::cmp::Ordering; use std::convert::{From, TryFrom}; use std::fmt; -use std::hash::{Hash, Hasher}; +use std::hash::Hash; /// The frame-spec determines which output rows are read by an aggregate window function. /// /// The ending frame boundary can be omitted (if the BETWEEN and AND keywords that surround the /// starting frame boundary are also omitted), in which case the ending frame boundary defaults to /// CURRENT ROW. -#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Hash)] +#[derive(Debug, Clone, PartialEq, Eq, Hash)] pub struct WindowFrame { /// A frame type - either ROWS, RANGE or GROUPS pub units: WindowFrameUnits, @@ -134,8 +133,7 @@ pub fn convert_range_bound_to_scalar_value(v: ast::RangeBounds) -> Result FOLLOWING /// 5. UNBOUNDED FOLLOWING /// -/// in this implementation we'll only allow to be u64 (i.e. no dynamic boundary) -#[derive(Debug, Clone, Eq)] +#[derive(Debug, Clone, PartialEq, Eq, Hash)] pub enum WindowFrameBound { /// 1. UNBOUNDED PRECEDING /// The frame boundary is the first row in the partition. @@ -160,11 +158,6 @@ pub enum WindowFrameBound { } impl From for WindowFrameBound { - // TODO: Add handling for other ScalarValue, once sql parser supports other types than literal int - // see https://github.com/sqlparser-rs/sqlparser-rs/issues/631 - // For now we can either get Some(u64) or None from PRECEDING AND - // FOLLOWING fields. When sql parser supports datetime types in the window - // range queries extend below to support datetime types inside the window. fn from(value: ast::WindowFrameBound) -> Self { match value { ast::WindowFrameBound::Preceding(Some(v)) => { @@ -202,112 +195,6 @@ impl fmt::Display for WindowFrameBound { } } -impl PartialEq for WindowFrameBound { - fn eq(&self, other: &Self) -> bool { - self.cmp(other) == Ordering::Equal - } -} - -impl PartialOrd for WindowFrameBound { - fn partial_cmp(&self, other: &Self) -> Option { - Some(self.cmp(other)) - } -} - -impl Ord for WindowFrameBound { - fn cmp(&self, other: &Self) -> Ordering { - self.get_rank().cmp(&other.get_rank()) - } -} - -impl Hash for WindowFrameBound { - fn hash(&self, state: &mut H) { - self.get_rank().hash(state) - } -} - -impl WindowFrameBound { - /// get the rank of this window frame bound. - /// - /// the rank is a tuple of (u8, u64) because we'll firstly compare the kind and then the value - /// which requires special handling e.g. with preceding the larger the value the smaller the - /// rank and also for 0 preceding / following it is the same as current row - fn get_rank(&self) -> (u8, u64) { - match self { - WindowFrameBound::Preceding(ScalarValue::Utf8(None)) => (0, 0), - WindowFrameBound::Following(ScalarValue::Utf8(None)) => (4, 0), - WindowFrameBound::Preceding(ScalarValue::UInt64(Some(0))) - | WindowFrameBound::Preceding(ScalarValue::UInt32(Some(0))) - | WindowFrameBound::Preceding(ScalarValue::UInt16(Some(0))) - | WindowFrameBound::Preceding(ScalarValue::Int64(Some(0))) - | WindowFrameBound::Preceding(ScalarValue::Int32(Some(0))) - | WindowFrameBound::Preceding(ScalarValue::Int16(Some(0))) - | WindowFrameBound::CurrentRow - | WindowFrameBound::Following(ScalarValue::UInt64(Some(0))) - | WindowFrameBound::Following(ScalarValue::UInt32(Some(0))) - | WindowFrameBound::Following(ScalarValue::UInt16(Some(0))) - | WindowFrameBound::Following(ScalarValue::Int64(Some(0))) - | WindowFrameBound::Following(ScalarValue::Int32(Some(0))) - | WindowFrameBound::Following(ScalarValue::Int16(Some(0))) => (2, 0), - WindowFrameBound::Preceding(ScalarValue::UInt64(Some(v))) => { - (1, u64::MAX - *v as u64) - } - WindowFrameBound::Preceding(ScalarValue::UInt32(Some(v))) => { - (1, u64::MAX - *v as u64) - } - WindowFrameBound::Preceding(ScalarValue::UInt16(Some(v))) => { - (1, u64::MAX - *v as u64) - } - WindowFrameBound::Preceding(ScalarValue::Int64(Some(v))) => { - (1, u64::MAX - *v as u64) - } - WindowFrameBound::Preceding(ScalarValue::Int32(Some(v))) => { - (1, u64::MAX - *v as u64) - } - WindowFrameBound::Preceding(ScalarValue::Int16(Some(v))) => { - (1, u64::MAX - *v as u64) - } - WindowFrameBound::Preceding(ScalarValue::Float64(Some(v))) => { - (1, u64::MAX - *v as u64) - } - WindowFrameBound::Preceding(ScalarValue::Float32(Some(v))) => { - (1, u64::MAX - *v as u64) - } - WindowFrameBound::Preceding(ScalarValue::IntervalDayTime(Some(v))) => { - (1, u64::MAX - *v as u64) - } - WindowFrameBound::Following(ScalarValue::UInt64(Some(v))) => (3, *v as u64), - WindowFrameBound::Following(ScalarValue::UInt32(Some(v))) => (3, *v as u64), - WindowFrameBound::Following(ScalarValue::UInt16(Some(v))) => (3, *v as u64), - WindowFrameBound::Following(ScalarValue::Int64(Some(v))) => (3, *v as u64), - WindowFrameBound::Following(ScalarValue::Int32(Some(v))) => (3, *v as u64), - WindowFrameBound::Following(ScalarValue::Int16(Some(v))) => (3, *v as u64), - WindowFrameBound::Following(ScalarValue::Float64(Some(v))) => (3, *v as u64), - WindowFrameBound::Following(ScalarValue::Float32(Some(v))) => (3, *v as u64), - WindowFrameBound::Following(ScalarValue::IntervalDayTime(Some(v))) => { - (3, *v as u64) - } - WindowFrameBound::Preceding(ScalarValue::Utf8(Some(v))) => { - match v.as_ref() { - "0" => (2, 0), - // After schema information we do not have string type for WindowFrameBound - // hence we can cast it to a arbitrary point, TODO: fix here for better handling - _elem => (1, u64::MAX - 1), - } - } - WindowFrameBound::Following(ScalarValue::Utf8(Some(v))) => { - match v.as_ref() { - "0" => (2, 0), - // After schema information we do not have string type for WindowFrameBound - // hence we can cast it to a arbitrary point, TODO: fix here for better handling - _elem => (3, 1), - } - } - _ => todo!(), - } - } -} - /// There are three frame types: ROWS, GROUPS, and RANGE. The frame type determines how the /// starting and ending boundaries of the frame are measured. #[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Hash)] @@ -355,7 +242,7 @@ mod tests { #[ignore] // We no longer check for validity of the preceding, following during window frame creation // Since we are accepting different kind of types during creation, validity of that check can be only - // done when schema information is available. The is no trivial way to reject PRECEDING '1 MONTH' is + // done when schema information is available. There is no trivial way to reject PRECEDING '1 MONTH' is // later than PRECEDING "40 DAYS". Hence this test is ignored, However they are rejected during physical // plan creation once schema information is available. fn test_window_frame_creation() -> Result<()> { @@ -413,15 +300,26 @@ mod tests { } #[test] + #[ignore] + // We now uses default PartialEq, Eq trait for the WindowFrameBound + // equality between WindowFrameBound::Preceding(ScalarValue::UInt64(Some(0))), + // and WindowFrameBound::CurrentRow will return false. fn test_eq() { - assert_eq!( - WindowFrameBound::Preceding(ScalarValue::UInt64(Some(0))), - WindowFrameBound::CurrentRow - ); - assert_eq!( - WindowFrameBound::CurrentRow, - WindowFrameBound::Following(ScalarValue::UInt64(Some(0))) - ); + // // Commented tests will not pass, this doesn't affect working of the + // // window frame calculation all tests pass + // assert_eq!( + // WindowFrameBound::Preceding(ScalarValue::UInt64(Some(0))), + // WindowFrameBound::CurrentRow + // ); + // assert_eq!( + // WindowFrameBound::Preceding(ScalarValue::IntervalMonthDayNano(Some(0))), + // WindowFrameBound::CurrentRow + // ); + // assert_eq!( + // WindowFrameBound::CurrentRow, + // WindowFrameBound::Following(ScalarValue::UInt64(Some(0))) + // ); + // // assert_eq!( WindowFrameBound::Following(ScalarValue::UInt64(Some(2))), WindowFrameBound::Following(ScalarValue::UInt64(Some(2))) @@ -439,53 +337,4 @@ mod tests { WindowFrameBound::Preceding(ScalarValue::Utf8(None)) ); } - - #[test] - fn test_ord() { - assert!( - WindowFrameBound::Preceding(ScalarValue::UInt64(Some(1))) - < WindowFrameBound::CurrentRow - ); - // ! yes this is correct! - assert!( - WindowFrameBound::Preceding(ScalarValue::UInt64(Some(2))) - < WindowFrameBound::Preceding(ScalarValue::UInt64(Some(1))) - ); - assert!( - WindowFrameBound::Preceding(ScalarValue::UInt64(Some(u64::MAX))) - < WindowFrameBound::Preceding(ScalarValue::UInt64(Some(u64::MAX - 1))) - ); - assert!( - WindowFrameBound::Preceding(ScalarValue::Utf8(None)) - < WindowFrameBound::Preceding(ScalarValue::UInt64(Some(1000000))) - ); - assert!( - WindowFrameBound::Preceding(ScalarValue::Utf8(None)) - < WindowFrameBound::Preceding(ScalarValue::UInt64(Some(u64::MAX))) - ); - assert!( - WindowFrameBound::Preceding(ScalarValue::Utf8(None)) - < WindowFrameBound::Following(ScalarValue::UInt64(Some(0))) - ); - assert!( - WindowFrameBound::Preceding(ScalarValue::UInt64(Some(1))) - < WindowFrameBound::Following(ScalarValue::UInt64(Some(1))) - ); - assert!( - WindowFrameBound::CurrentRow - < WindowFrameBound::Following(ScalarValue::UInt64(Some(1))) - ); - assert!( - WindowFrameBound::Following(ScalarValue::UInt64(Some(1))) - < WindowFrameBound::Following(ScalarValue::UInt64(Some(2))) - ); - assert!( - WindowFrameBound::Following(ScalarValue::UInt64(Some(2))) - < WindowFrameBound::Following(ScalarValue::Utf8(None)) - ); - assert!( - WindowFrameBound::Following(ScalarValue::UInt64(Some(u64::MAX))) - < WindowFrameBound::Following(ScalarValue::Utf8(None)) - ); - } } From 2d8a52f08b2234318b2474408b360bb2076a3563 Mon Sep 17 00:00:00 2001 From: Mehmet Ozan Kabak Date: Tue, 11 Oct 2022 19:58:57 -0400 Subject: [PATCH 04/13] Tidy up datetime arithmetic, get rid of unwraps --- datafusion/common/src/datetime.rs | 159 --------------- datafusion/common/src/lib.rs | 1 - datafusion/common/src/scalar.rs | 184 +++++++++++++----- datafusion/core/src/physical_plan/planner.rs | 133 ++++++------- datafusion/core/tests/sql/window.rs | 33 ++-- datafusion/expr/src/window_frame.rs | 114 +++-------- .../physical-expr/src/expressions/datetime.rs | 29 +-- .../physical-expr/src/window/aggregate.rs | 106 +++++----- datafusion/proto/src/to_proto.rs | 73 ++++--- 9 files changed, 334 insertions(+), 498 deletions(-) delete mode 100644 datafusion/common/src/datetime.rs diff --git a/datafusion/common/src/datetime.rs b/datafusion/common/src/datetime.rs deleted file mode 100644 index 80407a902c0b..000000000000 --- a/datafusion/common/src/datetime.rs +++ /dev/null @@ -1,159 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -use crate::delta::shift_months; -use crate::Result; -use crate::{DataFusionError, ScalarValue}; -use arrow::datatypes::{IntervalDayTimeType, IntervalMonthDayNanoType}; -use chrono::{Datelike, Duration, NaiveDate, NaiveDateTime}; -use std::ops::{Add, Sub}; - -pub fn evaluate_scalar( - operand: ScalarValue, - sign: i32, - scalar: &ScalarValue, -) -> Result { - let res = match operand { - ScalarValue::Date32(Some(days)) => { - let value = date32_add(days, scalar, sign)?; - ScalarValue::Date32(Some(value)) - } - ScalarValue::Date64(Some(ms)) => { - let value = date64_add(ms, scalar, sign)?; - ScalarValue::Date64(Some(value)) - } - ScalarValue::TimestampSecond(Some(ts_s), zone) => { - let value = seconds_add(ts_s, scalar, sign)?; - ScalarValue::TimestampSecond(Some(value), zone) - } - ScalarValue::TimestampMillisecond(Some(ts_ms), zone) => { - let value = milliseconds_add(ts_ms, scalar, sign)?; - ScalarValue::TimestampMillisecond(Some(value), zone) - } - ScalarValue::TimestampMicrosecond(Some(ts_us), zone) => { - let value = microseconds_add(ts_us, scalar, sign)?; - ScalarValue::TimestampMicrosecond(Some(value), zone) - } - ScalarValue::TimestampNanosecond(Some(ts_ns), zone) => { - let value = nanoseconds_add(ts_ns, scalar, sign)?; - ScalarValue::TimestampNanosecond(Some(value), zone) - } - _ => Err(DataFusionError::Execution(format!( - "Invalid lhs type {} for DateIntervalExpr", - operand.get_datatype() - )))?, - }; - Ok(res) -} - -#[inline] -pub fn date32_add(days: i32, scalar: &ScalarValue, sign: i32) -> Result { - let epoch = NaiveDate::from_ymd(1970, 1, 1); - let prior = epoch.add(Duration::days(days as i64)); - let posterior = do_date_math(prior, scalar, sign)?; - Ok(posterior.sub(epoch).num_days() as i32) -} - -#[inline] -pub fn date64_add(ms: i64, scalar: &ScalarValue, sign: i32) -> Result { - let epoch = NaiveDate::from_ymd(1970, 1, 1); - let prior = epoch.add(Duration::milliseconds(ms)); - let posterior = do_date_math(prior, scalar, sign)?; - Ok(posterior.sub(epoch).num_milliseconds()) -} - -#[inline] -pub fn seconds_add(ts_s: i64, scalar: &ScalarValue, sign: i32) -> Result { - Ok(do_date_time_math(ts_s, 0, scalar, sign)?.timestamp()) -} - -#[inline] -pub fn milliseconds_add(ts_ms: i64, scalar: &ScalarValue, sign: i32) -> Result { - let secs = ts_ms / 1000; - let nsecs = ((ts_ms % 1000) * 1_000_000) as u32; - Ok(do_date_time_math(secs, nsecs, scalar, sign)?.timestamp_millis()) -} - -#[inline] -pub fn microseconds_add(ts_us: i64, scalar: &ScalarValue, sign: i32) -> Result { - let secs = ts_us / 1_000_000; - let nsecs = ((ts_us % 1_000_000) * 1000) as u32; - Ok(do_date_time_math(secs, nsecs, scalar, sign)?.timestamp_nanos() / 1000) -} - -#[inline] -pub fn nanoseconds_add(ts_ns: i64, scalar: &ScalarValue, sign: i32) -> Result { - let secs = ts_ns / 1_000_000_000; - let nsecs = (ts_ns % 1_000_000_000) as u32; - Ok(do_date_time_math(secs, nsecs, scalar, sign)?.timestamp_nanos()) -} - -#[inline] -fn do_date_time_math( - secs: i64, - nsecs: u32, - scalar: &ScalarValue, - sign: i32, -) -> Result { - let prior = NaiveDateTime::from_timestamp(secs, nsecs); - do_date_math(prior, scalar, sign) -} - -fn do_date_math(prior: D, scalar: &ScalarValue, sign: i32) -> Result -where - D: Datelike + Add, -{ - Ok(match scalar { - ScalarValue::IntervalDayTime(Some(i)) => add_day_time(prior, *i, sign), - ScalarValue::IntervalYearMonth(Some(i)) => shift_months(prior, *i * sign), - ScalarValue::IntervalMonthDayNano(Some(i)) => add_m_d_nano(prior, *i, sign), - other => Err(DataFusionError::Execution(format!( - "DateIntervalExpr does not support non-interval type {:?}", - other - )))?, - }) -} - -// Can remove once https://github.com/apache/arrow-rs/pull/2031 is released -fn add_m_d_nano(prior: D, interval: i128, sign: i32) -> D -where - D: Datelike + Add, -{ - // let interval = interval as u128; - // let nanos = (interval >> 64) as i64 * sign as i64; - // let days = (interval >> 32) as i32 * sign; - // let months = interval as i32 * sign; - let (months, days, nanos) = IntervalMonthDayNanoType::to_parts(interval); - let months = months * sign; - let days = days * sign; - let nanos = nanos * sign as i64; - let a = shift_months(prior, months); - let b = a.add(Duration::days(days as i64)); - b.add(Duration::nanoseconds(nanos)) -} - -// Can remove once https://github.com/apache/arrow-rs/pull/2031 is released -fn add_day_time(prior: D, interval: i64, sign: i32) -> D -where - D: Datelike + Add, -{ - let (days, ms) = IntervalDayTimeType::to_parts(interval); - let days = days * sign; - let ms = ms * sign; - let intermediate = prior.add(Duration::days(days as i64)); - intermediate.add(Duration::milliseconds(ms as i64)) -} diff --git a/datafusion/common/src/lib.rs b/datafusion/common/src/lib.rs index b5bd5fba4ead..e631b4dcac0d 100644 --- a/datafusion/common/src/lib.rs +++ b/datafusion/common/src/lib.rs @@ -17,7 +17,6 @@ pub mod bisect; mod column; -pub mod datetime; pub mod delta; mod dfschema; mod error; diff --git a/datafusion/common/src/scalar.rs b/datafusion/common/src/scalar.rs index 4db54b97ba74..eac8effacd0a 100644 --- a/datafusion/common/src/scalar.rs +++ b/datafusion/common/src/scalar.rs @@ -20,6 +20,7 @@ use std::borrow::Borrow; use std::cmp::{max, Ordering}; use std::convert::{Infallible, TryInto}; +use std::ops::{Add, Sub}; use std::str::FromStr; use std::{convert::TryFrom, fmt, iter::repeat, sync::Arc}; @@ -36,9 +37,10 @@ use arrow::{ }, util::decimal::Decimal128, }; +use chrono::{Datelike, Duration, NaiveDate, NaiveDateTime}; use ordered_float::OrderedFloat; -use crate::datetime::evaluate_scalar; +use crate::delta::shift_months; use crate::error::{DataFusionError, Result}; /// Represents a dynamically typed, nullable single value. @@ -442,6 +444,7 @@ macro_rules! unsigned_subtraction_error { macro_rules! impl_op { ($LHS:expr, $RHS:expr, $OPERATION:tt) => { match ($LHS, $RHS) { + // Binary operations on arguments with the same type: ( ScalarValue::Decimal128(v1, p1, s1), ScalarValue::Decimal128(v2, p2, s2), @@ -478,59 +481,144 @@ macro_rules! impl_op { (ScalarValue::Int8(lhs), ScalarValue::Int8(rhs)) => { primitive_op!(lhs, rhs, Int8, $OPERATION) } - _ => { - impl_distinct_cases_op!($LHS, $RHS, $OPERATION) + // Binary operations on arguments with different types: + (ScalarValue::Date32(Some(days)), _) => { + let value = date32_add(*days, $RHS, get_sign!($OPERATION))?; + Ok(ScalarValue::Date32(Some(value))) + } + (ScalarValue::Date64(Some(ms)), _) => { + let value = date64_add(*ms, $RHS, get_sign!($OPERATION))?; + Ok(ScalarValue::Date64(Some(value))) + } + (ScalarValue::TimestampSecond(Some(ts_s), zone), _) => { + let value = seconds_add(*ts_s, $RHS, get_sign!($OPERATION))?; + Ok(ScalarValue::TimestampSecond(Some(value), zone.clone())) + } + (ScalarValue::TimestampMillisecond(Some(ts_ms), zone), _) => { + let value = milliseconds_add(*ts_ms, $RHS, get_sign!($OPERATION))?; + Ok(ScalarValue::TimestampMillisecond(Some(value), zone.clone())) + } + (ScalarValue::TimestampMicrosecond(Some(ts_us), zone), _) => { + let value = microseconds_add(*ts_us, $RHS, get_sign!($OPERATION))?; + Ok(ScalarValue::TimestampMicrosecond(Some(value), zone.clone())) + } + (ScalarValue::TimestampNanosecond(Some(ts_ns), zone), _) => { + let value = nanoseconds_add(*ts_ns, $RHS, get_sign!($OPERATION))?; + Ok(ScalarValue::TimestampNanosecond(Some(value), zone.clone())) } + _ => Err(DataFusionError::Internal(format!( + "Operator {} is not implemented for types {:?} and {:?}", + stringify!($OPERATION), + $LHS, + $RHS + ))), } }; } -// If we want a special implementation for an operation this is the place to implement it. -// For instance, in the future we may want to implement subtraction for dates but not addition. -// We can implement such special cases here. -macro_rules! impl_distinct_cases_op { - ($LHS:expr, $RHS:expr, +) => { - match ($LHS, $RHS) { - (ScalarValue::TimestampNanosecond(_, _), ScalarValue::IntervalDayTime(_)) - | ( - ScalarValue::TimestampNanosecond(_, _), - ScalarValue::IntervalYearMonth(_), - ) - | ( - ScalarValue::TimestampNanosecond(_, _), - ScalarValue::IntervalMonthDayNano(_), - ) => { - // 1 means addition - evaluate_scalar($LHS.clone(), 1, &$RHS) - } - e => Err(DataFusionError::Internal(format!( - "Addition is not implemented for {:?}", - e - ))), - } +macro_rules! get_sign { + (+) => { + 1 }; - ($LHS:expr, $RHS:expr, -) => { - match ($LHS, $RHS) { - (ScalarValue::TimestampNanosecond(_, _), ScalarValue::IntervalDayTime(_)) - | ( - ScalarValue::TimestampNanosecond(_, _), - ScalarValue::IntervalYearMonth(_), - ) - | ( - ScalarValue::TimestampNanosecond(_, _), - ScalarValue::IntervalMonthDayNano(_), - ) => { - // -1 means subtraction - evaluate_scalar($LHS.clone(), -1, &$RHS) - } - e => Err(DataFusionError::Internal(format!( - "Subtraction is not implemented for {:?}", - e - ))), - } + (-) => { + -1 }; } +#[inline] +pub fn date32_add(days: i32, scalar: &ScalarValue, sign: i32) -> Result { + let epoch = NaiveDate::from_ymd(1970, 1, 1); + let prior = epoch.add(Duration::days(days as i64)); + let posterior = do_date_math(prior, scalar, sign)?; + Ok(posterior.sub(epoch).num_days() as i32) +} + +#[inline] +pub fn date64_add(ms: i64, scalar: &ScalarValue, sign: i32) -> Result { + let epoch = NaiveDate::from_ymd(1970, 1, 1); + let prior = epoch.add(Duration::milliseconds(ms)); + let posterior = do_date_math(prior, scalar, sign)?; + Ok(posterior.sub(epoch).num_milliseconds()) +} + +#[inline] +pub fn seconds_add(ts_s: i64, scalar: &ScalarValue, sign: i32) -> Result { + Ok(do_date_time_math(ts_s, 0, scalar, sign)?.timestamp()) +} + +#[inline] +pub fn milliseconds_add(ts_ms: i64, scalar: &ScalarValue, sign: i32) -> Result { + let secs = ts_ms / 1000; + let nsecs = ((ts_ms % 1000) * 1_000_000) as u32; + Ok(do_date_time_math(secs, nsecs, scalar, sign)?.timestamp_millis()) +} + +#[inline] +pub fn microseconds_add(ts_us: i64, scalar: &ScalarValue, sign: i32) -> Result { + let secs = ts_us / 1_000_000; + let nsecs = ((ts_us % 1_000_000) * 1000) as u32; + Ok(do_date_time_math(secs, nsecs, scalar, sign)?.timestamp_nanos() / 1000) +} + +#[inline] +pub fn nanoseconds_add(ts_ns: i64, scalar: &ScalarValue, sign: i32) -> Result { + let secs = ts_ns / 1_000_000_000; + let nsecs = (ts_ns % 1_000_000_000) as u32; + Ok(do_date_time_math(secs, nsecs, scalar, sign)?.timestamp_nanos()) +} + +#[inline] +fn do_date_time_math( + secs: i64, + nsecs: u32, + scalar: &ScalarValue, + sign: i32, +) -> Result { + let prior = NaiveDateTime::from_timestamp(secs, nsecs); + do_date_math(prior, scalar, sign) +} + +fn do_date_math(prior: D, scalar: &ScalarValue, sign: i32) -> Result +where + D: Datelike + Add, +{ + Ok(match scalar { + ScalarValue::IntervalDayTime(Some(i)) => add_day_time(prior, *i, sign), + ScalarValue::IntervalYearMonth(Some(i)) => shift_months(prior, *i * sign), + ScalarValue::IntervalMonthDayNano(Some(i)) => add_m_d_nano(prior, *i, sign), + other => Err(DataFusionError::Execution(format!( + "DateIntervalExpr does not support non-interval type {:?}", + other + )))?, + }) +} + +// Can remove once https://github.com/apache/arrow-rs/pull/2031 is released +fn add_m_d_nano(prior: D, interval: i128, sign: i32) -> D +where + D: Datelike + Add, +{ + let (months, days, nanos) = IntervalMonthDayNanoType::to_parts(interval); + let months = months * sign; + let days = days * sign; + let nanos = nanos * sign as i64; + let a = shift_months(prior, months); + let b = a.add(Duration::days(days as i64)); + b.add(Duration::nanoseconds(nanos)) +} + +// Can remove once https://github.com/apache/arrow-rs/pull/2031 is released +fn add_day_time(prior: D, interval: i64, sign: i32) -> D +where + D: Datelike + Add, +{ + let (days, ms) = IntervalDayTimeType::to_parts(interval); + let days = days * sign; + let ms = ms * sign; + let intermediate = prior.add(Duration::days(days as i64)); + intermediate.add(Duration::milliseconds(ms as i64)) +} + // manual implementation of `Hash` that uses OrderedFloat to // get defined behavior for floating point impl std::hash::Hash for ScalarValue { @@ -3709,7 +3797,7 @@ mod tests { match lhs.$FUNCTION(&rhs) { Ok(_result) => { panic!( - "Expected summation error between lhs: '{:?}', rhs: {:?}", + "Expected binary operation error between lhs: '{:?}', rhs: {:?}", lhs, rhs ); } @@ -3727,8 +3815,8 @@ mod tests { }; } - expect_operation_error!(expect_add_error, add, "Addition is not implemented"); - expect_operation_error!(expect_sub_error, sub, "Subtraction is not implemented"); + expect_operation_error!(expect_add_error, add, "Operator + is not implemented"); + expect_operation_error!(expect_sub_error, sub, "Operator - is not implemented"); macro_rules! decimal_op_test_cases { ($OPERATION:ident, [$([$L_VALUE:expr, $L_PRECISION:expr, $L_SCALE:expr, $R_VALUE:expr, $R_PRECISION:expr, $R_SCALE:expr, $O_VALUE:expr, $O_PRECISION:expr, $O_SCALE:expr]),+]) => { diff --git a/datafusion/core/src/physical_plan/planner.rs b/datafusion/core/src/physical_plan/planner.rs index 57bec2a3e020..f4df4306530f 100644 --- a/datafusion/core/src/physical_plan/planner.rs +++ b/datafusion/core/src/physical_plan/planner.rs @@ -1368,37 +1368,37 @@ fn get_physical_expr_pair( let physical_name = physical_name(expr)?; Ok((physical_expr, physical_name)) } -/// Casts the ScalarValue: `in_scalar` to column type once we have schema information +/// Casts the ScalarValue `value` to column type once we have schema information /// The resulting type is not necessarily same type with the `column_type`. For instance /// if `column_type` is Timestamp the result is casted to Interval type. The reason is that /// Operation between Timestamps is not meaningful, However operation between Timestamp and /// Interval is valid. For basic types `column_type` is indeed the resulting type. fn convert_to_column_type( - column_type: arrow::datatypes::DataType, - in_scalar: &ScalarValue, + column_type: &arrow::datatypes::DataType, + value: &ScalarValue, ) -> Result { - match in_scalar { + match value { // In here we can either get ScalarValue::Utf8(None) or // ScalarValue::Utf8(Some(val)). The reason is that we convert the sqlparser result // to the Utf8 for all possible cases, since we have no schema information during conversion. // Here we have schema information, hence we can cast the appropriate ScalarValue Type. - ScalarValue::Utf8(None) => Ok(ScalarValue::Utf8(None)), + ScalarValue::Utf8(None) => ScalarValue::try_from(column_type), ScalarValue::Utf8(Some(val)) => { if let DataType::Timestamp(..) = column_type { parse_interval("millisecond", val) } else { - ScalarValue::try_from_string(val.clone(), &column_type) + ScalarValue::try_from_string(val.clone(), column_type) } } - unexpected => Err(DataFusionError::Internal(format!( - "unexpected: {:?}", - unexpected + s => Err(DataFusionError::Internal(format!( + "Unexpected value: {:?}", + s ))), } } fn convert_range_bound_to_column_type( - column_type: arrow::datatypes::DataType, + column_type: &arrow::datatypes::DataType, bound: &WindowFrameBound, ) -> Result { Ok(match bound { @@ -1411,50 +1411,27 @@ fn convert_range_bound_to_column_type( } }) } + /// Check if window bounds are valid after schema information is available, and /// window_frame bounds are casted to the corresponding column type. /// queries like: /// OVER (ORDER BY a RANGES BETWEEN 3 PRECEDING AND 5 PRECEDING) /// OVER (ORDER BY a RANGES BETWEEN INTERVAL '3 DAY' PRECEDING AND '5 DAY' PRECEDING) are rejected -pub fn is_window_valid(window_frame: &Arc) -> Result<()> { - let is_valid = match (&window_frame.start_bound, &window_frame.end_bound) { - ( - WindowFrameBound::Preceding(_), - // UNBOUNDED PRECEDING - WindowFrameBound::Preceding(ScalarValue::Utf8(None)), - ) - | ( - // UNBOUNDED FOLLOWING - WindowFrameBound::Following(ScalarValue::Utf8(None)), - WindowFrameBound::Following(_), - ) => false, - ( - WindowFrameBound::Preceding(ScalarValue::Utf8(None)), - WindowFrameBound::Preceding(_), - ) - | ( - WindowFrameBound::Following(_), - WindowFrameBound::Following(ScalarValue::Utf8(None)), - ) => true, +pub fn is_window_valid(window_frame: &WindowFrame) -> bool { + match (&window_frame.start_bound, &window_frame.end_bound) { + (WindowFrameBound::Following(_), WindowFrameBound::Preceding(_)) + | (WindowFrameBound::Following(_), WindowFrameBound::CurrentRow) + | (WindowFrameBound::CurrentRow, WindowFrameBound::Preceding(_)) => false, (WindowFrameBound::Preceding(lhs), WindowFrameBound::Preceding(rhs)) => { - lhs >= rhs + !rhs.is_null() && (lhs.is_null() || (lhs >= rhs)) } (WindowFrameBound::Following(lhs), WindowFrameBound::Following(rhs)) => { - lhs <= rhs + !lhs.is_null() && (rhs.is_null() || (lhs <= rhs)) } - (WindowFrameBound::Following(_), WindowFrameBound::Preceding(_)) - | (WindowFrameBound::Following(_), WindowFrameBound::CurrentRow) => false, _ => true, - }; - if !is_valid { - Err(DataFusionError::Execution(format!( - "Invalid window frame: start bound ({}) cannot be larger than end bound ({})", - window_frame.start_bound, window_frame.end_bound - ))) - } else { - Ok(()) } } + /// Create a window expression with a name from a logical expression pub fn create_window_expr_with_name( e: &Expr, @@ -1516,62 +1493,62 @@ pub fn create_window_expr_with_name( )), }) .collect::>>()?; - let mut new_window_frame = window_frame.clone(); - // Below query may produce error. We are calling its ? method only when it will - // not produce error logically (Such as when WindowFrameUnits is Range). - let order_by_column = order_by.first().ok_or_else(|| { - DataFusionError::Internal("Order By column cannot be empty".to_string()) - }); - if let Some(window_frame) = window_frame { + let mut window_frame = window_frame.clone(); + if let Some(ref mut window_frame) = window_frame { match window_frame.units { WindowFrameUnits::Groups => { return Err(DataFusionError::NotImplemented( - "Window frame definitions involving GROUPS are not supported yet" - .to_string(), + "Window frame definitions involving GROUPS are not supported yet" + .to_string(), )); } WindowFrameUnits::Range => { - let column_type = - order_by_column?.expr.data_type(physical_input_schema)?; - new_window_frame.as_mut().unwrap().start_bound = - convert_range_bound_to_column_type( - column_type.clone(), - &window_frame.start_bound, - )?; - new_window_frame.as_mut().unwrap().end_bound = - convert_range_bound_to_column_type( - column_type, - &window_frame.end_bound, - )?; + let column_type = order_by + .first() + .ok_or_else(|| { + DataFusionError::Internal( + "ORDER BY column cannot be empty".to_string(), + ) + })? + .expr + .data_type(physical_input_schema)?; + window_frame.start_bound = convert_range_bound_to_column_type( + &column_type, + &window_frame.start_bound, + )?; + window_frame.end_bound = convert_range_bound_to_column_type( + &column_type, + &window_frame.end_bound, + )?; } WindowFrameUnits::Rows => { // ROWS should have type usize which is Uint64 for our case let column_type = arrow::datatypes::DataType::UInt64; - new_window_frame.as_mut().unwrap().start_bound = - convert_range_bound_to_column_type( - column_type.clone(), - &window_frame.start_bound, - )?; - new_window_frame.as_mut().unwrap().end_bound = - convert_range_bound_to_column_type( - column_type, - &window_frame.end_bound, - )?; + window_frame.start_bound = convert_range_bound_to_column_type( + &column_type, + &window_frame.start_bound, + )?; + window_frame.end_bound = convert_range_bound_to_column_type( + &column_type, + &window_frame.end_bound, + )?; } } + if !is_window_valid(window_frame) { + return Err(DataFusionError::Execution(format!( + "Invalid window frame: start bound ({}) cannot be larger than end bound ({})", + window_frame.start_bound, window_frame.end_bound + ))); + } } - let new_window_frame = new_window_frame.map(Arc::new); - if let Some(ref window_frame) = new_window_frame { - is_window_valid(window_frame)?; - } windows::create_window_expr( fun, name, &args, &partition_by, &order_by, - new_window_frame, + window_frame.map(Arc::new), physical_input_schema, ) } diff --git a/datafusion/core/tests/sql/window.rs b/datafusion/core/tests/sql/window.rs index e82c9be0e0af..97bfebda682d 100644 --- a/datafusion/core/tests/sql/window.rs +++ b/datafusion/core/tests/sql/window.rs @@ -1136,25 +1136,30 @@ async fn window_frame_ranges_timestamp() -> Result<()> { // execute the query let df = ctx .sql( - "SELECT ts, COUNT(*) OVER (ORDER BY ts RANGE BETWEEN INTERVAL '1' DAY PRECEDING AND INTERVAL '2 DAY' FOLLOWING) FROM t;" - // "SELECT ts, COUNT(*) OVER (ORDER BY ts RANGE BETWEEN '0 DAY' PRECEDING AND '0 DAY' FOLLOWING) FROM t;" + "SELECT + ts, + COUNT(*) OVER (ORDER BY ts RANGE BETWEEN INTERVAL '1' DAY PRECEDING AND INTERVAL '2 DAY' FOLLOWING), + COUNT(*) OVER (ORDER BY ts RANGE BETWEEN '0 DAY' PRECEDING AND '0' DAY FOLLOWING), + COUNT(*) OVER (ORDER BY ts RANGE BETWEEN '5' SECOND PRECEDING AND CURRENT ROW) + FROM t + ORDER BY ts" ) .await?; let actual = df.collect().await?; let expected = vec![ - "+---------------------+-----------------+", - "| ts | COUNT(UInt8(1)) |", - "+---------------------+-----------------+", - "| 2022-09-27 07:43:11 | 6 |", - "| 2022-09-27 07:43:12 | 6 |", - "| 2022-09-27 07:43:12 | 6 |", - "| 2022-09-27 07:43:13 | 6 |", - "| 2022-09-27 07:43:14 | 6 |", - "| 2022-09-28 11:29:54 | 2 |", - "| 2022-09-29 15:16:34 | 2 |", - "| 2022-09-30 19:03:14 | 1 |", - "+---------------------+-----------------+", + "+---------------------+-----------------+-----------------+-----------------+", + "| ts | COUNT(UInt8(1)) | COUNT(UInt8(1)) | COUNT(UInt8(1)) |", + "+---------------------+-----------------+-----------------+-----------------+", + "| 2022-09-27 07:43:11 | 6 | 1 | 1 |", + "| 2022-09-27 07:43:12 | 6 | 2 | 3 |", + "| 2022-09-27 07:43:12 | 6 | 2 | 3 |", + "| 2022-09-27 07:43:13 | 6 | 1 | 4 |", + "| 2022-09-27 07:43:14 | 6 | 1 | 5 |", + "| 2022-09-28 11:29:54 | 2 | 1 | 1 |", + "| 2022-09-29 15:16:34 | 2 | 1 | 1 |", + "| 2022-09-30 19:03:14 | 1 | 1 | 1 |", + "+---------------------+-----------------+-----------------+-----------------+", ]; assert_batches_eq!(expected, &actual); Ok(()) diff --git a/datafusion/expr/src/window_frame.rs b/datafusion/expr/src/window_frame.rs index 0b97ce19c104..69f7e781bef0 100644 --- a/datafusion/expr/src/window_frame.rs +++ b/datafusion/expr/src/window_frame.rs @@ -59,11 +59,11 @@ impl TryFrom for WindowFrame { type Error = DataFusionError; fn try_from(value: ast::WindowFrame) -> Result { - let start_bound = value.start_bound.into(); - let end_bound = value - .end_bound - .map(WindowFrameBound::from) - .unwrap_or(WindowFrameBound::CurrentRow); + let start_bound = value.start_bound.try_into()?; + let end_bound = match value.end_bound { + Some(value) => value.try_into()?, + None => WindowFrameBound::CurrentRow, + }; if let WindowFrameBound::Following(ScalarValue::Utf8(None)) = start_bound { Err(DataFusionError::Execution( @@ -106,17 +106,19 @@ pub fn convert_range_bound_to_scalar_value(v: ast::RangeBounds) -> Result { - let mut res = match *value { - ast::Expr::Value(ast::Value::SingleQuotedString(elem)) => Ok(elem), - unexpected => Err(DataFusionError::Internal(format!( - "INTERVAL expression cannot be {:?}", - unexpected - ))), + let mut result = match *value { + ast::Expr::Value(ast::Value::SingleQuotedString(item)) => item, + e => { + return Err(DataFusionError::Internal(format!( + "INTERVAL expression cannot be {:?}", + e + ))) + } }; if let Some(leading_field) = leading_field { - res = Ok(format!("{} {}", res?, leading_field)); + result = format!("{} {}", result, leading_field); }; - Ok(ScalarValue::Utf8(Some(res?))) + Ok(ScalarValue::Utf8(Some(result))) } unexpected => Err(DataFusionError::Internal(format!( "RangeBounds cannot be {:?}", @@ -157,25 +159,25 @@ pub enum WindowFrameBound { Following(ScalarValue), } -impl From for WindowFrameBound { - fn from(value: ast::WindowFrameBound) -> Self { - match value { +impl TryFrom for WindowFrameBound { + type Error = DataFusionError; + + fn try_from(value: ast::WindowFrameBound) -> Result { + Ok(match value { ast::WindowFrameBound::Preceding(Some(v)) => { - let res = convert_range_bound_to_scalar_value(v).unwrap(); - Self::Preceding(res) + Self::Preceding(convert_range_bound_to_scalar_value(v)?) } ast::WindowFrameBound::Preceding(None) => { Self::Preceding(ScalarValue::Utf8(None)) } ast::WindowFrameBound::Following(Some(v)) => { - let res = convert_range_bound_to_scalar_value(v).unwrap(); - Self::Following(res) + Self::Following(convert_range_bound_to_scalar_value(v)?) } ast::WindowFrameBound::Following(None) => { Self::Following(ScalarValue::Utf8(None)) } ast::WindowFrameBound::CurrentRow => Self::CurrentRow, - } + }) } } @@ -239,12 +241,6 @@ mod tests { use super::*; #[test] - #[ignore] - // We no longer check for validity of the preceding, following during window frame creation - // Since we are accepting different kind of types during creation, validity of that check can be only - // done when schema information is available. There is no trivial way to reject PRECEDING '1 MONTH' is - // later than PRECEDING "40 DAYS". Hence this test is ignored, However they are rejected during physical - // plan creation once schema information is available. fn test_window_frame_creation() -> Result<()> { let window_frame = ast::WindowFrame { units: ast::WindowFrameUnits::Range, @@ -253,10 +249,9 @@ mod tests { }; let result = WindowFrame::try_from(window_frame); assert_eq!( - result.err().unwrap().to_string(), - "Execution error: Invalid window frame: start bound cannot be unbounded following" - .to_owned() - ); + result.err().unwrap().to_string(), + "Execution error: Invalid window frame: start bound cannot be unbounded following".to_owned() + ); let window_frame = ast::WindowFrame { units: ast::WindowFrameUnits::Range, @@ -264,25 +259,9 @@ mod tests { end_bound: Some(ast::WindowFrameBound::Preceding(None)), }; let result = WindowFrame::try_from(window_frame); - assert_eq!( - result.err().unwrap().to_string(), - "Execution error: Invalid window frame: end bound cannot be unbounded preceding" - .to_owned() - ); - - let window_frame = ast::WindowFrame { - units: ast::WindowFrameUnits::Range, - start_bound: ast::WindowFrameBound::Preceding(Some( - ast::RangeBounds::Number("1".to_string()), - )), - end_bound: Some(ast::WindowFrameBound::Preceding(Some( - ast::RangeBounds::Number("2".to_string()), - ))), - }; - let result = WindowFrame::try_from(window_frame); assert_eq!( result.err().unwrap().to_string(), - "Execution error: Invalid window frame: start bound (1 PRECEDING) cannot be larger than end bound (2 PRECEDING)".to_owned() + "Execution error: Invalid window frame: end bound cannot be unbounded preceding".to_owned() ); let window_frame = ast::WindowFrame { @@ -298,43 +277,4 @@ mod tests { assert!(result.is_ok()); Ok(()) } - - #[test] - #[ignore] - // We now uses default PartialEq, Eq trait for the WindowFrameBound - // equality between WindowFrameBound::Preceding(ScalarValue::UInt64(Some(0))), - // and WindowFrameBound::CurrentRow will return false. - fn test_eq() { - // // Commented tests will not pass, this doesn't affect working of the - // // window frame calculation all tests pass - // assert_eq!( - // WindowFrameBound::Preceding(ScalarValue::UInt64(Some(0))), - // WindowFrameBound::CurrentRow - // ); - // assert_eq!( - // WindowFrameBound::Preceding(ScalarValue::IntervalMonthDayNano(Some(0))), - // WindowFrameBound::CurrentRow - // ); - // assert_eq!( - // WindowFrameBound::CurrentRow, - // WindowFrameBound::Following(ScalarValue::UInt64(Some(0))) - // ); - // // - assert_eq!( - WindowFrameBound::Following(ScalarValue::UInt64(Some(2))), - WindowFrameBound::Following(ScalarValue::UInt64(Some(2))) - ); - assert_eq!( - WindowFrameBound::Following(ScalarValue::Utf8(None)), - WindowFrameBound::Following(ScalarValue::Utf8(None)) - ); - assert_eq!( - WindowFrameBound::Preceding(ScalarValue::UInt64(Some(2))), - WindowFrameBound::Preceding(ScalarValue::UInt64(Some(2))) - ); - assert_eq!( - WindowFrameBound::Preceding(ScalarValue::Utf8(None)), - WindowFrameBound::Preceding(ScalarValue::Utf8(None)) - ); - } } diff --git a/datafusion/physical-expr/src/expressions/datetime.rs b/datafusion/physical-expr/src/expressions/datetime.rs index a83ce0208b6c..97f038068794 100644 --- a/datafusion/physical-expr/src/expressions/datetime.rs +++ b/datafusion/physical-expr/src/expressions/datetime.rs @@ -26,9 +26,9 @@ use arrow::datatypes::{ TimestampMillisecondType, TimestampNanosecondType, TimestampSecondType, }; use arrow::record_batch::RecordBatch; -use datafusion_common::datetime::{ - date32_add, date64_add, evaluate_scalar, microseconds_add, milliseconds_add, - nanoseconds_add, seconds_add, +use datafusion_common::scalar::{ + date32_add, date64_add, microseconds_add, milliseconds_add, nanoseconds_add, + seconds_add, }; use datafusion_common::Result; use datafusion_common::{DataFusionError, ScalarValue}; @@ -118,28 +118,29 @@ impl PhysicalExpr for DateTimeIntervalExpr { // Unwrap interval to add let intervals = match &intervals { ColumnarValue::Scalar(interval) => interval, - _ => Err(DataFusionError::Execution( - "Columnar execution is not yet supported for DateIntervalExpr" - .to_string(), - ))?, + _ => { + let msg = "Columnar execution is not yet supported for DateIntervalExpr"; + return Err(DataFusionError::Execution(msg.to_string())); + } }; // Invert sign for subtraction - let sign = match &self.op { + let sign = match self.op { Operator::Plus => 1, Operator::Minus => -1, _ => { // this should be unreachable because we check the operators in `try_new` - Err(DataFusionError::Execution( - "Invalid operator for DateIntervalExpr".to_string(), - ))? + let msg = "Invalid operator for DateIntervalExpr"; + return Err(DataFusionError::Internal(msg.to_string())); } }; match dates { - ColumnarValue::Scalar(operand) => Ok(ColumnarValue::Scalar(evaluate_scalar( - operand, sign, intervals, - )?)), + ColumnarValue::Scalar(operand) => Ok(ColumnarValue::Scalar(if sign > 0 { + operand.add(intervals)? + } else { + operand.sub(intervals)? + })), ColumnarValue::Array(array) => evaluate_array(array, sign, intervals), } } diff --git a/datafusion/physical-expr/src/window/aggregate.rs b/datafusion/physical-expr/src/window/aggregate.rs index 4ace0b94eeff..80cb4d10ce1a 100644 --- a/datafusion/physical-expr/src/window/aggregate.rs +++ b/datafusion/physical-expr/src/window/aggregate.rs @@ -162,7 +162,7 @@ fn calculate_index_of_row( return Ok(value.clone()); } if SEARCH_SIDE == is_descending { - // TODO: ADD overflow check + // TODO: Handle positive overflows value.add(delta) } else if value.is_unsigned() && value < delta { // NOTE: This gets a polymorphic zero without having long coercion code for ScalarValue. @@ -170,6 +170,7 @@ fn calculate_index_of_row( // change the following statement to use that. value.sub(value) } else { + // TODO: Handle negative overflows value.sub(delta) } }) @@ -193,27 +194,25 @@ fn calculate_current_window( match window_frame.units { WindowFrameUnits::Range => { let start = match &window_frame.start_bound { - // UNBOUNDED PRECEDING - WindowFrameBound::Preceding(ScalarValue::Utf8(None)) => Ok(0), - WindowFrameBound::Preceding(n) => calculate_index_of_row::( - range_columns, - sort_options, - idx, - Some(n), - ), + WindowFrameBound::Preceding(n) => { + if n.is_null() { + // UNBOUNDED PRECEDING + Ok(0) + } else { + calculate_index_of_row::( + range_columns, + sort_options, + idx, + Some(n), + ) + } + } WindowFrameBound::CurrentRow => calculate_index_of_row::( range_columns, sort_options, idx, None, ), - // UNBOUNDED FOLLOWING - WindowFrameBound::Following(ScalarValue::Utf8(None)) => { - Err(DataFusionError::Internal(format!( - "Frame start cannot be UNBOUNDED FOLLOWING '{:?}'", - window_frame - ))) - } WindowFrameBound::Following(n) => calculate_index_of_row::( range_columns, sort_options, @@ -222,13 +221,6 @@ fn calculate_current_window( ), }; let end = match &window_frame.end_bound { - // UNBOUNDED PRECEDING - WindowFrameBound::Preceding(ScalarValue::Utf8(None)) => { - Err(DataFusionError::Internal(format!( - "Frame end cannot be UNBOUNDED PRECEDING '{:?}'", - window_frame - ))) - } WindowFrameBound::Preceding(n) => calculate_index_of_row::( range_columns, sort_options, @@ -241,24 +233,29 @@ fn calculate_current_window( idx, None, ), - // UNBOUNDED FOLLOWING - WindowFrameBound::Following(ScalarValue::Utf8(None)) => Ok(length), - WindowFrameBound::Following(n) => calculate_index_of_row::( - range_columns, - sort_options, - idx, - Some(n), - ), + WindowFrameBound::Following(n) => { + if n.is_null() { + // UNBOUNDED FOLLOWING + Ok(length) + } else { + calculate_index_of_row::( + range_columns, + sort_options, + idx, + Some(n), + ) + } + } }; Ok((start?, end?)) } WindowFrameUnits::Rows => { - let start = match &window_frame.start_bound { + let start = match window_frame.start_bound { // UNBOUNDED PRECEDING - WindowFrameBound::Preceding(ScalarValue::Utf8(None)) => Ok(0), + WindowFrameBound::Preceding(ScalarValue::UInt64(None)) => Ok(0), WindowFrameBound::Preceding(ScalarValue::UInt64(Some(n))) => { - if idx >= *n as usize { - Ok(idx - *n as usize) + if idx >= n as usize { + Ok(idx - n as usize) } else { Ok(0) } @@ -268,14 +265,14 @@ fn calculate_current_window( } WindowFrameBound::CurrentRow => Ok(idx), // UNBOUNDED FOLLOWING - WindowFrameBound::Following(ScalarValue::Utf8(None)) => { + WindowFrameBound::Following(ScalarValue::UInt64(None)) => { Err(DataFusionError::Internal(format!( "Frame start cannot be UNBOUNDED FOLLOWING '{:?}'", window_frame ))) } WindowFrameBound::Following(ScalarValue::UInt64(Some(n))) => { - Ok(min(idx + *n as usize, length)) + Ok(min(idx + n as usize, length)) } WindowFrameBound::Following(_) => { Err(DataFusionError::Internal("Rows should be Uint".to_string())) @@ -283,7 +280,7 @@ fn calculate_current_window( }; let end = match window_frame.end_bound { // UNBOUNDED PRECEDING - WindowFrameBound::Preceding(ScalarValue::Utf8(None)) => { + WindowFrameBound::Preceding(ScalarValue::UInt64(None)) => { Err(DataFusionError::Internal(format!( "Frame end cannot be UNBOUNDED PRECEDING '{:?}'", window_frame @@ -301,7 +298,7 @@ fn calculate_current_window( } WindowFrameBound::CurrentRow => Ok(idx + 1), // UNBOUNDED FOLLOWING - WindowFrameBound::Following(ScalarValue::Utf8(None)) => Ok(length), + WindowFrameBound::Following(ScalarValue::UInt64(None)) => Ok(length), WindowFrameBound::Following(ScalarValue::UInt64(Some(n))) => { Ok(min(idx + n as usize + 1, length)) } @@ -330,16 +327,8 @@ struct AggregateWindowAccumulator { } impl AggregateWindowAccumulator { - /// This function constructs a simple window frame with a single ORDER BY. - fn implicit_order_by_window() -> Arc { - Arc::new(WindowFrame { - units: WindowFrameUnits::Range, - start_bound: WindowFrameBound::Preceding(ScalarValue::Utf8(None)), - end_bound: WindowFrameBound::CurrentRow, - }) - } /// This function calculates the aggregation on all rows in `value_slice`. - /// Returns an array of size `len`. + /// Returns an array of size `length`. fn calculate_whole_table( &mut self, value_slice: &[ArrayRef], @@ -439,20 +428,23 @@ impl AggregateWindowAccumulator { .map(|v| v.slice(value_range.start, length)) .collect::>(); let order_columns = &order_bys[self.partition_by.len()..order_bys.len()].to_vec(); - match (order_columns.len(), &self.window_frame) { - (0, None) => { + match (&order_columns[..], &self.window_frame) { + ([], None) => { // OVER () case self.calculate_whole_table(&value_slice, length) } - (_n, None) => { + ([column, ..], None) => { // OVER (ORDER BY a) case // We create an implicit window for ORDER BY. - self.window_frame = - Some(AggregateWindowAccumulator::implicit_order_by_window()); - + let empty_bound = ScalarValue::try_from(column.data_type())?; + self.window_frame = Some(Arc::new(WindowFrame { + units: WindowFrameUnits::Range, + start_bound: WindowFrameBound::Preceding(empty_bound), + end_bound: WindowFrameBound::CurrentRow, + })); self.calculate_running_window(&value_slice, order_columns, value_range) } - (0, Some(frame)) => { + ([], Some(frame)) => { match frame.units { WindowFrameUnits::Range => { // OVER (RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) case @@ -472,9 +464,7 @@ impl AggregateWindowAccumulator { } } // OVER (ORDER BY a ROWS/RANGE BETWEEN X PRECEDING AND Y FOLLOWING) case - (_n, _) => { - self.calculate_running_window(&value_slice, order_columns, value_range) - } + _ => self.calculate_running_window(&value_slice, order_columns, value_range), } } } diff --git a/datafusion/proto/src/to_proto.rs b/datafusion/proto/src/to_proto.rs index 330fe3e3ee30..36a95e17201c 100644 --- a/datafusion/proto/src/to_proto.rs +++ b/datafusion/proto/src/to_proto.rs @@ -389,47 +389,39 @@ impl From for protobuf::WindowFrameUnits { } } -impl From for protobuf::WindowFrameBound { - fn from(bound: WindowFrameBound) -> Self { - match bound { - WindowFrameBound::CurrentRow => { - let pb_value: protobuf::ScalarValue = - (&ScalarValue::Utf8(None)).try_into().unwrap(); - Self { - window_frame_bound_type: protobuf::WindowFrameBoundType::CurrentRow - .into(), - bound_value: Some(pb_value), - } - } - WindowFrameBound::Preceding(v) => { - let pb_value: protobuf::ScalarValue = (&v).try_into().unwrap(); - Self { - window_frame_bound_type: protobuf::WindowFrameBoundType::Preceding - .into(), - bound_value: Some(pb_value), - } - } - WindowFrameBound::Following(v) => { - let pb_value: protobuf::ScalarValue = (&v).try_into().unwrap(); - Self { - window_frame_bound_type: protobuf::WindowFrameBoundType::Following - .into(), - bound_value: Some(pb_value), - } - } - } +impl TryFrom<&WindowFrameBound> for protobuf::WindowFrameBound { + type Error = Error; + + fn try_from(bound: &WindowFrameBound) -> Result { + Ok(match bound { + WindowFrameBound::CurrentRow => Self { + window_frame_bound_type: protobuf::WindowFrameBoundType::CurrentRow + .into(), + bound_value: None, + }, + WindowFrameBound::Preceding(v) => Self { + window_frame_bound_type: protobuf::WindowFrameBoundType::Preceding.into(), + bound_value: Some(v.try_into()?), + }, + WindowFrameBound::Following(v) => Self { + window_frame_bound_type: protobuf::WindowFrameBoundType::Following.into(), + bound_value: Some(v.try_into()?), + }, + }) } } -impl From for protobuf::WindowFrame { - fn from(window: WindowFrame) -> Self { - Self { +impl TryFrom<&WindowFrame> for protobuf::WindowFrame { + type Error = Error; + + fn try_from(window: &WindowFrame) -> Result { + Ok(Self { window_frame_units: protobuf::WindowFrameUnits::from(window.units).into(), - start_bound: Some(window.start_bound.into()), + start_bound: Some((&window.start_bound).try_into()?), end_bound: Some(protobuf::window_frame::EndBound::Bound( - window.end_bound.into(), + (&window.end_bound).try_into()?, )), - } + }) } } @@ -555,10 +547,13 @@ impl TryFrom<&Expr> for protobuf::LogicalExprNode { .iter() .map(|e| e.try_into()) .collect::, _>>()?; - let window_frame: Option = window_frame.as_ref().map(|window_frame| { - let pb_value: protobuf::WindowFrame = window_frame.clone().into(); - protobuf::window_expr_node::WindowFrame::Frame(pb_value) - }); + + let window_frame = match window_frame { + Some(frame) => Some( + protobuf::window_expr_node::WindowFrame::Frame(frame.try_into()?) + ), + None => None + }; let window_expr = Box::new(protobuf::WindowExprNode { expr: arg_expr, window_function: Some(window_function), From d900968bd188999771d5f5e819b865161afd79a4 Mon Sep 17 00:00:00 2001 From: Mustafa Akur Date: Thu, 13 Oct 2022 09:25:41 +0300 Subject: [PATCH 05/13] new test for validity, during window frame creation is added --- datafusion/core/tests/sql/window.rs | 47 +++++++++++++++++++++++++++++ 1 file changed, 47 insertions(+) diff --git a/datafusion/core/tests/sql/window.rs b/datafusion/core/tests/sql/window.rs index 97bfebda682d..630e0ac5ceda 100644 --- a/datafusion/core/tests/sql/window.rs +++ b/datafusion/core/tests/sql/window.rs @@ -1209,3 +1209,50 @@ async fn window_frame_groups_query() -> Result<()> { .contains("Window frame definitions involving GROUPS are not supported yet")); Ok(()) } + +#[tokio::test] +async fn window_frame_creation() -> Result<()> { + let ctx = SessionContext::new(); + register_aggregate_csv(&ctx).await?; + // execute the query + let df = ctx + .sql( + "SELECT + COUNT(c1) OVER (ORDER BY c2 RANGE BETWEEN 1 PRECEDING AND 2 PRECEDING) + FROM aggregate_test_100;", + ) + .await?; + let results = df.collect().await; + assert_eq!( + results.err().unwrap().to_string(), + "Execution error: Invalid window frame: start bound (1 PRECEDING) cannot be larger than end bound (2 PRECEDING)" + ); + + let df = ctx + .sql( + "SELECT + COUNT(c1) OVER (ORDER BY c2 RANGE BETWEEN 2 FOLLOWING AND 1 FOLLOWING) + FROM aggregate_test_100;", + ) + .await?; + let results = df.collect().await; + assert_eq!( + results.err().unwrap().to_string(), + "Execution error: Invalid window frame: start bound (2 FOLLOWING) cannot be larger than end bound (1 FOLLOWING)" + ); + + let df = ctx + .sql( + "SELECT + COUNT(c1) OVER (ORDER BY c2 RANGE BETWEEN '1 DAY' PRECEDING AND '2 DAY' FOLLOWING) + FROM aggregate_test_100;", + ) + .await?; + let results = df.collect().await; + assert_eq!( + results.err().unwrap().to_string(), + "Arrow error: Cast error: Cannot cast string '1 DAY' to value of UInt32 type" + ); + + Ok(()) +} From 19058b86f3741b6d13d20d9908e2f2cea4a1198c Mon Sep 17 00:00:00 2001 From: Mustafa akur <106137913+mustafasrepo@users.noreply.github.com> Date: Fri, 14 Oct 2022 16:51:11 +0300 Subject: [PATCH 06/13] window_bound is changed to Expr (#6) --- datafusion/common/Cargo.toml | 2 +- datafusion/core/Cargo.toml | 2 +- datafusion/expr/Cargo.toml | 2 +- datafusion/expr/src/window_frame.rs | 27 ++++++++++++++++----------- datafusion/sql/Cargo.toml | 2 +- datafusion/sql/src/planner.rs | 25 ++++++++++++++++++++++++- 6 files changed, 44 insertions(+), 16 deletions(-) diff --git a/datafusion/common/Cargo.toml b/datafusion/common/Cargo.toml index ee9148c803a8..f1ee244bd7b6 100644 --- a/datafusion/common/Cargo.toml +++ b/datafusion/common/Cargo.toml @@ -48,4 +48,4 @@ ordered-float = "3.0" parquet = { version = "24.0.0", default-features = false, optional = true } pyo3 = { version = "0.17.1", optional = true } # sqlparser = "0.25" -sqlparser = { git = "https://github.com/synnada-ai/sqlparser-rs", branch = "main" } +sqlparser = { git = "https://github.com/synnada-ai/sqlparser-rs", branch = "feature_window_bound_expr" } diff --git a/datafusion/core/Cargo.toml b/datafusion/core/Cargo.toml index e99b9ca11ba9..7332bd686486 100644 --- a/datafusion/core/Cargo.toml +++ b/datafusion/core/Cargo.toml @@ -90,7 +90,7 @@ rand = "0.8" rayon = { version = "1.5", optional = true } smallvec = { version = "1.6", features = ["union"] } # sqlparser = "0.25" -sqlparser = { git = "https://github.com/synnada-ai/sqlparser-rs", branch = "main" } +sqlparser = { git = "https://github.com/synnada-ai/sqlparser-rs", branch = "feature_window_bound_expr" } tempfile = "3" tokio = { version = "1.0", features = ["macros", "rt", "rt-multi-thread", "sync", "fs", "parking_lot"] } tokio-stream = "0.1" diff --git a/datafusion/expr/Cargo.toml b/datafusion/expr/Cargo.toml index bce56da99e8a..5684a7e4c8f1 100644 --- a/datafusion/expr/Cargo.toml +++ b/datafusion/expr/Cargo.toml @@ -39,4 +39,4 @@ ahash = { version = "0.8", default-features = false, features = ["runtime-rng"] arrow = { version = "24.0.0", default-features = false } datafusion-common = { path = "../common", version = "13.0.0" } # sqlparser = "0.25" -sqlparser = { git = "https://github.com/synnada-ai/sqlparser-rs", branch = "main" } +sqlparser = { git = "https://github.com/synnada-ai/sqlparser-rs", branch = "feature_window_bound_expr" } diff --git a/datafusion/expr/src/window_frame.rs b/datafusion/expr/src/window_frame.rs index 69f7e781bef0..11e8b170d78c 100644 --- a/datafusion/expr/src/window_frame.rs +++ b/datafusion/expr/src/window_frame.rs @@ -96,16 +96,18 @@ impl Default for WindowFrame { } } -pub fn convert_range_bound_to_scalar_value(v: ast::RangeBounds) -> Result { +pub fn convert_range_bound_to_scalar_value(v: ast::Expr) -> Result { match v { - ast::RangeBounds::Number(number) => Ok(ScalarValue::Utf8(Some(number))), - ast::RangeBounds::Interval(ast::Expr::Interval { + ast::Expr::Value(ast::Value::Number(number, false)) => { + Ok(ScalarValue::Utf8(Some(number))) + } + ast::Expr::Interval { value, leading_field, leading_precision: _, last_field: _, fractional_seconds_precision: _, - }) => { + } => { let mut result = match *value { ast::Expr::Value(ast::Value::SingleQuotedString(item)) => item, e => { @@ -120,6 +122,9 @@ pub fn convert_range_bound_to_scalar_value(v: ast::RangeBounds) -> Result { + Ok(ScalarValue::Utf8(Some(literal))) + } unexpected => Err(DataFusionError::Internal(format!( "RangeBounds cannot be {:?}", unexpected @@ -165,13 +170,13 @@ impl TryFrom for WindowFrameBound { fn try_from(value: ast::WindowFrameBound) -> Result { Ok(match value { ast::WindowFrameBound::Preceding(Some(v)) => { - Self::Preceding(convert_range_bound_to_scalar_value(v)?) + Self::Preceding(convert_range_bound_to_scalar_value(*v)?) } ast::WindowFrameBound::Preceding(None) => { Self::Preceding(ScalarValue::Utf8(None)) } ast::WindowFrameBound::Following(Some(v)) => { - Self::Following(convert_range_bound_to_scalar_value(v)?) + Self::Following(convert_range_bound_to_scalar_value(*v)?) } ast::WindowFrameBound::Following(None) => { Self::Following(ScalarValue::Utf8(None)) @@ -266,12 +271,12 @@ mod tests { let window_frame = ast::WindowFrame { units: ast::WindowFrameUnits::Rows, - start_bound: ast::WindowFrameBound::Preceding(Some( - ast::RangeBounds::Number("2".to_string()), - )), - end_bound: Some(ast::WindowFrameBound::Preceding(Some( - ast::RangeBounds::Number("1".to_string()), + start_bound: ast::WindowFrameBound::Preceding(Some(Box::new( + ast::Expr::Value(ast::Value::Number("2".to_string(), false)), ))), + end_bound: Some(ast::WindowFrameBound::Preceding(Some(Box::new( + ast::Expr::Value(ast::Value::Number("1".to_string(), false)), + )))), }; let result = WindowFrame::try_from(window_frame); assert!(result.is_ok()); diff --git a/datafusion/sql/Cargo.toml b/datafusion/sql/Cargo.toml index ea03cc20277b..2a0735baec04 100644 --- a/datafusion/sql/Cargo.toml +++ b/datafusion/sql/Cargo.toml @@ -41,4 +41,4 @@ arrow = { version = "24.0.0", default-features = false } datafusion-common = { path = "../common", version = "13.0.0" } datafusion-expr = { path = "../expr", version = "13.0.0" } # sqlparser = "0.25" -sqlparser = { git = "https://github.com/synnada-ai/sqlparser-rs", branch = "main" } +sqlparser = { git = "https://github.com/synnada-ai/sqlparser-rs", branch = "feature_window_bound_expr" } diff --git a/datafusion/sql/src/planner.rs b/datafusion/sql/src/planner.rs index b2b4e66e99c2..09464c7cda24 100644 --- a/datafusion/sql/src/planner.rs +++ b/datafusion/sql/src/planner.rs @@ -67,6 +67,8 @@ use sqlparser::ast::{ColumnDef as SQLColumnDef, ColumnOption}; use sqlparser::ast::{ObjectType, OrderByExpr, Statement}; use sqlparser::parser::ParserError::ParserError; +use sqlparser::ast::ExactNumberInfo; + use super::{ parser::DFParser, utils::{ @@ -2246,6 +2248,18 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { } } + SQLExpr::Floor{expr, field: _field} => { + let fun = BuiltinScalarFunction::Floor; + let args = vec![self.sql_expr_to_logical_expr(*expr, schema, ctes)?]; + Ok(Expr::ScalarFunction { fun, args }) + } + + SQLExpr::Ceil{expr, field: _field} => { + let fun = BuiltinScalarFunction::Ceil; + let args = vec![self.sql_expr_to_logical_expr(*expr, schema, ctes)?]; + Ok(Expr::ScalarFunction { fun, args }) + } + SQLExpr::Nested(e) => self.sql_expr_to_logical_expr(*e, schema, ctes), SQLExpr::Exists{ subquery, negated } => self.parse_exists_subquery(&subquery, negated, schema, ctes), @@ -2754,7 +2768,16 @@ pub fn convert_simple_data_type(sql_type: &SQLDataType) -> Result { ))) } } - SQLDataType::Decimal(precision, scale) => make_decimal_type(*precision, *scale), + SQLDataType::Decimal(exact_number_info) => { + let (precision, scale) = match *exact_number_info { + ExactNumberInfo::None => (None, None), + ExactNumberInfo::Precision(precision) => (Some(precision), None), + ExactNumberInfo::PrecisionAndScale(precision, scale) => { + (Some(precision), Some(scale)) + } + }; + make_decimal_type(precision, scale) + } SQLDataType::Bytea => Ok(DataType::Binary), // Explicitly list all other types so that if sqlparser // adds/changes the `SQLDataType` the compiler will tell us on upgrade From 51f8f7493cc73f5442befb78451ffe6218368e56 Mon Sep 17 00:00:00 2001 From: Mustafa Akur Date: Fri, 14 Oct 2022 17:24:52 +0300 Subject: [PATCH 07/13] give fixed commit hash as dependency --- datafusion/common/Cargo.toml | 2 +- datafusion/core/Cargo.toml | 2 +- datafusion/expr/Cargo.toml | 2 +- datafusion/sql/Cargo.toml | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/datafusion/common/Cargo.toml b/datafusion/common/Cargo.toml index f1ee244bd7b6..7b8db2474a5f 100644 --- a/datafusion/common/Cargo.toml +++ b/datafusion/common/Cargo.toml @@ -48,4 +48,4 @@ ordered-float = "3.0" parquet = { version = "24.0.0", default-features = false, optional = true } pyo3 = { version = "0.17.1", optional = true } # sqlparser = "0.25" -sqlparser = { git = "https://github.com/synnada-ai/sqlparser-rs", branch = "feature_window_bound_expr" } +sqlparser = { git = "https://github.com/synnada-ai/sqlparser-rs", rev = "a03b2f2" } diff --git a/datafusion/core/Cargo.toml b/datafusion/core/Cargo.toml index 7332bd686486..56082ec3f41a 100644 --- a/datafusion/core/Cargo.toml +++ b/datafusion/core/Cargo.toml @@ -90,7 +90,7 @@ rand = "0.8" rayon = { version = "1.5", optional = true } smallvec = { version = "1.6", features = ["union"] } # sqlparser = "0.25" -sqlparser = { git = "https://github.com/synnada-ai/sqlparser-rs", branch = "feature_window_bound_expr" } +sqlparser = { git = "https://github.com/synnada-ai/sqlparser-rs", rev = "a03b2f2" } tempfile = "3" tokio = { version = "1.0", features = ["macros", "rt", "rt-multi-thread", "sync", "fs", "parking_lot"] } tokio-stream = "0.1" diff --git a/datafusion/expr/Cargo.toml b/datafusion/expr/Cargo.toml index 5684a7e4c8f1..025cb585d4e1 100644 --- a/datafusion/expr/Cargo.toml +++ b/datafusion/expr/Cargo.toml @@ -39,4 +39,4 @@ ahash = { version = "0.8", default-features = false, features = ["runtime-rng"] arrow = { version = "24.0.0", default-features = false } datafusion-common = { path = "../common", version = "13.0.0" } # sqlparser = "0.25" -sqlparser = { git = "https://github.com/synnada-ai/sqlparser-rs", branch = "feature_window_bound_expr" } +sqlparser = { git = "https://github.com/synnada-ai/sqlparser-rs", rev = "a03b2f2" } diff --git a/datafusion/sql/Cargo.toml b/datafusion/sql/Cargo.toml index 2a0735baec04..270b1e19a283 100644 --- a/datafusion/sql/Cargo.toml +++ b/datafusion/sql/Cargo.toml @@ -41,4 +41,4 @@ arrow = { version = "24.0.0", default-features = false } datafusion-common = { path = "../common", version = "13.0.0" } datafusion-expr = { path = "../expr", version = "13.0.0" } # sqlparser = "0.25" -sqlparser = { git = "https://github.com/synnada-ai/sqlparser-rs", branch = "feature_window_bound_expr" } +sqlparser = { git = "https://github.com/synnada-ai/sqlparser-rs", rev = "a03b2f2" } From 600d60bfdad6a9db9cc0f4f117aaabb1a11f2eb7 Mon Sep 17 00:00:00 2001 From: Mustafa Akur Date: Fri, 14 Oct 2022 17:40:58 +0300 Subject: [PATCH 08/13] remove locked flag --- .github/workflows/rust.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/rust.yml b/.github/workflows/rust.yml index e74af1352d2d..0ae15fbf00b5 100644 --- a/.github/workflows/rust.yml +++ b/.github/workflows/rust.yml @@ -63,7 +63,7 @@ jobs: cargo check --workspace --benches --features avro,jit,scheduler,json - name: Check Cargo.lock for datafusion-cli run: | - cargo check --manifest-path datafusion-cli/Cargo.toml --locked + cargo check --manifest-path datafusion-cli/Cargo.toml # test the crate linux-test: From 4f8a2ce1682f00b830910982c1feef17330dac3b Mon Sep 17 00:00:00 2001 From: Mehmet Ozan Kabak Date: Fri, 14 Oct 2022 11:55:16 -0400 Subject: [PATCH 09/13] Simplify some frame bound conversion functions --- datafusion/core/src/physical_plan/planner.rs | 10 +-- datafusion/expr/src/window_frame.rs | 69 +++++++++----------- 2 files changed, 36 insertions(+), 43 deletions(-) diff --git a/datafusion/core/src/physical_plan/planner.rs b/datafusion/core/src/physical_plan/planner.rs index f4df4306530f..fa8de33e0112 100644 --- a/datafusion/core/src/physical_plan/planner.rs +++ b/datafusion/core/src/physical_plan/planner.rs @@ -1397,7 +1397,7 @@ fn convert_to_column_type( } } -fn convert_range_bound_to_column_type( +fn convert_frame_bound_to_column_type( column_type: &arrow::datatypes::DataType, bound: &WindowFrameBound, ) -> Result { @@ -1512,11 +1512,11 @@ pub fn create_window_expr_with_name( })? .expr .data_type(physical_input_schema)?; - window_frame.start_bound = convert_range_bound_to_column_type( + window_frame.start_bound = convert_frame_bound_to_column_type( &column_type, &window_frame.start_bound, )?; - window_frame.end_bound = convert_range_bound_to_column_type( + window_frame.end_bound = convert_frame_bound_to_column_type( &column_type, &window_frame.end_bound, )?; @@ -1524,11 +1524,11 @@ pub fn create_window_expr_with_name( WindowFrameUnits::Rows => { // ROWS should have type usize which is Uint64 for our case let column_type = arrow::datatypes::DataType::UInt64; - window_frame.start_bound = convert_range_bound_to_column_type( + window_frame.start_bound = convert_frame_bound_to_column_type( &column_type, &window_frame.start_bound, )?; - window_frame.end_bound = convert_range_bound_to_column_type( + window_frame.end_bound = convert_frame_bound_to_column_type( &column_type, &window_frame.end_bound, )?; diff --git a/datafusion/expr/src/window_frame.rs b/datafusion/expr/src/window_frame.rs index 11e8b170d78c..da8e6d36b15d 100644 --- a/datafusion/expr/src/window_frame.rs +++ b/datafusion/expr/src/window_frame.rs @@ -96,42 +96,6 @@ impl Default for WindowFrame { } } -pub fn convert_range_bound_to_scalar_value(v: ast::Expr) -> Result { - match v { - ast::Expr::Value(ast::Value::Number(number, false)) => { - Ok(ScalarValue::Utf8(Some(number))) - } - ast::Expr::Interval { - value, - leading_field, - leading_precision: _, - last_field: _, - fractional_seconds_precision: _, - } => { - let mut result = match *value { - ast::Expr::Value(ast::Value::SingleQuotedString(item)) => item, - e => { - return Err(DataFusionError::Internal(format!( - "INTERVAL expression cannot be {:?}", - e - ))) - } - }; - if let Some(leading_field) = leading_field { - result = format!("{} {}", result, leading_field); - }; - Ok(ScalarValue::Utf8(Some(result))) - } - ast::Expr::Value(ast::Value::SingleQuotedString(literal)) => { - Ok(ScalarValue::Utf8(Some(literal))) - } - unexpected => Err(DataFusionError::Internal(format!( - "RangeBounds cannot be {:?}", - unexpected - ))), - } -} - /// There are five ways to describe starting and ending frame boundaries: /// /// 1. UNBOUNDED PRECEDING @@ -170,13 +134,13 @@ impl TryFrom for WindowFrameBound { fn try_from(value: ast::WindowFrameBound) -> Result { Ok(match value { ast::WindowFrameBound::Preceding(Some(v)) => { - Self::Preceding(convert_range_bound_to_scalar_value(*v)?) + Self::Preceding(convert_frame_bound_to_scalar_value(*v)?) } ast::WindowFrameBound::Preceding(None) => { Self::Preceding(ScalarValue::Utf8(None)) } ast::WindowFrameBound::Following(Some(v)) => { - Self::Following(convert_range_bound_to_scalar_value(*v)?) + Self::Following(convert_frame_bound_to_scalar_value(*v)?) } ast::WindowFrameBound::Following(None) => { Self::Following(ScalarValue::Utf8(None)) @@ -186,6 +150,35 @@ impl TryFrom for WindowFrameBound { } } +pub fn convert_frame_bound_to_scalar_value(v: ast::Expr) -> Result { + Ok(ScalarValue::Utf8(Some(match v { + ast::Expr::Value(ast::Value::Number(value, false)) + | ast::Expr::Value(ast::Value::SingleQuotedString(value)) => value, + ast::Expr::Interval { + value, + leading_field, + .. + } => { + let result = match *value { + ast::Expr::Value(ast::Value::SingleQuotedString(item)) => item, + e => { + let msg = format!("INTERVAL expression cannot be {:?}", e); + return Err(DataFusionError::Internal(msg)); + } + }; + if let Some(leading_field) = leading_field { + format!("{} {}", result, leading_field) + } else { + result + } + } + e => { + let msg = format!("Window frame bound cannot be {:?}", e); + return Err(DataFusionError::Internal(msg)); + } + }))) +} + impl fmt::Display for WindowFrameBound { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match self { From 08af94ee9f05f945e671b601bc89225d1ef3eb44 Mon Sep 17 00:00:00 2001 From: Mehmet Ozan Kabak Date: Sat, 15 Oct 2022 12:04:12 -0400 Subject: [PATCH 10/13] Make things work the new sqlparser --- datafusion/common/Cargo.toml | 3 +-- datafusion/core/Cargo.toml | 3 +-- datafusion/sql/Cargo.toml | 3 +-- datafusion/sql/src/planner.rs | 3 +++ 4 files changed, 6 insertions(+), 6 deletions(-) diff --git a/datafusion/common/Cargo.toml b/datafusion/common/Cargo.toml index 7b8db2474a5f..6b151a610a01 100644 --- a/datafusion/common/Cargo.toml +++ b/datafusion/common/Cargo.toml @@ -47,5 +47,4 @@ object_store = { version = "0.5.0", default-features = false, optional = true } ordered-float = "3.0" parquet = { version = "24.0.0", default-features = false, optional = true } pyo3 = { version = "0.17.1", optional = true } -# sqlparser = "0.25" -sqlparser = { git = "https://github.com/synnada-ai/sqlparser-rs", rev = "a03b2f2" } +sqlparser = { git = "https://github.com/sqlparser-rs/sqlparser-rs", branch = "main" } diff --git a/datafusion/core/Cargo.toml b/datafusion/core/Cargo.toml index 56082ec3f41a..62e0298afe25 100644 --- a/datafusion/core/Cargo.toml +++ b/datafusion/core/Cargo.toml @@ -89,8 +89,7 @@ pyo3 = { version = "0.17.1", optional = true } rand = "0.8" rayon = { version = "1.5", optional = true } smallvec = { version = "1.6", features = ["union"] } -# sqlparser = "0.25" -sqlparser = { git = "https://github.com/synnada-ai/sqlparser-rs", rev = "a03b2f2" } +sqlparser = { git = "https://github.com/sqlparser-rs/sqlparser-rs", branch = "main" } tempfile = "3" tokio = { version = "1.0", features = ["macros", "rt", "rt-multi-thread", "sync", "fs", "parking_lot"] } tokio-stream = "0.1" diff --git a/datafusion/sql/Cargo.toml b/datafusion/sql/Cargo.toml index 270b1e19a283..fa92415f0bb9 100644 --- a/datafusion/sql/Cargo.toml +++ b/datafusion/sql/Cargo.toml @@ -40,5 +40,4 @@ unicode_expressions = [] arrow = { version = "24.0.0", default-features = false } datafusion-common = { path = "../common", version = "13.0.0" } datafusion-expr = { path = "../expr", version = "13.0.0" } -# sqlparser = "0.25" -sqlparser = { git = "https://github.com/synnada-ai/sqlparser-rs", rev = "a03b2f2" } +sqlparser = { git = "https://github.com/sqlparser-rs/sqlparser-rs", branch = "main" } diff --git a/datafusion/sql/src/planner.rs b/datafusion/sql/src/planner.rs index 6f9dbbece19f..06199e4de44e 100644 --- a/datafusion/sql/src/planner.rs +++ b/datafusion/sql/src/planner.rs @@ -252,6 +252,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { if_exists, names, cascade: _, + restrict: _, purge: _, // We don't support cascade and purge for now. // nor do we support multiple object names @@ -2802,6 +2803,8 @@ pub fn convert_simple_data_type(sql_type: &SQLDataType) -> Result { | SQLDataType::Character(_) | SQLDataType::CharacterVarying(_) | SQLDataType::CharVarying(_) + | SQLDataType::CharacterLargeObject(_) + | SQLDataType::CharLargeObject(_) | SQLDataType::Clob(_) => Err(DataFusionError::NotImplemented(format!( "Unsupported SQL type {:?}", sql_type From d8630e8492708162d8d5a1e0d4d9024c6099047a Mon Sep 17 00:00:00 2001 From: Mehmet Ozan Kabak Date: Wed, 19 Oct 2022 19:04:37 -0400 Subject: [PATCH 11/13] Upgrade sqlparser version to 0.26 --- datafusion-cli/Cargo.lock | 146 ++++++++++++++++++++++++----------- datafusion/common/Cargo.toml | 2 +- datafusion/core/Cargo.toml | 2 +- datafusion/expr/Cargo.toml | 2 +- datafusion/sql/Cargo.toml | 2 +- 5 files changed, 106 insertions(+), 48 deletions(-) diff --git a/datafusion-cli/Cargo.lock b/datafusion-cli/Cargo.lock index 364c412fb02d..49d5012c6fa5 100644 --- a/datafusion-cli/Cargo.lock +++ b/datafusion-cli/Cargo.lock @@ -166,9 +166,9 @@ dependencies = [ [[package]] name = "async-trait" -version = "0.1.57" +version = "0.1.58" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "76464446b8bc32758d7e88ee1a804d9914cd9b1cb264c029899680b0be29826f" +checksum = "1e805d94e6b5001b651426cf4cd446b1ab5f319d27bab5c644f61de0a804360c" dependencies = [ "proc-macro2", "quote", @@ -271,9 +271,9 @@ dependencies = [ [[package]] name = "bumpalo" -version = "3.11.0" +version = "3.11.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c1ad822118d20d2c234f427000d5acc36eabe1e29a348c89b63dd60b13f28e5d" +checksum = "572f695136211188308f16ad2ca5c851a712c464060ae6974944458eb83880ba" [[package]] name = "byteorder" @@ -398,9 +398,9 @@ dependencies = [ [[package]] name = "comfy-table" -version = "6.1.0" +version = "6.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "85914173c2f558d61613bfbbf1911f14e630895087a7ed2fafc0f5319e1536e7" +checksum = "7b3d16bb3da60be2f7c7acfc438f2ae6f3496897ce68c291d0509bb67b4e248e" dependencies = [ "strum", "strum_macros", @@ -499,9 +499,9 @@ dependencies = [ [[package]] name = "cxx" -version = "1.0.78" +version = "1.0.79" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "19f39818dcfc97d45b03953c1292efc4e80954e1583c4aa770bac1383e2310a4" +checksum = "3f83d0ebf42c6eafb8d7c52f7e5f2d3003b89c7aa4fd2b79229209459a849af8" dependencies = [ "cc", "cxxbridge-flags", @@ -511,9 +511,9 @@ dependencies = [ [[package]] name = "cxx-build" -version = "1.0.78" +version = "1.0.79" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3e580d70777c116df50c390d1211993f62d40302881e54d4b79727acb83d0199" +checksum = "07d050484b55975889284352b0ffc2ecbda25c0c55978017c132b29ba0818a86" dependencies = [ "cc", "codespan-reporting", @@ -526,15 +526,15 @@ dependencies = [ [[package]] name = "cxxbridge-flags" -version = "1.0.78" +version = "1.0.79" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "56a46460b88d1cec95112c8c363f0e2c39afdb237f60583b0b36343bf627ea9c" +checksum = "99d2199b00553eda8012dfec8d3b1c75fce747cf27c169a270b3b99e3448ab78" [[package]] name = "cxxbridge-macro" -version = "1.0.78" +version = "1.0.79" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "747b608fecf06b0d72d440f27acc99288207324b793be2c17991839f3d4995ea" +checksum = "dcb67a6de1f602736dd7eaead0080cf3435df806c61b24b13328db128c58868f" dependencies = [ "proc-macro2", "quote", @@ -567,7 +567,7 @@ dependencies = [ "log", "num_cpus", "object_store", - "ordered-float 3.2.0", + "ordered-float 3.3.0", "parking_lot", "parquet", "paste", @@ -605,8 +605,9 @@ name = "datafusion-common" version = "13.0.0" dependencies = [ "arrow", + "chrono", "object_store", - "ordered-float 3.2.0", + "ordered-float 3.3.0", "parquet", "sqlparser", ] @@ -651,7 +652,7 @@ dependencies = [ "hashbrown", "lazy_static", "md-5", - "ordered-float 3.2.0", + "ordered-float 3.3.0", "paste", "rand", "regex", @@ -819,7 +820,7 @@ checksum = "e11dcc7e4d79a8c89b9ab4c6f5c30b1fc4a83c420792da3542fd31179ed5f517" dependencies = [ "cfg-if", "rustix", - "windows-sys", + "windows-sys 0.36.1", ] [[package]] @@ -1120,9 +1121,9 @@ dependencies = [ [[package]] name = "iana-time-zone-haiku" -version = "0.1.0" +version = "0.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fde6edd6cef363e9359ed3c98ba64590ba9eecba2293eb5a723ab32aee8926aa" +checksum = "0703ae284fc167426161c2e3f1da3ea71d94b21bedbcc9494e92b28e334e3dca" dependencies = [ "cxx", "cxx-build", @@ -1298,9 +1299,9 @@ checksum = "292a948cd991e376cf75541fe5b97a1081d713c618b4f1b9500f8844e49eb565" [[package]] name = "libmimalloc-sys" -version = "0.1.25" +version = "0.1.26" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "11ca136052550448f55df7898c6dbe651c6b574fe38a0d9ea687a9f8088a2e2c" +checksum = "8fc093ab289b0bfda3aa1bdfab9c9542be29c7ef385cfcbe77f8c9813588eb48" dependencies = [ "cc", ] @@ -1376,9 +1377,9 @@ checksum = "2dffe52ecf27772e601905b7522cb4ef790d2cc203488bbd0e2fe85fcb74566d" [[package]] name = "mimalloc" -version = "0.1.29" +version = "0.1.30" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2f64ad83c969af2e732e907564deb0d0ed393cec4af80776f77dd77a1a427698" +checksum = "76ce6a4b40d3bff9eb3ce9881ca0737a85072f9f975886082640cd46a75cdb35" dependencies = [ "libmimalloc-sys", ] @@ -1407,7 +1408,7 @@ dependencies = [ "libc", "log", "wasi", - "windows-sys", + "windows-sys 0.36.1", ] [[package]] @@ -1582,9 +1583,9 @@ dependencies = [ [[package]] name = "ordered-float" -version = "3.2.0" +version = "3.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "129d36517b53c461acc6e1580aeb919c8ae6708a4b1eae61c4463a615d4f0411" +checksum = "1f74e330193f90ec45e2b257fa3ef6df087784157ac1ad2c1e71c62837b03aa7" dependencies = [ "num-traits", ] @@ -1607,15 +1608,15 @@ dependencies = [ [[package]] name = "parking_lot_core" -version = "0.9.3" +version = "0.9.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "09a279cbf25cb0757810394fbc1e359949b59e348145c643a939a525692e6929" +checksum = "4dc9e0dc2adc1c69d09143aff38d3d30c5c3f0df0dad82e6d25547af174ebec0" dependencies = [ "cfg-if", "libc", "redox_syscall", "smallvec", - "windows-sys", + "windows-sys 0.42.0", ] [[package]] @@ -1712,9 +1713,9 @@ checksum = "dbf0c48bc1d91375ae5c3cd81e3722dff1abcf81a30960240640d223f59fe0e5" [[package]] name = "proc-macro2" -version = "1.0.46" +version = "1.0.47" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "94e2ef8dbfc347b10c094890f778ee2e36ca9bb4262e86dc99cd217e35f3470b" +checksum = "5ea3d908b0e36316caf9e9e2c4625cdde190a7e6f440d794667ed17a1855e725" dependencies = [ "unicode-ident", ] @@ -1896,14 +1897,14 @@ dependencies = [ "io-lifetimes", "libc", "linux-raw-sys", - "windows-sys", + "windows-sys 0.36.1", ] [[package]] name = "rustls" -version = "0.20.6" +version = "0.20.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5aab8ee6c7097ed6057f43c187a62418d0c05a4bd5f18b3571db50ee0f9ce033" +checksum = "539a2bfe908f471bfa933876bd1eb6a19cf2176d375f82ef7f99530a40e48c2c" dependencies = [ "log", "ring", @@ -2014,9 +2015,9 @@ dependencies = [ [[package]] name = "serde_json" -version = "1.0.86" +version = "1.0.87" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "41feea4228a6f1cd09ec7a3593a682276702cd67b5273544757dae23c096f074" +checksum = "6ce777b7b150d76b9cf60d28b55f5847135a003f7d7350c6be7a773508ce7d45" dependencies = [ "itoa 1.0.4", "ryu", @@ -2107,9 +2108,9 @@ checksum = "6e63cff320ae2c57904679ba7cb63280a3dc4613885beafb148ee7bf9aa9042d" [[package]] name = "sqlparser" -version = "0.25.0" +version = "0.26.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0781f2b6bd03e5adf065c8e772b49eaea9f640d06a1b9130330fe8bd2563f4fd" +checksum = "86be66ea0b2b22749cfa157d16e2e84bf793e626a3375f4d378dc289fa03affb" dependencies = [ "log", ] @@ -2598,43 +2599,100 @@ version = "0.36.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ea04155a16a59f9eab786fe12a4a450e75cdb175f9e0d80da1e17db09f55b8d2" dependencies = [ - "windows_aarch64_msvc", - "windows_i686_gnu", - "windows_i686_msvc", - "windows_x86_64_gnu", - "windows_x86_64_msvc", + "windows_aarch64_msvc 0.36.1", + "windows_i686_gnu 0.36.1", + "windows_i686_msvc 0.36.1", + "windows_x86_64_gnu 0.36.1", + "windows_x86_64_msvc 0.36.1", +] + +[[package]] +name = "windows-sys" +version = "0.42.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5a3e1820f08b8513f676f7ab6c1f99ff312fb97b553d30ff4dd86f9f15728aa7" +dependencies = [ + "windows_aarch64_gnullvm", + "windows_aarch64_msvc 0.42.0", + "windows_i686_gnu 0.42.0", + "windows_i686_msvc 0.42.0", + "windows_x86_64_gnu 0.42.0", + "windows_x86_64_gnullvm", + "windows_x86_64_msvc 0.42.0", ] +[[package]] +name = "windows_aarch64_gnullvm" +version = "0.42.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "41d2aa71f6f0cbe00ae5167d90ef3cfe66527d6f613ca78ac8024c3ccab9a19e" + [[package]] name = "windows_aarch64_msvc" version = "0.36.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9bb8c3fd39ade2d67e9874ac4f3db21f0d710bee00fe7cab16949ec184eeaa47" +[[package]] +name = "windows_aarch64_msvc" +version = "0.42.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dd0f252f5a35cac83d6311b2e795981f5ee6e67eb1f9a7f64eb4500fbc4dcdb4" + [[package]] name = "windows_i686_gnu" version = "0.36.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "180e6ccf01daf4c426b846dfc66db1fc518f074baa793aa7d9b9aaeffad6a3b6" +[[package]] +name = "windows_i686_gnu" +version = "0.42.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fbeae19f6716841636c28d695375df17562ca208b2b7d0dc47635a50ae6c5de7" + [[package]] name = "windows_i686_msvc" version = "0.36.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e2e7917148b2812d1eeafaeb22a97e4813dfa60a3f8f78ebe204bcc88f12f024" +[[package]] +name = "windows_i686_msvc" +version = "0.42.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "84c12f65daa39dd2babe6e442988fc329d6243fdce47d7d2d155b8d874862246" + [[package]] name = "windows_x86_64_gnu" version = "0.36.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4dcd171b8776c41b97521e5da127a2d86ad280114807d0b2ab1e462bc764d9e1" +[[package]] +name = "windows_x86_64_gnu" +version = "0.42.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bf7b1b21b5362cbc318f686150e5bcea75ecedc74dd157d874d754a2ca44b0ed" + +[[package]] +name = "windows_x86_64_gnullvm" +version = "0.42.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "09d525d2ba30eeb3297665bd434a54297e4170c7f1a44cad4ef58095b4cd2028" + [[package]] name = "windows_x86_64_msvc" version = "0.36.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c811ca4a8c853ef420abd8592ba53ddbbac90410fab6903b3e79972a631f7680" +[[package]] +name = "windows_x86_64_msvc" +version = "0.42.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f40009d85759725a34da6d89a94e63d7bdc50a862acf0dbc7c8e488f1edcb6f5" + [[package]] name = "winreg" version = "0.10.1" diff --git a/datafusion/common/Cargo.toml b/datafusion/common/Cargo.toml index 7c95cd255319..e3da20b22dfc 100644 --- a/datafusion/common/Cargo.toml +++ b/datafusion/common/Cargo.toml @@ -47,4 +47,4 @@ object_store = { version = "0.5.0", default-features = false, optional = true } ordered-float = "3.0" parquet = { version = "25.0.0", default-features = false, optional = true } pyo3 = { version = "0.17.1", optional = true } -sqlparser = { git = "https://github.com/sqlparser-rs/sqlparser-rs", branch = "main" } +sqlparser = "0.26" diff --git a/datafusion/core/Cargo.toml b/datafusion/core/Cargo.toml index 42c5b9497669..616c16b4f02e 100644 --- a/datafusion/core/Cargo.toml +++ b/datafusion/core/Cargo.toml @@ -89,7 +89,7 @@ pyo3 = { version = "0.17.1", optional = true } rand = "0.8" rayon = { version = "1.5", optional = true } smallvec = { version = "1.6", features = ["union"] } -sqlparser = { git = "https://github.com/sqlparser-rs/sqlparser-rs", branch = "main" } +sqlparser = "0.26" tempfile = "3" tokio = { version = "1.0", features = ["macros", "rt", "rt-multi-thread", "sync", "fs", "parking_lot"] } tokio-stream = "0.1" diff --git a/datafusion/expr/Cargo.toml b/datafusion/expr/Cargo.toml index 7dac223d0e9b..afa2be239ffe 100644 --- a/datafusion/expr/Cargo.toml +++ b/datafusion/expr/Cargo.toml @@ -39,4 +39,4 @@ ahash = { version = "0.8", default-features = false, features = ["runtime-rng"] arrow = { version = "25.0.0", default-features = false } datafusion-common = { path = "../common", version = "13.0.0" } log = "^0.4" -sqlparser = { git = "https://github.com/sqlparser-rs/sqlparser-rs", branch = "main" } +sqlparser = "0.26" diff --git a/datafusion/sql/Cargo.toml b/datafusion/sql/Cargo.toml index 04d3a2d0d21a..44f3860f676a 100644 --- a/datafusion/sql/Cargo.toml +++ b/datafusion/sql/Cargo.toml @@ -40,4 +40,4 @@ unicode_expressions = [] arrow = { version = "25.0.0", default-features = false } datafusion-common = { path = "../common", version = "13.0.0" } datafusion-expr = { path = "../expr", version = "13.0.0" } -sqlparser = { git = "https://github.com/sqlparser-rs/sqlparser-rs", branch = "main" } +sqlparser = "0.26" From 082753c2dbcd64b063b5bd229d08d319cc3a9f7e Mon Sep 17 00:00:00 2001 From: Mustafa Akur Date: Mon, 24 Oct 2022 09:47:17 +0300 Subject: [PATCH 12/13] Minor changes --- datafusion/common/src/scalar.rs | 4 ++-- datafusion/expr/src/window_frame.rs | 11 ++++++----- 2 files changed, 8 insertions(+), 7 deletions(-) diff --git a/datafusion/common/src/scalar.rs b/datafusion/common/src/scalar.rs index 69c752049498..fe8d0019b6db 100644 --- a/datafusion/common/src/scalar.rs +++ b/datafusion/common/src/scalar.rs @@ -598,7 +598,7 @@ where }) } -// Can remove once https://github.com/apache/arrow-rs/pull/2031 is released +// Can remove once chrono:0.4.23 is released fn add_m_d_nano(prior: D, interval: i128, sign: i32) -> D where D: Datelike + Add, @@ -612,7 +612,7 @@ where b.add(Duration::nanoseconds(nanos)) } -// Can remove once https://github.com/apache/arrow-rs/pull/2031 is released +// Can remove once chrono:0.4.23 is released fn add_day_time(prior: D, interval: i64, sign: i32) -> D where D: Datelike + Add, diff --git a/datafusion/expr/src/window_frame.rs b/datafusion/expr/src/window_frame.rs index da8e6d36b15d..5bf81d165db5 100644 --- a/datafusion/expr/src/window_frame.rs +++ b/datafusion/expr/src/window_frame.rs @@ -25,6 +25,7 @@ use datafusion_common::{DataFusionError, Result, ScalarValue}; use sqlparser::ast; +use sqlparser::parser::ParserError::ParserError; use std::convert::{From, TryFrom}; use std::fmt; use std::hash::Hash; @@ -163,7 +164,7 @@ pub fn convert_frame_bound_to_scalar_value(v: ast::Expr) -> Result ast::Expr::Value(ast::Value::SingleQuotedString(item)) => item, e => { let msg = format!("INTERVAL expression cannot be {:?}", e); - return Err(DataFusionError::Internal(msg)); + return Err(DataFusionError::SQL(ParserError(msg))); } }; if let Some(leading_field) = leading_field { @@ -245,9 +246,9 @@ mod tests { start_bound: ast::WindowFrameBound::Following(None), end_bound: None, }; - let result = WindowFrame::try_from(window_frame); + let err = WindowFrame::try_from(window_frame).unwrap_err(); assert_eq!( - result.err().unwrap().to_string(), + err.to_string(), "Execution error: Invalid window frame: start bound cannot be unbounded following".to_owned() ); @@ -256,9 +257,9 @@ mod tests { start_bound: ast::WindowFrameBound::Preceding(None), end_bound: Some(ast::WindowFrameBound::Preceding(None)), }; - let result = WindowFrame::try_from(window_frame); + let err = WindowFrame::try_from(window_frame).unwrap_err(); assert_eq!( - result.err().unwrap().to_string(), + err.to_string(), "Execution error: Invalid window frame: end bound cannot be unbounded preceding".to_owned() ); From b78eb162910e31cdee6b2747f1031c0514e65b96 Mon Sep 17 00:00:00 2001 From: Mustafa Akur Date: Tue, 25 Oct 2022 14:01:36 +0300 Subject: [PATCH 13/13] type coercion for window frames moved to the optimizer. --- datafusion/core/src/physical_plan/planner.rs | 105 +++------------- datafusion/core/tests/sql/window.rs | 56 ++++----- datafusion/expr/src/type_coercion.rs | 10 ++ datafusion/optimizer/src/type_coercion.rs | 119 ++++++++++++++++++- datafusion/sql/src/utils.rs | 10 ++ 5 files changed, 183 insertions(+), 117 deletions(-) diff --git a/datafusion/core/src/physical_plan/planner.rs b/datafusion/core/src/physical_plan/planner.rs index 148fb334c4e7..83ae71d66a18 100644 --- a/datafusion/core/src/physical_plan/planner.rs +++ b/datafusion/core/src/physical_plan/planner.rs @@ -56,9 +56,9 @@ use crate::{ physical_plan::displayable, }; use arrow::compute::SortOptions; -use arrow::datatypes::{DataType, Schema, SchemaRef}; +use arrow::datatypes::{Schema, SchemaRef}; use async_trait::async_trait; -use datafusion_common::{parse_interval, DFSchema, ScalarValue}; +use datafusion_common::{DFSchema, ScalarValue}; use datafusion_expr::expr::{Between, BinaryExpr, GetIndexedField, GroupingSet, Like}; use datafusion_expr::expr_rewriter::unnormalize_cols; use datafusion_expr::utils::{expand_wildcard, expr_to_columns}; @@ -550,6 +550,16 @@ impl DefaultPhysicalPlanner { ref order_by, .. } => generate_sort_key(partition_by, order_by), + Expr::Alias(expr, _) => { + // Convert &Box to &T + match &**expr { + Expr::WindowFunction { + ref partition_by, + ref order_by, + ..} => generate_sort_key(partition_by, order_by), + _ => unreachable!(), + } + } _ => unreachable!(), }; let sort_keys = get_sort_keys(&window_expr[0]); @@ -1367,49 +1377,6 @@ fn get_physical_expr_pair( let physical_name = physical_name(expr)?; Ok((physical_expr, physical_name)) } -/// Casts the ScalarValue `value` to column type once we have schema information -/// The resulting type is not necessarily same type with the `column_type`. For instance -/// if `column_type` is Timestamp the result is casted to Interval type. The reason is that -/// Operation between Timestamps is not meaningful, However operation between Timestamp and -/// Interval is valid. For basic types `column_type` is indeed the resulting type. -fn convert_to_column_type( - column_type: &arrow::datatypes::DataType, - value: &ScalarValue, -) -> Result { - match value { - // In here we can either get ScalarValue::Utf8(None) or - // ScalarValue::Utf8(Some(val)). The reason is that we convert the sqlparser result - // to the Utf8 for all possible cases, since we have no schema information during conversion. - // Here we have schema information, hence we can cast the appropriate ScalarValue Type. - ScalarValue::Utf8(None) => ScalarValue::try_from(column_type), - ScalarValue::Utf8(Some(val)) => { - if let DataType::Timestamp(..) = column_type { - parse_interval("millisecond", val) - } else { - ScalarValue::try_from_string(val.clone(), column_type) - } - } - s => Err(DataFusionError::Internal(format!( - "Unexpected value: {:?}", - s - ))), - } -} - -fn convert_frame_bound_to_column_type( - column_type: &arrow::datatypes::DataType, - bound: &WindowFrameBound, -) -> Result { - Ok(match bound { - WindowFrameBound::Preceding(val) => { - WindowFrameBound::Preceding(convert_to_column_type(column_type, val)?) - } - WindowFrameBound::CurrentRow => WindowFrameBound::CurrentRow, - WindowFrameBound::Following(val) => { - WindowFrameBound::Following(convert_to_column_type(column_type, val)?) - } - }) -} /// Check if window bounds are valid after schema information is available, and /// window_frame bounds are casted to the corresponding column type. @@ -1492,46 +1459,12 @@ pub fn create_window_expr_with_name( )), }) .collect::>>()?; - let mut window_frame = window_frame.clone(); - if let Some(ref mut window_frame) = window_frame { - match window_frame.units { - WindowFrameUnits::Groups => { - return Err(DataFusionError::NotImplemented( - "Window frame definitions involving GROUPS are not supported yet" + if let Some(ref window_frame) = window_frame { + if window_frame.units == WindowFrameUnits::Groups { + return Err(DataFusionError::NotImplemented( + "Window frame definitions involving GROUPS are not supported yet" .to_string(), - )); - } - WindowFrameUnits::Range => { - let column_type = order_by - .first() - .ok_or_else(|| { - DataFusionError::Internal( - "ORDER BY column cannot be empty".to_string(), - ) - })? - .expr - .data_type(physical_input_schema)?; - window_frame.start_bound = convert_frame_bound_to_column_type( - &column_type, - &window_frame.start_bound, - )?; - window_frame.end_bound = convert_frame_bound_to_column_type( - &column_type, - &window_frame.end_bound, - )?; - } - WindowFrameUnits::Rows => { - // ROWS should have type usize which is Uint64 for our case - let column_type = arrow::datatypes::DataType::UInt64; - window_frame.start_bound = convert_frame_bound_to_column_type( - &column_type, - &window_frame.start_bound, - )?; - window_frame.end_bound = convert_frame_bound_to_column_type( - &column_type, - &window_frame.end_bound, - )?; - } + )); } if !is_window_valid(window_frame) { return Err(DataFusionError::Execution(format!( @@ -1540,14 +1473,14 @@ pub fn create_window_expr_with_name( ))); } } - + let window_frame = window_frame.clone().map(Arc::new); windows::create_window_expr( fun, name, &args, &partition_by, &order_by, - window_frame.map(Arc::new), + window_frame, physical_input_schema, ) } diff --git a/datafusion/core/tests/sql/window.rs b/datafusion/core/tests/sql/window.rs index 630e0ac5ceda..d9ede9771858 100644 --- a/datafusion/core/tests/sql/window.rs +++ b/datafusion/core/tests/sql/window.rs @@ -978,22 +978,22 @@ async fn window_frame_ranges_unbounded_preceding_following() -> Result<()> { let ctx = SessionContext::new(); register_aggregate_csv(&ctx).await?; let sql = "SELECT \ - SUM(c2) OVER (ORDER BY c2 RANGE BETWEEN UNBOUNDED PRECEDING AND 1 FOLLOWING), \ - COUNT(*) OVER (ORDER BY c2 RANGE BETWEEN UNBOUNDED PRECEDING AND 1 FOLLOWING) \ + SUM(c2) OVER (ORDER BY c2 RANGE BETWEEN UNBOUNDED PRECEDING AND 1 FOLLOWING) as sum1, \ + COUNT(*) OVER (ORDER BY c2 RANGE BETWEEN UNBOUNDED PRECEDING AND 1 FOLLOWING) as cnt1 \ FROM aggregate_test_100 \ ORDER BY c9 \ LIMIT 5"; let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ - "+----------------------------+-----------------+", - "| SUM(aggregate_test_100.c2) | COUNT(UInt8(1)) |", - "+----------------------------+-----------------+", - "| 285 | 100 |", - "| 123 | 63 |", - "| 285 | 100 |", - "| 123 | 63 |", - "| 123 | 63 |", - "+----------------------------+-----------------+", + "+------+------+", + "| sum1 | cnt1 |", + "+------+------+", + "| 285 | 100 |", + "| 123 | 63 |", + "| 285 | 100 |", + "| 123 | 63 |", + "| 123 | 63 |", + "+------+------+", ]; assert_batches_eq!(expected, &actual); Ok(()) @@ -1138,9 +1138,9 @@ async fn window_frame_ranges_timestamp() -> Result<()> { .sql( "SELECT ts, - COUNT(*) OVER (ORDER BY ts RANGE BETWEEN INTERVAL '1' DAY PRECEDING AND INTERVAL '2 DAY' FOLLOWING), - COUNT(*) OVER (ORDER BY ts RANGE BETWEEN '0 DAY' PRECEDING AND '0' DAY FOLLOWING), - COUNT(*) OVER (ORDER BY ts RANGE BETWEEN '5' SECOND PRECEDING AND CURRENT ROW) + COUNT(*) OVER (ORDER BY ts RANGE BETWEEN INTERVAL '1' DAY PRECEDING AND INTERVAL '2 DAY' FOLLOWING) AS cnt1, + COUNT(*) OVER (ORDER BY ts RANGE BETWEEN '0 DAY' PRECEDING AND '0' DAY FOLLOWING) as cnt2, + COUNT(*) OVER (ORDER BY ts RANGE BETWEEN '5' SECOND PRECEDING AND CURRENT ROW) as cnt3 FROM t ORDER BY ts" ) @@ -1148,18 +1148,18 @@ async fn window_frame_ranges_timestamp() -> Result<()> { let actual = df.collect().await?; let expected = vec![ - "+---------------------+-----------------+-----------------+-----------------+", - "| ts | COUNT(UInt8(1)) | COUNT(UInt8(1)) | COUNT(UInt8(1)) |", - "+---------------------+-----------------+-----------------+-----------------+", - "| 2022-09-27 07:43:11 | 6 | 1 | 1 |", - "| 2022-09-27 07:43:12 | 6 | 2 | 3 |", - "| 2022-09-27 07:43:12 | 6 | 2 | 3 |", - "| 2022-09-27 07:43:13 | 6 | 1 | 4 |", - "| 2022-09-27 07:43:14 | 6 | 1 | 5 |", - "| 2022-09-28 11:29:54 | 2 | 1 | 1 |", - "| 2022-09-29 15:16:34 | 2 | 1 | 1 |", - "| 2022-09-30 19:03:14 | 1 | 1 | 1 |", - "+---------------------+-----------------+-----------------+-----------------+", + "+---------------------+------+------+------+", + "| ts | cnt1 | cnt2 | cnt3 |", + "+---------------------+------+------+------+", + "| 2022-09-27 07:43:11 | 6 | 1 | 1 |", + "| 2022-09-27 07:43:12 | 6 | 2 | 3 |", + "| 2022-09-27 07:43:12 | 6 | 2 | 3 |", + "| 2022-09-27 07:43:13 | 6 | 1 | 4 |", + "| 2022-09-27 07:43:14 | 6 | 1 | 5 |", + "| 2022-09-28 11:29:54 | 2 | 1 | 1 |", + "| 2022-09-29 15:16:34 | 2 | 1 | 1 |", + "| 2022-09-30 19:03:14 | 1 | 1 | 1 |", + "+---------------------+------+------+------+", ]; assert_batches_eq!(expected, &actual); Ok(()) @@ -1249,9 +1249,9 @@ async fn window_frame_creation() -> Result<()> { ) .await?; let results = df.collect().await; - assert_eq!( + assert_contains!( results.err().unwrap().to_string(), - "Arrow error: Cast error: Cannot cast string '1 DAY' to value of UInt32 type" + "Arrow error: External error: Internal error: Operator - is not implemented for types UInt32(1) and Utf8(\"1 DAY\")" ); Ok(()) diff --git a/datafusion/expr/src/type_coercion.rs b/datafusion/expr/src/type_coercion.rs index 4a006ad87a27..02702b113ba3 100644 --- a/datafusion/expr/src/type_coercion.rs +++ b/datafusion/expr/src/type_coercion.rs @@ -62,6 +62,16 @@ pub fn is_numeric(dt: &DataType) -> bool { ) } +/// Determine if a DataType is Timestamp or not +pub fn is_timestamp(dt: &DataType) -> bool { + matches!(dt, DataType::Timestamp(_, _)) +} + +/// Determine if a DataType is Date or not +pub fn is_date(dt: &DataType) -> bool { + matches!(dt, DataType::Date32 | DataType::Date64) +} + pub mod aggregates; pub mod binary; pub mod functions; diff --git a/datafusion/optimizer/src/type_coercion.rs b/datafusion/optimizer/src/type_coercion.rs index 2833eee048b4..ae1327ed1888 100644 --- a/datafusion/optimizer/src/type_coercion.rs +++ b/datafusion/optimizer/src/type_coercion.rs @@ -19,8 +19,10 @@ use crate::utils::rewrite_preserving_name; use crate::{OptimizerConfig, OptimizerRule}; -use arrow::datatypes::DataType; -use datafusion_common::{DFSchema, DFSchemaRef, DataFusionError, Result}; +use arrow::datatypes::{DataType, IntervalUnit}; +use datafusion_common::{ + parse_interval, DFSchema, DFSchemaRef, DataFusionError, Result, ScalarValue, +}; use datafusion_expr::expr::{Between, BinaryExpr, Case, Like}; use datafusion_expr::expr_rewriter::{ExprRewriter, RewriteRecursion}; use datafusion_expr::logical_plan::Subquery; @@ -29,10 +31,12 @@ use datafusion_expr::type_coercion::functions::data_types; use datafusion_expr::type_coercion::other::{ get_coerce_type_for_case_when, get_coerce_type_for_list, }; +use datafusion_expr::type_coercion::{is_date, is_numeric, is_timestamp}; use datafusion_expr::utils::from_plan; use datafusion_expr::{ aggregate_function, function, is_false, is_not_false, is_not_true, is_not_unknown, is_true, is_unknown, type_coercion, AggregateFunction, Expr, LogicalPlan, Operator, + WindowFrame, WindowFrameBound, WindowFrameUnits, }; use datafusion_expr::{ExprSchemable, Signature}; use std::sync::Arc; @@ -72,7 +76,6 @@ fn optimize_internal( .iter() .map(|p| optimize_internal(external_schema, p, optimizer_config)) .collect::>>()?; - // get schema representing all available input fields. This is used for data type // resolution only, so order does not matter here let mut schema = new_inputs.iter().map(|input| input.schema()).fold( @@ -410,11 +413,121 @@ impl ExprRewriter for TypeCoercionRewriter { }; Ok(expr) } + Expr::WindowFunction { + fun, + args, + partition_by, + order_by, + window_frame, + } => { + let window_frame = + get_coerced_window_frame(window_frame, &self.schema, &order_by)?; + let expr = Expr::WindowFunction { + fun, + args, + partition_by, + order_by, + window_frame, + }; + Ok(expr) + } expr => Ok(expr), } } } +/// Casts the ScalarValue `value` to coerced type. +// When coerced type is `Interval` we use `parse_interval` since `try_from_string` not +// supports conversion from string to Interval +fn convert_to_coerced_type( + coerced_type: &DataType, + value: &ScalarValue, +) -> Result { + match value { + // In here we do casting either for ScalarValue::Utf8(None) or + // ScalarValue::Utf8(Some(val)). The other types are already casted. + // The reason is that we convert the sqlparser result + // to the Utf8 for all possible cases. Hence the types other than Utf8 + // are already casted to appropriate type. Therefore they can be returned directly. + ScalarValue::Utf8(None) => ScalarValue::try_from(coerced_type), + ScalarValue::Utf8(Some(val)) => { + // we need special handling for Interval types + if let DataType::Interval(..) = coerced_type { + parse_interval("millisecond", val) + } else { + ScalarValue::try_from_string(val.clone(), coerced_type) + } + } + s => Ok(s.clone()), + } +} + +fn coerce_frame_bound( + coerced_type: &DataType, + bound: &WindowFrameBound, +) -> Result { + Ok(match bound { + WindowFrameBound::Preceding(val) => { + WindowFrameBound::Preceding(convert_to_coerced_type(coerced_type, val)?) + } + WindowFrameBound::CurrentRow => WindowFrameBound::CurrentRow, + WindowFrameBound::Following(val) => { + WindowFrameBound::Following(convert_to_coerced_type(coerced_type, val)?) + } + }) +} + +fn get_coerced_window_frame( + window_frame: Option, + schema: &DFSchemaRef, + expressions: &[Expr], +) -> Result> { + fn get_coerced_type(column_type: &DataType) -> Result { + if is_numeric(column_type) { + Ok(column_type.clone()) + } else if is_timestamp(column_type) || is_date(column_type) { + Ok(DataType::Interval(IntervalUnit::MonthDayNano)) + } else { + Err(DataFusionError::Internal(format!( + "Cannot run range queries on datatype: {:?}", + column_type + ))) + } + } + + if let Some(window_frame) = window_frame { + let mut window_frame = window_frame; + let current_types = expressions + .iter() + .map(|e| e.get_type(schema)) + .collect::>>()?; + match &mut window_frame.units { + WindowFrameUnits::Range => { + let col_type = current_types.first().ok_or_else(|| { + DataFusionError::Internal( + "ORDER BY column cannot be empty".to_string(), + ) + })?; + let coerced_type = get_coerced_type(col_type)?; + window_frame.start_bound = + coerce_frame_bound(&coerced_type, &window_frame.start_bound)?; + window_frame.end_bound = + coerce_frame_bound(&coerced_type, &window_frame.end_bound)?; + } + WindowFrameUnits::Rows | WindowFrameUnits::Groups => { + let coerced_type = DataType::UInt64; + window_frame.start_bound = + coerce_frame_bound(&coerced_type, &window_frame.start_bound)?; + window_frame.end_bound = + coerce_frame_bound(&coerced_type, &window_frame.end_bound)?; + } + } + + Ok(Some(window_frame)) + } else { + Ok(None) + } +} // Support the `IsTrue` `IsNotTrue` `IsFalse` `IsNotFalse` type coercion. // The above op will be rewrite to the binary op when creating the physical op. fn get_casted_expr_for_bool_op(expr: &Expr, schema: &DFSchemaRef) -> Result { diff --git a/datafusion/sql/src/utils.rs b/datafusion/sql/src/utils.rs index 9413754e9355..550c5ee42b2d 100644 --- a/datafusion/sql/src/utils.rs +++ b/datafusion/sql/src/utils.rs @@ -480,6 +480,16 @@ pub fn window_expr_common_partition_keys(window_exprs: &[Expr]) -> Result<&[Expr .iter() .map(|expr| match expr { Expr::WindowFunction { partition_by, .. } => Ok(partition_by), + Expr::Alias(expr, _) => { + // convert &Box to &T + match &**expr { + Expr::WindowFunction { partition_by, .. } => Ok(partition_by), + expr => Err(DataFusionError::Execution(format!( + "Impossibly got non-window expr {:?}", + expr + ))), + } + } expr => Err(DataFusionError::Execution(format!( "Impossibly got non-window expr {:?}", expr