Skip to content

Commit

Permalink
Support using var/var_pop/stddev/stddev_pop in window expressions wit…
Browse files Browse the repository at this point in the history
…h custom frames (#4848)

* Wire up retract_batch for Stddev/StddevPop/Variance/VariancePop to

* Add test for Stddev/StddevPop/Variance/VariancePop with window frame
  • Loading branch information
jonmmease committed Jan 10, 2023
1 parent 13fb42e commit 292eb95
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 0 deletions.
28 changes: 28 additions & 0 deletions datafusion/core/tests/sql/window.rs
Original file line number Diff line number Diff line change
Expand Up @@ -524,6 +524,34 @@ async fn window_frame_rows_preceding() -> Result<()> {
Ok(())
}

#[tokio::test]
async fn window_frame_rows_preceding_stddev_variance() -> Result<()> {
let ctx = SessionContext::new();
register_aggregate_csv(&ctx).await?;
let sql = "SELECT \
VAR(c4) OVER(ORDER BY c4 ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING),\
VAR_POP(c4) OVER(ORDER BY c4 ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING),\
STDDEV(c4) OVER(ORDER BY c4 ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING),\
STDDEV_POP(c4) OVER(ORDER BY c4 ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING)\
FROM aggregate_test_100 \
ORDER BY c9 \
LIMIT 5";
let actual = execute_to_batches(&ctx, sql).await;
let expected = vec![
"+---------------------------------+------------------------------------+-------------------------------+----------------------------------+",
"| VARIANCE(aggregate_test_100.c4) | VARIANCEPOP(aggregate_test_100.c4) | STDDEV(aggregate_test_100.c4) | STDDEVPOP(aggregate_test_100.c4) |",
"+---------------------------------+------------------------------------+-------------------------------+----------------------------------+",
"| 46721.33333333174 | 31147.555555554496 | 216.15118166073427 | 176.4867007894773 |",
"| 2639429.333333332 | 1759619.5555555548 | 1624.6320609089714 | 1326.5065229977404 |",
"| 746202.3333333324 | 497468.2222222216 | 863.8300372951455 | 705.3142719541563 |",
"| 768422.9999999981 | 512281.9999999988 | 876.5973990378925 | 715.7387791645767 |",
"| 66526.3333333288 | 44350.88888888587 | 257.9269922542594 | 210.5965073045749 |",
"+---------------------------------+------------------------------------+-------------------------------+----------------------------------+",
];
assert_batches_eq!(expected, &actual);
Ok(())
}

#[tokio::test]
async fn window_frame_rows_preceding_with_partition_unique_order_by() -> Result<()> {
let ctx = SessionContext::new();
Expand Down
12 changes: 12 additions & 0 deletions datafusion/physical-expr/src/aggregate/stddev.rs
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,10 @@ impl AggregateExpr for Stddev {
Ok(Box::new(StddevAccumulator::try_new(StatsType::Sample)?))
}

fn create_sliding_accumulator(&self) -> Result<Box<dyn Accumulator>> {
Ok(Box::new(StddevAccumulator::try_new(StatsType::Sample)?))
}

fn state_fields(&self) -> Result<Vec<Field>> {
Ok(vec![
Field::new(
Expand Down Expand Up @@ -128,6 +132,10 @@ impl AggregateExpr for StddevPop {
Ok(Box::new(StddevAccumulator::try_new(StatsType::Population)?))
}

fn create_sliding_accumulator(&self) -> Result<Box<dyn Accumulator>> {
Ok(Box::new(StddevAccumulator::try_new(StatsType::Population)?))
}

fn state_fields(&self) -> Result<Vec<Field>> {
Ok(vec![
Field::new(
Expand Down Expand Up @@ -184,6 +192,10 @@ impl Accumulator for StddevAccumulator {
self.variance.update_batch(values)
}

fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
self.variance.retract_batch(values)
}

fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
self.variance.merge_batch(states)
}
Expand Down
10 changes: 10 additions & 0 deletions datafusion/physical-expr/src/aggregate/variance.rs
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,10 @@ impl AggregateExpr for Variance {
Ok(Box::new(VarianceAccumulator::try_new(StatsType::Sample)?))
}

fn create_sliding_accumulator(&self) -> Result<Box<dyn Accumulator>> {
Ok(Box::new(VarianceAccumulator::try_new(StatsType::Sample)?))
}

fn state_fields(&self) -> Result<Vec<Field>> {
Ok(vec![
Field::new(
Expand Down Expand Up @@ -136,6 +140,12 @@ impl AggregateExpr for VariancePop {
)?))
}

fn create_sliding_accumulator(&self) -> Result<Box<dyn Accumulator>> {
Ok(Box::new(VarianceAccumulator::try_new(
StatsType::Population,
)?))
}

fn state_fields(&self) -> Result<Vec<Field>> {
Ok(vec![
Field::new(
Expand Down

0 comments on commit 292eb95

Please sign in to comment.