From 2f60903e716dd33efaef30c24e5fed6fc76b97a9 Mon Sep 17 00:00:00 2001 From: Jiayu Liu Date: Thu, 10 Jun 2021 11:41:17 +0800 Subject: [PATCH] add boundary check --- .../core/src/serde/logical_plan/from_proto.rs | 10 +- datafusion/src/logical_plan/window_frames.rs | 35 ++ datafusion/src/physical_plan/mod.rs | 1 - datafusion/src/physical_plan/window_frames.rs | 337 ------------------ datafusion/src/sql/planner.rs | 58 ++- 5 files changed, 95 insertions(+), 346 deletions(-) delete mode 100644 datafusion/src/physical_plan/window_frames.rs diff --git a/ballista/rust/core/src/serde/logical_plan/from_proto.rs b/ballista/rust/core/src/serde/logical_plan/from_proto.rs index 86daeb063c47..894a5f0a7d98 100644 --- a/ballista/rust/core/src/serde/logical_plan/from_proto.rs +++ b/ballista/rust/core/src/serde/logical_plan/from_proto.rs @@ -927,10 +927,18 @@ impl TryInto for &protobuf::LogicalExprNode { .as_ref() .map::, _>(|e| match e { window_expr_node::WindowFrame::Frame(frame) => { - frame.clone().try_into() + let window_frame: WindowFrame = frame.clone().try_into()?; + if WindowFrameUnits::Range == window_frame.units + && order_by.len() != 1 + { + Err(proto_error("With window frame of type RANGE, the order by expression must be of length 1")) + } else { + Ok(window_frame) + } } }) .transpose()?; + match window_function { window_expr_node::WindowFunction::AggrFunction(i) => { let aggr_function = protobuf::AggregateFunction::from_i32(*i) diff --git a/datafusion/src/logical_plan/window_frames.rs b/datafusion/src/logical_plan/window_frames.rs index f0be5a221fbf..8aaebd3155c1 100644 --- a/datafusion/src/logical_plan/window_frames.rs +++ b/datafusion/src/logical_plan/window_frames.rs @@ -82,6 +82,22 @@ impl TryFrom for WindowFrame { ))) } else { let units = value.units.into(); + if units == WindowFrameUnits::Range { + for bound in &[start_bound, end_bound] { + match bound { + WindowFrameBound::Preceding(Some(v)) + | WindowFrameBound::Following(Some(v)) + if *v > 0 => + { + Err(DataFusionError::NotImplemented(format!( + "With WindowFrameUnits={}, the bound cannot be {} PRECEDING or FOLLOWING at the moment", + units, v + ))) + } + _ => Ok(()), + }?; + } + } Ok(Self { units, start_bound, @@ -270,6 +286,25 @@ mod tests { result.err().unwrap().to_string(), "Execution error: Invalid window frame: start bound (1 PRECEDING) cannot be larger than end bound (2 PRECEDING)".to_owned() ); + + let window_frame = ast::WindowFrame { + units: ast::WindowFrameUnits::Range, + start_bound: ast::WindowFrameBound::Preceding(Some(2)), + end_bound: Some(ast::WindowFrameBound::Preceding(Some(1))), + }; + let result = WindowFrame::try_from(window_frame); + assert_eq!( + result.err().unwrap().to_string(), + "This feature is not implemented: With WindowFrameUnits=RANGE, the bound cannot be 2 PRECEDING or FOLLOWING at the moment".to_owned() + ); + + let window_frame = ast::WindowFrame { + units: ast::WindowFrameUnits::Rows, + start_bound: ast::WindowFrameBound::Preceding(Some(2)), + end_bound: Some(ast::WindowFrameBound::Preceding(Some(1))), + }; + let result = WindowFrame::try_from(window_frame); + assert!(result.is_ok()); Ok(()) } diff --git a/datafusion/src/physical_plan/mod.rs b/datafusion/src/physical_plan/mod.rs index 490e02875c42..af6969c43cbd 100644 --- a/datafusion/src/physical_plan/mod.rs +++ b/datafusion/src/physical_plan/mod.rs @@ -617,6 +617,5 @@ pub mod udf; #[cfg(feature = "unicode_expressions")] pub mod unicode_expressions; pub mod union; -pub mod window_frames; pub mod window_functions; pub mod windows; diff --git a/datafusion/src/physical_plan/window_frames.rs b/datafusion/src/physical_plan/window_frames.rs deleted file mode 100644 index f0be5a221fbf..000000000000 --- a/datafusion/src/physical_plan/window_frames.rs +++ /dev/null @@ -1,337 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -//! Window frame -//! -//! The frame-spec determines which output rows are read by an aggregate window function. The frame-spec consists of four parts: -//! - A frame type - either ROWS, RANGE or GROUPS, -//! - A starting frame boundary, -//! - An ending frame boundary, -//! - An EXCLUDE clause. - -use crate::error::{DataFusionError, Result}; -use sqlparser::ast; -use std::cmp::Ordering; -use std::convert::{From, TryFrom}; -use std::fmt; - -/// The frame-spec determines which output rows are read by an aggregate window function. -/// -/// The ending frame boundary can be omitted (if the BETWEEN and AND keywords that surround the -/// starting frame boundary are also omitted), in which case the ending frame boundary defaults to -/// CURRENT ROW. -#[derive(Debug, Clone, Copy, PartialEq, Eq)] -pub struct WindowFrame { - /// A frame type - either ROWS, RANGE or GROUPS - pub units: WindowFrameUnits, - /// A starting frame boundary - pub start_bound: WindowFrameBound, - /// An ending frame boundary - pub end_bound: WindowFrameBound, -} - -impl fmt::Display for WindowFrame { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - write!( - f, - "{} BETWEEN {} AND {}", - self.units, self.start_bound, self.end_bound - )?; - Ok(()) - } -} - -impl TryFrom for WindowFrame { - type Error = DataFusionError; - - fn try_from(value: ast::WindowFrame) -> Result { - let start_bound = value.start_bound.into(); - let end_bound = value - .end_bound - .map(WindowFrameBound::from) - .unwrap_or(WindowFrameBound::CurrentRow); - - if let WindowFrameBound::Following(None) = start_bound { - Err(DataFusionError::Execution( - "Invalid window frame: start bound cannot be unbounded following" - .to_owned(), - )) - } else if let WindowFrameBound::Preceding(None) = end_bound { - Err(DataFusionError::Execution( - "Invalid window frame: end bound cannot be unbounded preceding" - .to_owned(), - )) - } else if start_bound > end_bound { - Err(DataFusionError::Execution(format!( - "Invalid window frame: start bound ({}) cannot be larger than end bound ({})", - start_bound, end_bound - ))) - } else { - let units = value.units.into(); - Ok(Self { - units, - start_bound, - end_bound, - }) - } - } -} - -impl Default for WindowFrame { - fn default() -> Self { - WindowFrame { - units: WindowFrameUnits::Range, - start_bound: WindowFrameBound::Preceding(None), - end_bound: WindowFrameBound::CurrentRow, - } - } -} - -/// There are five ways to describe starting and ending frame boundaries: -/// -/// 1. UNBOUNDED PRECEDING -/// 2. PRECEDING -/// 3. CURRENT ROW -/// 4. FOLLOWING -/// 5. UNBOUNDED FOLLOWING -/// -/// in this implementation we'll only allow to be u64 (i.e. no dynamic boundary) -#[derive(Debug, Clone, Copy, Eq)] -pub enum WindowFrameBound { - /// 1. UNBOUNDED PRECEDING - /// The frame boundary is the first row in the partition. - /// - /// 2. PRECEDING - /// must be a non-negative constant numeric expression. The boundary is a row that - /// is "units" prior to the current row. - Preceding(Option), - /// 3. The current row. - /// - /// For RANGE and GROUPS frame types, peers of the current row are also - /// included in the frame, unless specifically excluded by the EXCLUDE clause. - /// This is true regardless of whether CURRENT ROW is used as the starting or ending frame - /// boundary. - CurrentRow, - /// 4. This is the same as " PRECEDING" except that the boundary is units after the - /// current rather than before the current row. - /// - /// 5. UNBOUNDED FOLLOWING - /// The frame boundary is the last row in the partition. - Following(Option), -} - -impl From for WindowFrameBound { - fn from(value: ast::WindowFrameBound) -> Self { - match value { - ast::WindowFrameBound::Preceding(v) => Self::Preceding(v), - ast::WindowFrameBound::Following(v) => Self::Following(v), - ast::WindowFrameBound::CurrentRow => Self::CurrentRow, - } - } -} - -impl fmt::Display for WindowFrameBound { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - match self { - WindowFrameBound::CurrentRow => f.write_str("CURRENT ROW"), - WindowFrameBound::Preceding(None) => f.write_str("UNBOUNDED PRECEDING"), - WindowFrameBound::Following(None) => f.write_str("UNBOUNDED FOLLOWING"), - WindowFrameBound::Preceding(Some(n)) => write!(f, "{} PRECEDING", n), - WindowFrameBound::Following(Some(n)) => write!(f, "{} FOLLOWING", n), - } - } -} - -impl PartialEq for WindowFrameBound { - fn eq(&self, other: &Self) -> bool { - self.cmp(other) == Ordering::Equal - } -} - -impl PartialOrd for WindowFrameBound { - fn partial_cmp(&self, other: &Self) -> Option { - Some(self.cmp(other)) - } -} - -impl Ord for WindowFrameBound { - fn cmp(&self, other: &Self) -> Ordering { - self.get_rank().cmp(&other.get_rank()) - } -} - -impl WindowFrameBound { - /// get the rank of this window frame bound. - /// - /// the rank is a tuple of (u8, u64) because we'll firstly compare the kind and then the value - /// which requires special handling e.g. with preceding the larger the value the smaller the - /// rank and also for 0 preceding / following it is the same as current row - fn get_rank(&self) -> (u8, u64) { - match self { - WindowFrameBound::Preceding(None) => (0, 0), - WindowFrameBound::Following(None) => (4, 0), - WindowFrameBound::Preceding(Some(0)) - | WindowFrameBound::CurrentRow - | WindowFrameBound::Following(Some(0)) => (2, 0), - WindowFrameBound::Preceding(Some(v)) => (1, u64::MAX - *v), - WindowFrameBound::Following(Some(v)) => (3, *v), - } - } -} - -/// There are three frame types: ROWS, GROUPS, and RANGE. The frame type determines how the -/// starting and ending boundaries of the frame are measured. -#[derive(Debug, Clone, Copy, PartialEq, Eq)] -pub enum WindowFrameUnits { - /// The ROWS frame type means that the starting and ending boundaries for the frame are - /// determined by counting individual rows relative to the current row. - Rows, - /// The RANGE frame type requires that the ORDER BY clause of the window have exactly one - /// term. Call that term "X". With the RANGE frame type, the elements of the frame are - /// determined by computing the value of expression X for all rows in the partition and framing - /// those rows for which the value of X is within a certain range of the value of X for the - /// current row. - Range, - /// The GROUPS frame type means that the starting and ending boundaries are determine - /// by counting "groups" relative to the current group. A "group" is a set of rows that all have - /// equivalent values for all all terms of the window ORDER BY clause. - Groups, -} - -impl fmt::Display for WindowFrameUnits { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - f.write_str(match self { - WindowFrameUnits::Rows => "ROWS", - WindowFrameUnits::Range => "RANGE", - WindowFrameUnits::Groups => "GROUPS", - }) - } -} - -impl From for WindowFrameUnits { - fn from(value: ast::WindowFrameUnits) -> Self { - match value { - ast::WindowFrameUnits::Range => Self::Range, - ast::WindowFrameUnits::Groups => Self::Groups, - ast::WindowFrameUnits::Rows => Self::Rows, - } - } -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_window_frame_creation() -> Result<()> { - let window_frame = ast::WindowFrame { - units: ast::WindowFrameUnits::Range, - start_bound: ast::WindowFrameBound::Following(None), - end_bound: None, - }; - let result = WindowFrame::try_from(window_frame); - assert_eq!( - result.err().unwrap().to_string(), - "Execution error: Invalid window frame: start bound cannot be unbounded following".to_owned() - ); - - let window_frame = ast::WindowFrame { - units: ast::WindowFrameUnits::Range, - start_bound: ast::WindowFrameBound::Preceding(None), - end_bound: Some(ast::WindowFrameBound::Preceding(None)), - }; - let result = WindowFrame::try_from(window_frame); - assert_eq!( - result.err().unwrap().to_string(), - "Execution error: Invalid window frame: end bound cannot be unbounded preceding".to_owned() - ); - - let window_frame = ast::WindowFrame { - units: ast::WindowFrameUnits::Range, - start_bound: ast::WindowFrameBound::Preceding(Some(1)), - end_bound: Some(ast::WindowFrameBound::Preceding(Some(2))), - }; - let result = WindowFrame::try_from(window_frame); - assert_eq!( - result.err().unwrap().to_string(), - "Execution error: Invalid window frame: start bound (1 PRECEDING) cannot be larger than end bound (2 PRECEDING)".to_owned() - ); - Ok(()) - } - - #[test] - fn test_eq() { - assert_eq!( - WindowFrameBound::Preceding(Some(0)), - WindowFrameBound::CurrentRow - ); - assert_eq!( - WindowFrameBound::CurrentRow, - WindowFrameBound::Following(Some(0)) - ); - assert_eq!( - WindowFrameBound::Following(Some(2)), - WindowFrameBound::Following(Some(2)) - ); - assert_eq!( - WindowFrameBound::Following(None), - WindowFrameBound::Following(None) - ); - assert_eq!( - WindowFrameBound::Preceding(Some(2)), - WindowFrameBound::Preceding(Some(2)) - ); - assert_eq!( - WindowFrameBound::Preceding(None), - WindowFrameBound::Preceding(None) - ); - } - - #[test] - fn test_ord() { - assert!(WindowFrameBound::Preceding(Some(1)) < WindowFrameBound::CurrentRow); - // ! yes this is correct! - assert!( - WindowFrameBound::Preceding(Some(2)) < WindowFrameBound::Preceding(Some(1)) - ); - assert!( - WindowFrameBound::Preceding(Some(u64::MAX)) - < WindowFrameBound::Preceding(Some(u64::MAX - 1)) - ); - assert!( - WindowFrameBound::Preceding(None) - < WindowFrameBound::Preceding(Some(1000000)) - ); - assert!( - WindowFrameBound::Preceding(None) - < WindowFrameBound::Preceding(Some(u64::MAX)) - ); - assert!(WindowFrameBound::Preceding(None) < WindowFrameBound::Following(Some(0))); - assert!( - WindowFrameBound::Preceding(Some(1)) < WindowFrameBound::Following(Some(1)) - ); - assert!(WindowFrameBound::CurrentRow < WindowFrameBound::Following(Some(1))); - assert!( - WindowFrameBound::Following(Some(1)) < WindowFrameBound::Following(Some(2)) - ); - assert!(WindowFrameBound::Following(Some(2)) < WindowFrameBound::Following(None)); - assert!( - WindowFrameBound::Following(Some(u64::MAX)) - < WindowFrameBound::Following(None) - ); - } -} diff --git a/datafusion/src/sql/planner.rs b/datafusion/src/sql/planner.rs index 53f22ecaf3f2..c128634091a0 100644 --- a/datafusion/src/sql/planner.rs +++ b/datafusion/src/sql/planner.rs @@ -19,6 +19,7 @@ use crate::catalog::TableReference; use crate::datasource::TableProvider; +use crate::logical_plan::window_frames::{WindowFrame, WindowFrameUnits}; use crate::logical_plan::Expr::Alias; use crate::logical_plan::{ and, lit, DFSchema, Expr, LogicalPlan, LogicalPlanBuilder, Operator, PlanType, @@ -1137,7 +1138,18 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { let window_frame = window .window_frame .as_ref() - .map(|window_frame| window_frame.clone().try_into()) + .map(|window_frame| { + let window_frame: WindowFrame = window_frame.clone().try_into()?; + if WindowFrameUnits::Range == window_frame.units + && order_by.len() != 1 + { + Err(DataFusionError::Plan(format!( + "With window frame of type RANGE, the order by expression must be of length 1, got {}", order_by.len()))) + } else { + Ok(window_frame) + } + + }) .transpose()?; let fun = window_functions::WindowFunction::from_str(&name)?; match fun { @@ -2859,10 +2871,10 @@ mod tests { #[test] fn over_order_by_with_window_frame_double_end() { - let sql = "SELECT order_id, MAX(qty) OVER (ORDER BY order_id RANGE BETWEEN 3 PRECEDING and 3 FOLLOWING), MIN(qty) OVER (ORDER BY order_id DESC) from orders"; + let sql = "SELECT order_id, MAX(qty) OVER (ORDER BY order_id ROWS BETWEEN 3 PRECEDING and 3 FOLLOWING), MIN(qty) OVER (ORDER BY order_id DESC) from orders"; let expected = "\ - Projection: #order_id, #MAX(qty) RANGE BETWEEN 3 PRECEDING AND 3 FOLLOWING, #MIN(qty)\ - \n WindowAggr: windowExpr=[[MAX(#qty) RANGE BETWEEN 3 PRECEDING AND 3 FOLLOWING]]\ + Projection: #order_id, #MAX(qty) ROWS BETWEEN 3 PRECEDING AND 3 FOLLOWING, #MIN(qty)\ + \n WindowAggr: windowExpr=[[MAX(#qty) ROWS BETWEEN 3 PRECEDING AND 3 FOLLOWING]]\ \n Sort: #order_id ASC NULLS FIRST\ \n WindowAggr: windowExpr=[[MIN(#qty)]]\ \n Sort: #order_id DESC NULLS FIRST\ @@ -2872,10 +2884,10 @@ mod tests { #[test] fn over_order_by_with_window_frame_single_end() { - let sql = "SELECT order_id, MAX(qty) OVER (ORDER BY order_id RANGE 3 PRECEDING), MIN(qty) OVER (ORDER BY order_id DESC) from orders"; + let sql = "SELECT order_id, MAX(qty) OVER (ORDER BY order_id ROWS 3 PRECEDING), MIN(qty) OVER (ORDER BY order_id DESC) from orders"; let expected = "\ - Projection: #order_id, #MAX(qty) RANGE BETWEEN 3 PRECEDING AND CURRENT ROW, #MIN(qty)\ - \n WindowAggr: windowExpr=[[MAX(#qty) RANGE BETWEEN 3 PRECEDING AND CURRENT ROW]]\ + Projection: #order_id, #MAX(qty) ROWS BETWEEN 3 PRECEDING AND CURRENT ROW, #MIN(qty)\ + \n WindowAggr: windowExpr=[[MAX(#qty) ROWS BETWEEN 3 PRECEDING AND CURRENT ROW]]\ \n Sort: #order_id ASC NULLS FIRST\ \n WindowAggr: windowExpr=[[MIN(#qty)]]\ \n Sort: #order_id DESC NULLS FIRST\ @@ -2883,6 +2895,38 @@ mod tests { quick_test(sql, expected); } + #[test] + fn over_order_by_with_window_frame_range_value_check() { + let sql = "SELECT order_id, MAX(qty) OVER (ORDER BY order_id RANGE 3 PRECEDING) from orders"; + let err = logical_plan(sql).expect_err("query should have failed"); + assert_eq!( + "NotImplemented(\"With WindowFrameUnits=RANGE, the bound cannot be 3 PRECEDING or FOLLOWING at the moment\")", + format!("{:?}", err) + ); + } + + #[test] + fn over_order_by_with_window_frame_range_order_by_check() { + let sql = + "SELECT order_id, MAX(qty) OVER (RANGE UNBOUNDED PRECEDING) from orders"; + let err = logical_plan(sql).expect_err("query should have failed"); + assert_eq!( + "Plan(\"With window frame of type RANGE, the order by expression must be of length 1, got 0\")", + format!("{:?}", err) + ); + } + + #[test] + fn over_order_by_with_window_frame_range_order_by_check_2() { + let sql = + "SELECT order_id, MAX(qty) OVER (ORDER BY order_id, qty RANGE UNBOUNDED PRECEDING) from orders"; + let err = logical_plan(sql).expect_err("query should have failed"); + assert_eq!( + "Plan(\"With window frame of type RANGE, the order by expression must be of length 1, got 2\")", + format!("{:?}", err) + ); + } + #[test] fn over_order_by_with_window_frame_single_end_groups() { let sql = "SELECT order_id, MAX(qty) OVER (ORDER BY order_id GROUPS 3 PRECEDING), MIN(qty) OVER (ORDER BY order_id DESC) from orders";