diff --git a/datafusion/core/src/physical_plan/windows/mod.rs b/datafusion/core/src/physical_plan/windows/mod.rs index 95582b2119de..bece2a50ce47 100644 --- a/datafusion/core/src/physical_plan/windows/mod.rs +++ b/datafusion/core/src/physical_plan/windows/mod.rs @@ -65,6 +65,7 @@ pub fn create_window_expr( create_built_in_window_expr(fun, args, input_schema, name)?, partition_by, order_by, + window_frame, )), }) } diff --git a/datafusion/core/tests/sql/window.rs b/datafusion/core/tests/sql/window.rs index 48f5a08dd55c..0535c09ac5cb 100644 --- a/datafusion/core/tests/sql/window.rs +++ b/datafusion/core/tests/sql/window.rs @@ -1276,3 +1276,185 @@ async fn window_frame_creation() -> Result<()> { Ok(()) } + +#[tokio::test] +async fn test_window_row_number_aggregate() -> Result<()> { + let config = SessionConfig::new(); + let ctx = SessionContext::with_config(config); + register_aggregate_csv(&ctx).await?; + let sql = "SELECT + c8, + ROW_NUMBER() OVER(ORDER BY c9) AS rn1, + ROW_NUMBER() OVER(ORDER BY c9 ROWS BETWEEN 10 PRECEDING and 1 FOLLOWING) as rn2 + FROM aggregate_test_100 + ORDER BY c8 + LIMIT 5"; + + let actual = execute_to_batches(&ctx, sql).await; + let expected = vec![ + "+-----+-----+-----+", + "| c8 | rn1 | rn2 |", + "+-----+-----+-----+", + "| 102 | 73 | 73 |", + "| 299 | 1 | 1 |", + "| 363 | 41 | 41 |", + "| 417 | 14 | 14 |", + "| 794 | 95 | 95 |", + "+-----+-----+-----+", + ]; + assert_batches_eq!(expected, &actual); + Ok(()) +} + +#[tokio::test] +async fn test_window_cume_dist() -> Result<()> { + let config = SessionConfig::new(); + let ctx = SessionContext::with_config(config); + register_aggregate_csv(&ctx).await?; + let sql = "SELECT + c8, + CUME_DIST() OVER(ORDER BY c9) as cd1, + CUME_DIST() OVER(ORDER BY c9 ROWS BETWEEN 10 PRECEDING and 1 FOLLOWING) as cd2 + FROM aggregate_test_100 + ORDER BY c8 + LIMIT 5"; + + let actual = execute_to_batches(&ctx, sql).await; + let expected = vec![ + "+-----+------+------+", + "| c8 | cd1 | cd2 |", + "+-----+------+------+", + "| 102 | 0.73 | 0.73 |", + "| 299 | 0.01 | 0.01 |", + "| 363 | 0.41 | 0.41 |", + "| 417 | 0.14 | 0.14 |", + "| 794 | 0.95 | 0.95 |", + "+-----+------+------+", + ]; + assert_batches_eq!(expected, &actual); + Ok(()) +} + +#[tokio::test] +async fn test_window_rank() -> Result<()> { + let config = SessionConfig::new(); + let ctx = SessionContext::with_config(config); + register_aggregate_csv(&ctx).await?; + let sql = "SELECT + c9, + RANK() OVER(ORDER BY c1) AS rank1, + RANK() OVER(ORDER BY c1 ROWS BETWEEN 10 PRECEDING and 1 FOLLOWING) as rank2, + DENSE_RANK() OVER(ORDER BY c1) as dense_rank1, + DENSE_RANK() OVER(ORDER BY c1 ROWS BETWEEN 10 PRECEDING and 1 FOLLOWING) as dense_rank2, + PERCENT_RANK() OVER(ORDER BY c1) as percent_rank1, + PERCENT_RANK() OVER(ORDER BY c1 ROWS BETWEEN 10 PRECEDING and 1 FOLLOWING) as percent_rank2 + FROM aggregate_test_100 + ORDER BY c9 + LIMIT 5"; + + let actual = execute_to_batches(&ctx, sql).await; + let expected = vec![ + "+-----------+-------+-------+-------------+-------------+---------------------+---------------------+", + "| c9 | rank1 | rank2 | dense_rank1 | dense_rank2 | percent_rank1 | percent_rank2 |", + "+-----------+-------+-------+-------------+-------------+---------------------+---------------------+", + "| 28774375 | 80 | 80 | 5 | 5 | 0.797979797979798 | 0.797979797979798 |", + "| 63044568 | 62 | 62 | 4 | 4 | 0.6161616161616161 | 0.6161616161616161 |", + "| 141047417 | 1 | 1 | 1 | 1 | 0 | 0 |", + "| 141680161 | 41 | 41 | 3 | 3 | 0.40404040404040403 | 0.40404040404040403 |", + "| 145294611 | 1 | 1 | 1 | 1 | 0 | 0 |", + "+-----------+-------+-------+-------------+-------------+---------------------+---------------------+", + ]; + assert_batches_eq!(expected, &actual); + Ok(()) +} + +#[tokio::test] +async fn test_lag_lead() -> Result<()> { + let config = SessionConfig::new(); + let ctx = SessionContext::with_config(config); + register_aggregate_csv(&ctx).await?; + let sql = "SELECT + c9, + LAG(c9, 2, 10101) OVER(ORDER BY c9) as lag1, + LAG(c9, 2, 10101) OVER(ORDER BY c9 ROWS BETWEEN 10 PRECEDING and 1 FOLLOWING) as lag2, + LEAD(c9, 2, 10101) OVER(ORDER BY c9) as lead1, + LEAD(c9, 2, 10101) OVER(ORDER BY c9 ROWS BETWEEN 10 PRECEDING and 1 FOLLOWING) as lead2 + FROM aggregate_test_100 + ORDER BY c9 + LIMIT 5"; + + let actual = execute_to_batches(&ctx, sql).await; + let expected = vec![ + "+-----------+-----------+-----------+-----------+-----------+", + "| c9 | lag1 | lag2 | lead1 | lead2 |", + "+-----------+-----------+-----------+-----------+-----------+", + "| 28774375 | 10101 | 10101 | 141047417 | 141047417 |", + "| 63044568 | 10101 | 10101 | 141680161 | 141680161 |", + "| 141047417 | 28774375 | 28774375 | 145294611 | 145294611 |", + "| 141680161 | 63044568 | 63044568 | 225513085 | 225513085 |", + "| 145294611 | 141047417 | 141047417 | 243203849 | 243203849 |", + "+-----------+-----------+-----------+-----------+-----------+", + ]; + assert_batches_eq!(expected, &actual); + Ok(()) +} + +#[tokio::test] +async fn test_window_frame_first_value_last_value_aggregate() -> Result<()> { + let config = SessionConfig::new(); + let ctx = SessionContext::with_config(config); + register_aggregate_csv(&ctx).await?; + + let sql = "SELECT + FIRST_VALUE(c4) OVER(ORDER BY c9 ASC ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING) as first_value1, + FIRST_VALUE(c4) OVER(ORDER BY c9 ASC ROWS BETWEEN 2 PRECEDING AND 3 FOLLOWING) as first_value2, + LAST_VALUE(c4) OVER(ORDER BY c9 ASC ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING) as last_value1, + LAST_VALUE(c4) OVER(ORDER BY c9 ASC ROWS BETWEEN 2 PRECEDING AND 3 FOLLOWING) as last_value2 + FROM aggregate_test_100 + ORDER BY c9 + LIMIT 5"; + + let actual = execute_to_batches(&ctx, sql).await; + let expected = vec![ + "+--------------+--------------+-------------+-------------+", + "| first_value1 | first_value2 | last_value1 | last_value2 |", + "+--------------+--------------+-------------+-------------+", + "| -16110 | -16110 | 3917 | -1114 |", + "| -16110 | -16110 | -16974 | 15673 |", + "| -16110 | -16110 | -1114 | 13630 |", + "| -16110 | 3917 | 15673 | -13217 |", + "| -16110 | -16974 | 13630 | 20690 |", + "+--------------+--------------+-------------+-------------+", + ]; + assert_batches_eq!(expected, &actual); + Ok(()) +} + +#[tokio::test] +async fn test_window_frame_nth_value_aggregate() -> Result<()> { + let config = SessionConfig::new(); + let ctx = SessionContext::with_config(config); + register_aggregate_csv(&ctx).await?; + + let sql = "SELECT + NTH_VALUE(c4, 3) OVER(ORDER BY c9 ASC ROWS BETWEEN 2 PRECEDING AND 1 FOLLOWING) as nth_value1, + NTH_VALUE(c4, 2) OVER(ORDER BY c9 ASC ROWS BETWEEN 1 PRECEDING AND 3 FOLLOWING) as nth_value2 + FROM aggregate_test_100 + ORDER BY c9 + LIMIT 5"; + + let actual = execute_to_batches(&ctx, sql).await; + let expected = vec![ + "+------------+------------+", + "| nth_value1 | nth_value2 |", + "+------------+------------+", + "| | 3917 |", + "| -16974 | 3917 |", + "| -16974 | -16974 |", + "| -1114 | -1114 |", + "| 15673 | 15673 |", + "+------------+------------+", + ]; + assert_batches_eq!(expected, &actual); + Ok(()) +} diff --git a/datafusion/physical-expr/src/window/aggregate.rs b/datafusion/physical-expr/src/window/aggregate.rs index 80cb4d10ce1a..e6c754387f1c 100644 --- a/datafusion/physical-expr/src/window/aggregate.rs +++ b/datafusion/physical-expr/src/window/aggregate.rs @@ -18,21 +18,17 @@ //! Physical exec for aggregate window function expressions. use std::any::Any; -use std::cmp::min; use std::iter::IntoIterator; -use std::ops::Range; use std::sync::Arc; use arrow::array::Array; -use arrow::compute::{concat, SortOptions}; +use arrow::compute::SortOptions; use arrow::record_batch::RecordBatch; use arrow::{array::ArrayRef, datatypes::Field}; -use datafusion_common::bisect::bisect; use datafusion_common::Result; -use datafusion_common::{DataFusionError, ScalarValue}; -use datafusion_expr::{Accumulator, WindowFrameBound}; -use datafusion_expr::{WindowFrame, WindowFrameUnits}; +use datafusion_common::ScalarValue; +use datafusion_expr::WindowFrame; use crate::{expressions::PhysicalSortExpr, PhysicalExpr}; use crate::{window::WindowExpr, AggregateExpr}; @@ -61,23 +57,6 @@ impl AggregateWindowExpr { window_frame, } } - - /// create a new accumulator based on the underlying aggregation function - fn create_accumulator(&self) -> Result { - let accumulator = self.aggregate.create_accumulator()?; - let window_frame = self.window_frame.clone(); - let partition_by = self.partition_by().to_vec(); - let order_by = self.order_by.to_vec(); - let field = self.aggregate.field()?; - - Ok(AggregateWindowAccumulator { - accumulator, - window_frame, - partition_by, - order_by, - field, - }) - } } /// peer based evaluation based on the fact that batch is pre-sorted given the sort columns @@ -103,368 +82,86 @@ impl WindowExpr for AggregateWindowExpr { } fn evaluate(&self, batch: &RecordBatch) -> Result { - let num_rows = batch.num_rows(); + let partition_columns = self.partition_columns(batch)?; let partition_points = - self.evaluate_partition_points(num_rows, &self.partition_columns(batch)?)?; + self.evaluate_partition_points(batch.num_rows(), &partition_columns)?; let values = self.evaluate_args(batch)?; + let sort_options: Vec = + self.order_by.iter().map(|o| o.options).collect(); let columns = self.sort_columns(batch)?; - let array_refs: Vec<&ArrayRef> = columns.iter().map(|s| &s.values).collect(); + let order_columns: Vec<&ArrayRef> = columns.iter().map(|s| &s.values).collect(); // Sort values, this will make the same partitions consecutive. Also, within the partition // range, values will be sorted. - let results = partition_points - .iter() - .map(|partition_range| { - let mut window_accumulators = self.create_accumulator()?; - Ok(vec![window_accumulators.scan( - &values, - &array_refs, - partition_range, - )?]) - }) - .collect::>>>()? - .into_iter() - .flatten() - .collect::>(); - let results = results.iter().map(|i| i.as_ref()).collect::>(); - concat(&results).map_err(DataFusionError::ArrowError) - } - - fn partition_by(&self) -> &[Arc] { - &self.partition_by - } - - fn order_by(&self) -> &[PhysicalSortExpr] { - &self.order_by - } -} - -fn calculate_index_of_row( - range_columns: &[ArrayRef], - sort_options: &[SortOptions], - idx: usize, - delta: Option<&ScalarValue>, -) -> Result { - let current_row_values = range_columns - .iter() - .map(|col| ScalarValue::try_from_array(col, idx)) - .collect::>>()?; - let end_range = if let Some(delta) = delta { - let is_descending: bool = sort_options - .first() - .ok_or_else(|| DataFusionError::Internal("Array is empty".to_string()))? - .descending; - - current_row_values - .iter() - .map(|value| { - if value.is_null() { - return Ok(value.clone()); - } - if SEARCH_SIDE == is_descending { - // TODO: Handle positive overflows - value.add(delta) - } else if value.is_unsigned() && value < delta { - // NOTE: This gets a polymorphic zero without having long coercion code for ScalarValue. - // If we decide to implement a "default" construction mechanism for ScalarValue, - // change the following statement to use that. - value.sub(value) + let order_bys = &order_columns[self.partition_by.len()..]; + let window_frame = if !order_bys.is_empty() && self.window_frame.is_none() { + // OVER (ORDER BY a) case + // We create an implicit window for ORDER BY. + Some(Arc::new(WindowFrame::default())) + } else { + self.window_frame.clone() + }; + let mut row_wise_results: Vec = vec![]; + for partition_range in &partition_points { + let mut accumulator = self.aggregate.create_accumulator()?; + let length = partition_range.end - partition_range.start; + let slice_order_bys = order_bys + .iter() + .map(|v| v.slice(partition_range.start, length)) + .collect::>(); + let value_slice = values + .iter() + .map(|v| v.slice(partition_range.start, length)) + .collect::>(); + + let mut last_range: (usize, usize) = (0, 0); + + // We iterate on each row to perform a running calculation. + // First, cur_range is calculated, then it is compared with last_range. + for i in 0..length { + let cur_range = self.calculate_range( + &window_frame, + &slice_order_bys, + &sort_options, + length, + i, + )?; + let value = if cur_range.0 == cur_range.1 { + // We produce None if the window is empty. + ScalarValue::try_from(self.aggregate.field()?.data_type())? } else { - // TODO: Handle negative overflows - value.sub(delta) - } - }) - .collect::>>()? - } else { - current_row_values - }; - // `BISECT_SIDE` true means bisect_left, false means bisect_right - bisect::(range_columns, &end_range, sort_options) -} - -/// We use start and end bounds to calculate current row's starting and ending range. -/// This function supports different modes, but we currently do not support window calculation for GROUPS inside window frames. -fn calculate_current_window( - window_frame: &WindowFrame, - range_columns: &[ArrayRef], - sort_options: &[SortOptions], - length: usize, - idx: usize, -) -> Result<(usize, usize)> { - match window_frame.units { - WindowFrameUnits::Range => { - let start = match &window_frame.start_bound { - WindowFrameBound::Preceding(n) => { - if n.is_null() { - // UNBOUNDED PRECEDING - Ok(0) - } else { - calculate_index_of_row::( - range_columns, - sort_options, - idx, - Some(n), - ) + // Accumulate any new rows that have entered the window: + let update_bound = cur_range.1 - last_range.1; + if update_bound > 0 { + let update: Vec = value_slice + .iter() + .map(|v| v.slice(last_range.1, update_bound)) + .collect(); + accumulator.update_batch(&update)? } - } - WindowFrameBound::CurrentRow => calculate_index_of_row::( - range_columns, - sort_options, - idx, - None, - ), - WindowFrameBound::Following(n) => calculate_index_of_row::( - range_columns, - sort_options, - idx, - Some(n), - ), - }; - let end = match &window_frame.end_bound { - WindowFrameBound::Preceding(n) => calculate_index_of_row::( - range_columns, - sort_options, - idx, - Some(n), - ), - WindowFrameBound::CurrentRow => calculate_index_of_row::( - range_columns, - sort_options, - idx, - None, - ), - WindowFrameBound::Following(n) => { - if n.is_null() { - // UNBOUNDED FOLLOWING - Ok(length) - } else { - calculate_index_of_row::( - range_columns, - sort_options, - idx, - Some(n), - ) + // Remove rows that have now left the window: + let retract_bound = cur_range.0 - last_range.0; + if retract_bound > 0 { + let retract: Vec = value_slice + .iter() + .map(|v| v.slice(last_range.0, retract_bound)) + .collect(); + accumulator.retract_batch(&retract)? } - } - }; - Ok((start?, end?)) - } - WindowFrameUnits::Rows => { - let start = match window_frame.start_bound { - // UNBOUNDED PRECEDING - WindowFrameBound::Preceding(ScalarValue::UInt64(None)) => Ok(0), - WindowFrameBound::Preceding(ScalarValue::UInt64(Some(n))) => { - if idx >= n as usize { - Ok(idx - n as usize) - } else { - Ok(0) - } - } - WindowFrameBound::Preceding(_) => { - Err(DataFusionError::Internal("Rows should be Uint".to_string())) - } - WindowFrameBound::CurrentRow => Ok(idx), - // UNBOUNDED FOLLOWING - WindowFrameBound::Following(ScalarValue::UInt64(None)) => { - Err(DataFusionError::Internal(format!( - "Frame start cannot be UNBOUNDED FOLLOWING '{:?}'", - window_frame - ))) - } - WindowFrameBound::Following(ScalarValue::UInt64(Some(n))) => { - Ok(min(idx + n as usize, length)) - } - WindowFrameBound::Following(_) => { - Err(DataFusionError::Internal("Rows should be Uint".to_string())) - } - }; - let end = match window_frame.end_bound { - // UNBOUNDED PRECEDING - WindowFrameBound::Preceding(ScalarValue::UInt64(None)) => { - Err(DataFusionError::Internal(format!( - "Frame end cannot be UNBOUNDED PRECEDING '{:?}'", - window_frame - ))) - } - WindowFrameBound::Preceding(ScalarValue::UInt64(Some(n))) => { - if idx >= n as usize { - Ok(idx - n as usize + 1) - } else { - Ok(0) - } - } - WindowFrameBound::Preceding(_) => { - Err(DataFusionError::Internal("Rows should be Uint".to_string())) - } - WindowFrameBound::CurrentRow => Ok(idx + 1), - // UNBOUNDED FOLLOWING - WindowFrameBound::Following(ScalarValue::UInt64(None)) => Ok(length), - WindowFrameBound::Following(ScalarValue::UInt64(Some(n))) => { - Ok(min(idx + n as usize + 1, length)) - } - WindowFrameBound::Following(_) => { - Err(DataFusionError::Internal("Rows should be Uint".to_string())) - } - }; - Ok((start?, end?)) - } - WindowFrameUnits::Groups => Err(DataFusionError::NotImplemented( - "Window frame for groups is not implemented".to_string(), - )), - } -} - -/// Aggregate window accumulator utilizes the accumulator from aggregation and do a accumulative sum -/// across evaluation arguments based on peer equivalences. It uses many information to calculate -/// correct running window. -#[derive(Debug)] -struct AggregateWindowAccumulator { - accumulator: Box, - window_frame: Option>, - partition_by: Vec>, - order_by: Vec, - field: Field, -} - -impl AggregateWindowAccumulator { - /// This function calculates the aggregation on all rows in `value_slice`. - /// Returns an array of size `length`. - fn calculate_whole_table( - &mut self, - value_slice: &[ArrayRef], - length: usize, - ) -> Result { - self.accumulator.update_batch(value_slice)?; - let value = self.accumulator.evaluate()?; - Ok(value.to_array_of_size(length)) - } - - /// This function calculates the running window logic for the rows in `value_range` of `value_slice`. - /// We maintain the accumulator state via `update_batch` and `retract_batch` functions. - /// Note that not all aggregators implement `retract_batch` just yet. - fn calculate_running_window( - &mut self, - value_slice: &[ArrayRef], - order_bys: &[&ArrayRef], - value_range: &Range, - ) -> Result { - // We iterate on each row to perform a running calculation. - // First, cur_range is calculated, then it is compared with last_range. - let length = value_range.end - value_range.start; - let slice_order_columns = order_bys - .iter() - .map(|v| v.slice(value_range.start, length)) - .collect::>(); - let sort_options: Vec = - self.order_by.iter().map(|o| o.options).collect(); - - let updated_zero_offset_value_range = Range { - start: 0, - end: length, - }; - let mut row_wise_results: Vec = vec![]; - let mut last_range: (usize, usize) = ( - updated_zero_offset_value_range.start, - updated_zero_offset_value_range.start, - ); - - for i in 0..length { - let window_frame = self.window_frame.as_ref().ok_or_else(|| { - DataFusionError::Internal( - "Window frame cannot be empty to calculate window ranges".to_string(), - ) - })?; - let cur_range = calculate_current_window( - window_frame, - &slice_order_columns, - &sort_options, - length, - i, - )?; - - if cur_range.0 == cur_range.1 { - // We produce None if the window is empty. - row_wise_results.push(ScalarValue::try_from(self.field.data_type())?) - } else { - // Accumulate any new rows that have entered the window: - let update_bound = cur_range.1 - last_range.1; - if update_bound > 0 { - let update: Vec = value_slice - .iter() - .map(|v| v.slice(last_range.1, update_bound)) - .collect(); - self.accumulator.update_batch(&update)? - } - // Remove rows that have now left the window: - let retract_bound = cur_range.0 - last_range.0; - if retract_bound > 0 { - let retract: Vec = value_slice - .iter() - .map(|v| v.slice(last_range.0, retract_bound)) - .collect(); - self.accumulator.retract_batch(&retract)? - } - row_wise_results.push(self.accumulator.evaluate()?); + accumulator.evaluate()? + }; + row_wise_results.push(value); + last_range = cur_range; } - last_range = cur_range; } ScalarValue::iter_to_array(row_wise_results.into_iter()) } - fn scan( - &mut self, - values: &[ArrayRef], - order_bys: &[&ArrayRef], - value_range: &Range, - ) -> Result { - if value_range.is_empty() { - return Err(DataFusionError::Internal( - "Value range cannot be empty".to_owned(), - )); - } - let length = value_range.end - value_range.start; - let value_slice = values - .iter() - .map(|v| v.slice(value_range.start, length)) - .collect::>(); - let order_columns = &order_bys[self.partition_by.len()..order_bys.len()].to_vec(); - match (&order_columns[..], &self.window_frame) { - ([], None) => { - // OVER () case - self.calculate_whole_table(&value_slice, length) - } - ([column, ..], None) => { - // OVER (ORDER BY a) case - // We create an implicit window for ORDER BY. - let empty_bound = ScalarValue::try_from(column.data_type())?; - self.window_frame = Some(Arc::new(WindowFrame { - units: WindowFrameUnits::Range, - start_bound: WindowFrameBound::Preceding(empty_bound), - end_bound: WindowFrameBound::CurrentRow, - })); - self.calculate_running_window(&value_slice, order_columns, value_range) - } - ([], Some(frame)) => { - match frame.units { - WindowFrameUnits::Range => { - // OVER (RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) case - self.calculate_whole_table(&value_slice, length) - } - WindowFrameUnits::Rows => { - // OVER (ROWS BETWEEN X PRECEDING AND Y FOLLOWING) case - self.calculate_running_window( - &value_slice, - order_bys, - value_range, - ) - } - WindowFrameUnits::Groups => Err(DataFusionError::NotImplemented( - "Window frame for groups is not implemented".to_string(), - )), - } - } - // OVER (ORDER BY a ROWS/RANGE BETWEEN X PRECEDING AND Y FOLLOWING) case - _ => self.calculate_running_window(&value_slice, order_columns, value_range), - } + fn partition_by(&self) -> &[Arc] { + &self.partition_by + } + + fn order_by(&self) -> &[PhysicalSortExpr] { + &self.order_by } } diff --git a/datafusion/physical-expr/src/window/built_in.rs b/datafusion/physical-expr/src/window/built_in.rs index 2fa1f808fda8..e4e377175653 100644 --- a/datafusion/physical-expr/src/window/built_in.rs +++ b/datafusion/physical-expr/src/window/built_in.rs @@ -20,12 +20,15 @@ use super::BuiltInWindowFunctionExpr; use super::WindowExpr; use crate::{expressions::PhysicalSortExpr, PhysicalExpr}; -use arrow::compute::concat; +use arrow::array::Array; +use arrow::compute::{concat, SortOptions}; use arrow::record_batch::RecordBatch; use arrow::{array::ArrayRef, datatypes::Field}; use datafusion_common::DataFusionError; use datafusion_common::Result; +use datafusion_expr::WindowFrame; use std::any::Any; +use std::ops::Range; use std::sync::Arc; /// A window expr that takes the form of a built in window function @@ -34,6 +37,7 @@ pub struct BuiltInWindowExpr { expr: Arc, partition_by: Vec>, order_by: Vec, + window_frame: Option>, } impl BuiltInWindowExpr { @@ -42,11 +46,13 @@ impl BuiltInWindowExpr { expr: Arc, partition_by: &[Arc], order_by: &[PhysicalSortExpr], + window_frame: Option>, ) -> Self { Self { expr, partition_by: partition_by.to_vec(), order_by: order_by.to_vec(), + window_frame, } } } @@ -80,11 +86,55 @@ impl WindowExpr for BuiltInWindowExpr { fn evaluate(&self, batch: &RecordBatch) -> Result { let evaluator = self.expr.create_evaluator(batch)?; let num_rows = batch.num_rows(); + let partition_columns = self.partition_columns(batch)?; let partition_points = - self.evaluate_partition_points(num_rows, &self.partition_columns(batch)?)?; - let results = if evaluator.include_rank() { + self.evaluate_partition_points(num_rows, &partition_columns)?; + + let results = if evaluator.uses_window_frame() { + let sort_options: Vec = + self.order_by.iter().map(|o| o.options).collect(); + let columns = self.sort_columns(batch)?; + let order_columns: Vec<&ArrayRef> = + columns.iter().map(|s| &s.values).collect(); + // Sort values, this will make the same partitions consecutive. Also, within the partition + // range, values will be sorted. + let order_bys = &order_columns[self.partition_by.len()..]; + let window_frame = if !order_bys.is_empty() && self.window_frame.is_none() { + // OVER (ORDER BY a) case + // We create an implicit window for ORDER BY. + Some(Arc::new(WindowFrame::default())) + } else { + self.window_frame.clone() + }; + let mut row_wise_results = vec![]; + for partition_range in &partition_points { + let length = partition_range.end - partition_range.start; + let slice_order_bys = order_bys + .iter() + .map(|v| v.slice(partition_range.start, length)) + .collect::>(); + // We iterate on each row to calculate window frame range and and window function result + for idx in 0..length { + let range = self.calculate_range( + &window_frame, + &slice_order_bys, + &sort_options, + num_rows, + idx, + )?; + let range = Range { + start: partition_range.start + range.0, + end: partition_range.start + range.1, + }; + let value = evaluator.evaluate_inside_range(range)?; + row_wise_results.push(value.to_array()); + } + } + row_wise_results + } else if evaluator.include_rank() { + let columns = self.sort_columns(batch)?; let sort_partition_points = - self.evaluate_partition_points(num_rows, &self.sort_columns(batch)?)?; + self.evaluate_partition_points(num_rows, &columns)?; evaluator.evaluate_with_rank(partition_points, sort_partition_points)? } else { evaluator.evaluate(partition_points)? diff --git a/datafusion/physical-expr/src/window/nth_value.rs b/datafusion/physical-expr/src/window/nth_value.rs index e0a6b2bd7a7c..14ce53621bde 100644 --- a/datafusion/physical-expr/src/window/nth_value.rs +++ b/datafusion/physical-expr/src/window/nth_value.rs @@ -21,14 +21,12 @@ use crate::window::partition_evaluator::PartitionEvaluator; use crate::window::BuiltInWindowFunctionExpr; use crate::PhysicalExpr; -use arrow::array::{new_null_array, ArrayRef}; -use arrow::compute::kernels::window::shift; +use arrow::array::{Array, ArrayRef}; use arrow::datatypes::{DataType, Field}; use arrow::record_batch::RecordBatch; use datafusion_common::ScalarValue; use datafusion_common::{DataFusionError, Result}; use std::any::Any; -use std::iter; use std::ops::Range; use std::sync::Arc; @@ -142,7 +140,7 @@ pub(crate) struct NthValueEvaluator { } impl PartitionEvaluator for NthValueEvaluator { - fn include_rank(&self) -> bool { + fn uses_window_frame(&self) -> bool { true } @@ -150,45 +148,19 @@ impl PartitionEvaluator for NthValueEvaluator { unreachable!("first, last, and nth_value evaluation must be called with evaluate_partition_with_rank") } - fn evaluate_partition_with_rank( - &self, - partition: Range, - ranks_in_partition: &[Range], - ) -> Result { + fn evaluate_inside_range(&self, range: Range) -> Result { let arr = &self.values[0]; - let num_rows = partition.end - partition.start; + let n_range = range.end - range.start; match self.kind { - NthValueKind::First => { - let value = ScalarValue::try_from_array(arr, partition.start)?; - Ok(value.to_array_of_size(num_rows)) - } - NthValueKind::Last => { - // because the default window frame is between unbounded preceding and current - // row with peer evaluation, hence the last rows expands until the end of the peers - let values = ranks_in_partition - .iter() - .map(|range| { - let len = range.end - range.start; - let value = ScalarValue::try_from_array(arr, range.end - 1)?; - Ok(iter::repeat(value).take(len)) - }) - .collect::>>()? - .into_iter() - .flatten(); - ScalarValue::iter_to_array(values) - } + NthValueKind::First => ScalarValue::try_from_array(arr, range.start), + NthValueKind::Last => ScalarValue::try_from_array(arr, range.end - 1), NthValueKind::Nth(n) => { + // We are certain that n > 0. let index = (n as usize) - 1; - if index >= num_rows { - Ok(new_null_array(arr.data_type(), num_rows)) + if index >= n_range { + ScalarValue::try_from(arr.data_type()) } else { - let value = - ScalarValue::try_from_array(arr, partition.start + index)?; - let arr = value.to_array_of_size(num_rows); - // because the default window frame is between unbounded preceding and current - // row, hence the shift because for values with indices < index they should be - // null. This changes when window frames other than default is implemented - shift(arr.as_ref(), index as i64).map_err(DataFusionError::ArrowError) + ScalarValue::try_from_array(arr, range.start + index) } } } @@ -208,11 +180,21 @@ mod tests { let values = vec![arr]; let schema = Schema::new(vec![Field::new("arr", DataType::Int32, false)]); let batch = RecordBatch::try_new(Arc::new(schema), values.clone())?; - let result = expr - .create_evaluator(&batch)? - .evaluate_with_rank(vec![0..8], vec![0..8])?; - assert_eq!(1, result.len()); - let result = result[0].as_any().downcast_ref::().unwrap(); + let mut ranges: Vec> = vec![]; + for i in 0..8 { + ranges.push(Range { + start: 0, + end: i + 1, + }) + } + let evaluator = expr.create_evaluator(&batch)?; + let result = ranges + .into_iter() + .map(|range| evaluator.evaluate_inside_range(range)) + .into_iter() + .collect::>>()?; + let result = ScalarValue::iter_to_array(result.into_iter())?; + let result = result.as_any().downcast_ref::().unwrap(); assert_eq!(expected, *result); Ok(()) } @@ -235,7 +217,19 @@ mod tests { Arc::new(Column::new("arr", 0)), DataType::Int32, ); - test_i32_result(last_value, Int32Array::from_iter_values(vec![8; 8]))?; + test_i32_result( + last_value, + Int32Array::from(vec![ + Some(1), + Some(-2), + Some(3), + Some(-4), + Some(5), + Some(-6), + Some(7), + Some(8), + ]), + )?; Ok(()) } diff --git a/datafusion/physical-expr/src/window/partition_evaluator.rs b/datafusion/physical-expr/src/window/partition_evaluator.rs index c3a88367a2c2..4ecfd87a9df0 100644 --- a/datafusion/physical-expr/src/window/partition_evaluator.rs +++ b/datafusion/physical-expr/src/window/partition_evaluator.rs @@ -18,8 +18,8 @@ //! partition evaluation module use arrow::array::ArrayRef; -use datafusion_common::DataFusionError; use datafusion_common::Result; +use datafusion_common::{DataFusionError, ScalarValue}; use std::ops::Range; /// Given a partition range, and the full list of sort partition points, given that the sort @@ -46,6 +46,10 @@ pub trait PartitionEvaluator { false } + fn uses_window_frame(&self) -> bool { + false + } + /// evaluate the partition evaluator against the partitions fn evaluate(&self, partition_points: Vec>) -> Result> { partition_points @@ -83,4 +87,11 @@ pub trait PartitionEvaluator { "evaluate_partition_with_rank is not implemented by default".into(), )) } + + /// evaluate window function result inside given range + fn evaluate_inside_range(&self, _range: Range) -> Result { + Err(DataFusionError::NotImplemented( + "evaluate_inside_range is not implemented by default".into(), + )) + } } diff --git a/datafusion/physical-expr/src/window/window_expr.rs b/datafusion/physical-expr/src/window/window_expr.rs index 67caba51dcab..9c4b1b17970d 100644 --- a/datafusion/physical-expr/src/window/window_expr.rs +++ b/datafusion/physical-expr/src/window/window_expr.rs @@ -20,12 +20,17 @@ use arrow::compute::kernels::partition::lexicographical_partition_ranges; use arrow::compute::kernels::sort::{SortColumn, SortOptions}; use arrow::record_batch::RecordBatch; use arrow::{array::ArrayRef, datatypes::Field}; -use datafusion_common::{DataFusionError, Result}; +use datafusion_common::bisect::bisect; +use datafusion_common::{DataFusionError, Result, ScalarValue}; use std::any::Any; +use std::cmp::min; use std::fmt::Debug; use std::ops::Range; use std::sync::Arc; +use datafusion_expr::WindowFrameBound; +use datafusion_expr::{WindowFrame, WindowFrameUnits}; + /// A window expression that: /// * knows its resulting field pub trait WindowExpr: Send + Sync + Debug { @@ -110,4 +115,208 @@ pub trait WindowExpr: Send + Sync + Debug { sort_columns.extend(order_by_columns); Ok(sort_columns) } + + /// We use start and end bounds to calculate current row's starting and ending range. + /// This function supports different modes, but we currently do not support window calculation for GROUPS inside window frames. + fn calculate_range( + &self, + window_frame: &Option>, + range_columns: &[ArrayRef], + sort_options: &[SortOptions], + length: usize, + idx: usize, + ) -> Result<(usize, usize)> { + if let Some(window_frame) = window_frame { + match window_frame.units { + WindowFrameUnits::Range => { + let start = match &window_frame.start_bound { + // UNBOUNDED PRECEDING + WindowFrameBound::Preceding(n) => { + if n.is_null() { + 0 + } else { + calculate_index_of_row::( + range_columns, + sort_options, + idx, + Some(n), + )? + } + } + WindowFrameBound::CurrentRow => { + if range_columns.is_empty() { + 0 + } else { + calculate_index_of_row::( + range_columns, + sort_options, + idx, + None, + )? + } + } + WindowFrameBound::Following(n) => { + calculate_index_of_row::( + range_columns, + sort_options, + idx, + Some(n), + )? + } + }; + let end = match &window_frame.end_bound { + WindowFrameBound::Preceding(n) => { + calculate_index_of_row::( + range_columns, + sort_options, + idx, + Some(n), + )? + } + WindowFrameBound::CurrentRow => { + if range_columns.is_empty() { + length + } else { + calculate_index_of_row::( + range_columns, + sort_options, + idx, + None, + )? + } + } + WindowFrameBound::Following(n) => { + if n.is_null() { + // UNBOUNDED FOLLOWING + length + } else { + calculate_index_of_row::( + range_columns, + sort_options, + idx, + Some(n), + )? + } + } + }; + Ok((start, end)) + } + WindowFrameUnits::Rows => { + let start = match window_frame.start_bound { + // UNBOUNDED PRECEDING + WindowFrameBound::Preceding(ScalarValue::UInt64(None)) => 0, + WindowFrameBound::Preceding(ScalarValue::UInt64(Some(n))) => { + if idx >= n as usize { + idx - n as usize + } else { + 0 + } + } + WindowFrameBound::Preceding(_) => { + return Err(DataFusionError::Internal( + "Rows should be Uint".to_string(), + )) + } + WindowFrameBound::CurrentRow => idx, + // UNBOUNDED FOLLOWING + WindowFrameBound::Following(ScalarValue::UInt64(None)) => { + return Err(DataFusionError::Internal(format!( + "Frame start cannot be UNBOUNDED FOLLOWING '{:?}'", + window_frame + ))) + } + WindowFrameBound::Following(ScalarValue::UInt64(Some(n))) => { + min(idx + n as usize, length) + } + WindowFrameBound::Following(_) => { + return Err(DataFusionError::Internal( + "Rows should be Uint".to_string(), + )) + } + }; + let end = match window_frame.end_bound { + // UNBOUNDED PRECEDING + WindowFrameBound::Preceding(ScalarValue::UInt64(None)) => { + return Err(DataFusionError::Internal(format!( + "Frame end cannot be UNBOUNDED PRECEDING '{:?}'", + window_frame + ))) + } + WindowFrameBound::Preceding(ScalarValue::UInt64(Some(n))) => { + if idx >= n as usize { + idx - n as usize + 1 + } else { + 0 + } + } + WindowFrameBound::Preceding(_) => { + return Err(DataFusionError::Internal( + "Rows should be Uint".to_string(), + )) + } + WindowFrameBound::CurrentRow => idx + 1, + // UNBOUNDED FOLLOWING + WindowFrameBound::Following(ScalarValue::UInt64(None)) => length, + WindowFrameBound::Following(ScalarValue::UInt64(Some(n))) => { + min(idx + n as usize + 1, length) + } + WindowFrameBound::Following(_) => { + return Err(DataFusionError::Internal( + "Rows should be Uint".to_string(), + )) + } + }; + Ok((start, end)) + } + WindowFrameUnits::Groups => Err(DataFusionError::NotImplemented( + "Window frame for groups is not implemented".to_string(), + )), + } + } else { + Ok((0, length)) + } + } +} + +fn calculate_index_of_row( + range_columns: &[ArrayRef], + sort_options: &[SortOptions], + idx: usize, + delta: Option<&ScalarValue>, +) -> Result { + let current_row_values = range_columns + .iter() + .map(|col| ScalarValue::try_from_array(col, idx)) + .collect::>>()?; + let end_range = if let Some(delta) = delta { + let is_descending: bool = sort_options + .first() + .ok_or_else(|| DataFusionError::Internal("Array is empty".to_string()))? + .descending; + + current_row_values + .iter() + .map(|value| { + if value.is_null() { + return Ok(value.clone()); + } + if SEARCH_SIDE == is_descending { + // TODO: Handle positive overflows + value.add(delta) + } else if value.is_unsigned() && value < delta { + // NOTE: This gets a polymorphic zero without having long coercion code for ScalarValue. + // If we decide to implement a "default" construction mechanism for ScalarValue, + // change the following statement to use that. + value.sub(value) + } else { + // TODO: Handle negative overflows + value.sub(delta) + } + }) + .collect::>>()? + } else { + current_row_values + }; + // `BISECT_SIDE` true means bisect_left, false means bisect_right + bisect::(range_columns, &end_range, sort_options) }