-
Notifications
You must be signed in to change notification settings - Fork 29.1k
[SPARK-55702][SQL] Support filter predicate in window aggregate functions #54501
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -46,13 +46,17 @@ private[window] object AggregateProcessor { | |
| functions: Array[Expression], | ||
| ordinal: Int, | ||
| inputAttributes: Seq[Attribute], | ||
| newMutableProjection: (Seq[Expression], Seq[Attribute]) => MutableProjection) | ||
| newMutableProjection: (Seq[Expression], Seq[Attribute]) => MutableProjection, | ||
| filters: Array[Option[Expression]]) | ||
| : AggregateProcessor = { | ||
| assert(filters.length == functions.length, | ||
| s"filters length (${filters.length}) must match functions length (${functions.length})") | ||
| val aggBufferAttributes = mutable.Buffer.empty[AttributeReference] | ||
| val initialValues = mutable.Buffer.empty[Expression] | ||
| val updateExpressions = mutable.Buffer.empty[Expression] | ||
| val evaluateExpressions = mutable.Buffer.fill[Expression](ordinal)(NoOp) | ||
| val imperatives = mutable.Buffer.empty[ImperativeAggregate] | ||
| val imperativeFilterExprs = mutable.Buffer.empty[Option[Expression]] | ||
|
|
||
| // SPARK-14244: `SizeBasedWindowFunction`s are firstly created on driver side and then | ||
| // serialized to executor side. These functions all reference a global singleton window | ||
|
|
@@ -73,25 +77,34 @@ private[window] object AggregateProcessor { | |
| } | ||
|
|
||
| // Add an AggregateFunction to the AggregateProcessor. | ||
| functions.foreach { | ||
| case agg: DeclarativeAggregate => | ||
| functions.zip(filters).foreach { | ||
| case (agg: DeclarativeAggregate, filterOpt) => | ||
| aggBufferAttributes ++= agg.aggBufferAttributes | ||
| initialValues ++= agg.initialValues | ||
| updateExpressions ++= agg.updateExpressions | ||
| filterOpt match { | ||
| case Some(filter) => | ||
| updateExpressions ++= agg.updateExpressions.zip(agg.aggBufferAttributes).map { | ||
| case (updateExpr, attr) => If(filter, updateExpr, attr) | ||
| } | ||
| case None => | ||
| updateExpressions ++= agg.updateExpressions | ||
| } | ||
| evaluateExpressions += agg.evaluateExpression | ||
| case agg: ImperativeAggregate => | ||
| case (agg: ImperativeAggregate, filterOpt) => | ||
| val offset = aggBufferAttributes.size | ||
| val imperative = BindReferences.bindReference(agg | ||
| .withNewInputAggBufferOffset(offset) | ||
| .withNewMutableAggBufferOffset(offset), | ||
| inputAttributes) | ||
| imperatives += imperative | ||
| imperativeFilterExprs += filterOpt.map(f => | ||
| BindReferences.bindReference(f, inputAttributes)) | ||
| aggBufferAttributes ++= imperative.aggBufferAttributes | ||
| val noOps = Seq.fill(imperative.aggBufferAttributes.size)(NoOp) | ||
| initialValues ++= noOps | ||
| updateExpressions ++= noOps | ||
| evaluateExpressions += imperative | ||
| case other => | ||
| case (other, _) => | ||
| throw SparkException.internalError(s"Unsupported aggregate function: $other") | ||
| } | ||
|
|
||
|
|
@@ -108,6 +121,7 @@ private[window] object AggregateProcessor { | |
| updateProj, | ||
| evalProj, | ||
| imperatives.toArray, | ||
| imperativeFilterExprs.toArray, | ||
| partitionSize.isDefined) | ||
| } | ||
| } | ||
|
|
@@ -122,6 +136,7 @@ private[window] final class AggregateProcessor( | |
| private[this] val updateProjection: MutableProjection, | ||
| private[this] val evaluateProjection: MutableProjection, | ||
| private[this] val imperatives: Array[ImperativeAggregate], | ||
| private[this] val imperativeFilters: Array[Option[Expression]], | ||
| private[this] val trackPartitionSize: Boolean) { | ||
|
|
||
| private[this] val join = new JoinedRow | ||
|
|
@@ -152,7 +167,15 @@ private[window] final class AggregateProcessor( | |
| updateProjection(join(buffer, input)) | ||
| var i = 0 | ||
| while (i < numImperatives) { | ||
| imperatives(i).update(buffer, input) | ||
| val shouldUpdate = imperativeFilters(i) match { | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Looks like there is no common expression evaluation here?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
| case Some(filter) => | ||
| val result = filter.eval(input) | ||
| result != null && result.asInstanceOf[Boolean] | ||
| case None => true | ||
| } | ||
| if (shouldUpdate) { | ||
| imperatives(i).update(buffer, input) | ||
| } | ||
| i += 1 | ||
| } | ||
| } | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -182,11 +182,41 @@ FROM testData | |
| WHERE val is not null | ||
| WINDOW w AS (PARTITION BY cate ORDER BY val); | ||
|
|
||
| -- with filter predicate | ||
| -- window aggregate with filter predicate: first_value/last_value (imperative aggregate) | ||
| SELECT val, cate, | ||
| count(val) FILTER (WHERE val > 1) OVER(PARTITION BY cate) | ||
| first_value(val) FILTER (WHERE cate = 'a') OVER(ORDER BY val_long | ||
| ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS first_a, | ||
| last_value(val) FILTER (WHERE cate = 'a') OVER(ORDER BY val_long | ||
| ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS last_a | ||
| FROM testData ORDER BY val_long, cate; | ||
|
Comment on lines
+187
to
+191
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The tests all use either UNBOUNDED PRECEDING AND CURRENT ROW (growing frame) or no-frame PARTITION BY cate (full partition). There's no test for a true sliding window like: sum(val) FILTER (WHERE val > 1) OVER (ORDER BY val_long ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING) |
||
|
|
||
| -- window aggregate with filter predicate: multiple aggregates with different filters | ||
| SELECT val, cate, | ||
| sum(val) FILTER (WHERE cate = 'a') OVER(ORDER BY val_long | ||
| ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS sum_a, | ||
| sum(val) FILTER (WHERE cate = 'b') OVER(ORDER BY val_long | ||
| ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS sum_b, | ||
| count(val) FILTER (WHERE val > 1) OVER(ORDER BY val_long | ||
| ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS cnt_gt1 | ||
| FROM testData ORDER BY val_long, cate; | ||
|
|
||
| -- window aggregate with filter predicate: entire partition frame | ||
| SELECT val, cate, | ||
| sum(val) FILTER (WHERE cate = 'a') OVER(PARTITION BY cate) AS total_sum_filtered | ||
| FROM testData ORDER BY cate, val; | ||
|
|
||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No test for RANGE frame? All new tests use ROW frames. There's no test for: sum(val) FILTER (WHERE cate = 'a') OVER (ORDER BY val_long RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) |
||
| -- window aggregate with filter predicate: sliding window (ROWS frame) | ||
| SELECT val, cate, | ||
| sum(val) FILTER (WHERE val > 1) OVER(ORDER BY val_long | ||
| ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING) AS sliding_sum_filtered | ||
| FROM testData ORDER BY val_long, cate; | ||
|
|
||
| -- window aggregate with filter predicate: RANGE frame | ||
| SELECT val, cate, | ||
| sum(val) FILTER (WHERE cate = 'a') OVER(ORDER BY val | ||
| RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS range_sum_filtered | ||
| FROM testData ORDER BY val, cate; | ||
|
|
||
| -- nth_value()/first_value()/any_value() over () | ||
| SELECT | ||
| employee_name, | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does it mean
filterwill be evaluated multiple times? Maybe common expression evaluation helps.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's pretty much the same as interpreted version of
HashAggregateExec:AggregationIterator