Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -3256,16 +3256,14 @@ class Analyzer(
}
wsc.copy(partitionSpec = newPartitionSpec, orderSpec = newOrderSpec)

case WindowExpression(ae: AggregateExpression, _) if ae.filter.isDefined =>
throw QueryCompilationErrors.windowAggregateFunctionWithFilterNotSupportedError()

// Extract Windowed AggregateExpression
case we @ WindowExpression(
ae @ AggregateExpression(function, _, _, _, _),
ae @ AggregateExpression(function, _, _, filter, _),
spec: WindowSpecDefinition) =>
val newChildren = function.children.map(extractExpr)
val newFunction = function.withNewChildren(newChildren).asInstanceOf[AggregateFunction]
val newAgg = ae.copy(aggregateFunction = newFunction)
val newFilter = filter.map(extractExpr)
val newAgg = ae.copy(aggregateFunction = newFunction, filter = newFilter)
seenWindowAggregates += newAgg
WindowExpression(newAgg, spec)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -208,19 +208,6 @@ class AnalysisErrorSuite extends AnalysisTest with DataTypeErrorsBase {
| RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW)"
|""".stripMargin.replaceAll("\n", "")))

errorTest(
"window aggregate function with filter predicate",
testRelation2.select(
WindowExpression(
Count(UnresolvedAttribute("b"))
.toAggregateExpression(isDistinct = false, filter = Some(UnresolvedAttribute("b") > 1)),
WindowSpecDefinition(
UnresolvedAttribute("a") :: Nil,
SortOrder(UnresolvedAttribute("b"), Ascending) :: Nil,
UnspecifiedFrame)).as("window")),
"window aggregate function with filter predicate is not supported" :: Nil
)

test("distinct function") {
assertAnalysisErrorCondition(
CatalystSqlParser.parsePlan("SELECT hex(DISTINCT a) FROM TaBlE"),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,13 +46,17 @@ private[window] object AggregateProcessor {
functions: Array[Expression],
ordinal: Int,
inputAttributes: Seq[Attribute],
newMutableProjection: (Seq[Expression], Seq[Attribute]) => MutableProjection)
newMutableProjection: (Seq[Expression], Seq[Attribute]) => MutableProjection,
filters: Array[Option[Expression]])
: AggregateProcessor = {
assert(filters.length == functions.length,
s"filters length (${filters.length}) must match functions length (${functions.length})")
val aggBufferAttributes = mutable.Buffer.empty[AttributeReference]
val initialValues = mutable.Buffer.empty[Expression]
val updateExpressions = mutable.Buffer.empty[Expression]
val evaluateExpressions = mutable.Buffer.fill[Expression](ordinal)(NoOp)
val imperatives = mutable.Buffer.empty[ImperativeAggregate]
val imperativeFilterExprs = mutable.Buffer.empty[Option[Expression]]

// SPARK-14244: `SizeBasedWindowFunction`s are firstly created on driver side and then
// serialized to executor side. These functions all reference a global singleton window
Expand All @@ -73,25 +77,34 @@ private[window] object AggregateProcessor {
}

// Add an AggregateFunction to the AggregateProcessor.
functions.foreach {
case agg: DeclarativeAggregate =>
functions.zip(filters).foreach {
case (agg: DeclarativeAggregate, filterOpt) =>
aggBufferAttributes ++= agg.aggBufferAttributes
initialValues ++= agg.initialValues
updateExpressions ++= agg.updateExpressions
filterOpt match {
case Some(filter) =>
updateExpressions ++= agg.updateExpressions.zip(agg.aggBufferAttributes).map {
case (updateExpr, attr) => If(filter, updateExpr, attr)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does it mean filter will be evaluated multiple times? Maybe common expression evaluation helps.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's pretty much the same as interpreted version of HashAggregateExec: AggregationIterator

}
case None =>
updateExpressions ++= agg.updateExpressions
}
evaluateExpressions += agg.evaluateExpression
case agg: ImperativeAggregate =>
case (agg: ImperativeAggregate, filterOpt) =>
val offset = aggBufferAttributes.size
val imperative = BindReferences.bindReference(agg
.withNewInputAggBufferOffset(offset)
.withNewMutableAggBufferOffset(offset),
inputAttributes)
imperatives += imperative
imperativeFilterExprs += filterOpt.map(f =>
BindReferences.bindReference(f, inputAttributes))
aggBufferAttributes ++= imperative.aggBufferAttributes
val noOps = Seq.fill(imperative.aggBufferAttributes.size)(NoOp)
initialValues ++= noOps
updateExpressions ++= noOps
evaluateExpressions += imperative
case other =>
case (other, _) =>
throw SparkException.internalError(s"Unsupported aggregate function: $other")
}

Expand All @@ -108,6 +121,7 @@ private[window] object AggregateProcessor {
updateProj,
evalProj,
imperatives.toArray,
imperativeFilterExprs.toArray,
partitionSize.isDefined)
}
}
Expand All @@ -122,6 +136,7 @@ private[window] final class AggregateProcessor(
private[this] val updateProjection: MutableProjection,
private[this] val evaluateProjection: MutableProjection,
private[this] val imperatives: Array[ImperativeAggregate],
private[this] val imperativeFilters: Array[Option[Expression]],
private[this] val trackPartitionSize: Boolean) {

private[this] val join = new JoinedRow
Expand Down Expand Up @@ -152,7 +167,15 @@ private[window] final class AggregateProcessor(
updateProjection(join(buffer, input))
var i = 0
while (i < numImperatives) {
imperatives(i).update(buffer, input)
val shouldUpdate = imperativeFilters(i) match {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like there is no common expression evaluation here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

case Some(filter) =>
val result = filter.eval(input)
result != null && result.asInstanceOf[Boolean]
case None => true
}
if (shouldUpdate) {
imperatives(i).update(buffer, input)
}
i += 1
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -194,12 +194,17 @@ trait WindowEvaluatorFactoryBase {
def processor = if (functions.exists(_.isInstanceOf[PythonFuncExpression])) {
null
} else {
val aggFilters = expressions.map {
case WindowExpression(ae: AggregateExpression, _) => ae.filter
case _ => None
}.toArray
AggregateProcessor(
functions,
ordinal,
childOutput,
(expressions, schema) =>
MutableProjection.create(expressions, schema))
MutableProjection.create(expressions, schema),
aggFilters)
}

// Create the factory to produce WindowFunctionFrame.
Expand Down
102 changes: 97 additions & 5 deletions sql/core/src/test/resources/sql-tests/analyzer-results/window.sql.out
Original file line number Diff line number Diff line change
Expand Up @@ -688,13 +688,105 @@ Project [cate#x, sum(val) OVER (PARTITION BY cate ORDER BY val ASC NULLS FIRST R

-- !query
SELECT val, cate,
count(val) FILTER (WHERE val > 1) OVER(PARTITION BY cate)
first_value(val) FILTER (WHERE cate = 'a') OVER(ORDER BY val_long
ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS first_a,
last_value(val) FILTER (WHERE cate = 'a') OVER(ORDER BY val_long
ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS last_a
FROM testData ORDER BY val_long, cate
-- !query analysis
Project [val#x, cate#x, first_a#x, last_a#x]
+- Sort [val_long#xL ASC NULLS FIRST, cate#x ASC NULLS FIRST], true
+- Project [val#x, cate#x, first_a#x, last_a#x, val_long#xL]
+- Project [val#x, cate#x, _w0#x, val_long#xL, first_a#x, last_a#x, first_a#x, last_a#x]
+- Window [first_value(val#x, false) FILTER (WHERE _w0#x) windowspecdefinition(val_long#xL ASC NULLS FIRST, specifiedwindowframe(RowFrame, unboundedpreceding$(), currentrow$())) AS first_a#x, last_value(val#x, false) FILTER (WHERE _w0#x) windowspecdefinition(val_long#xL ASC NULLS FIRST, specifiedwindowframe(RowFrame, unboundedpreceding$(), currentrow$())) AS last_a#x], [val_long#xL ASC NULLS FIRST]
+- Project [val#x, cate#x, (cate#x = a) AS _w0#x, val_long#xL]
+- SubqueryAlias testdata
+- View (`testData`, [val#x, val_long#xL, val_double#x, val_date#x, val_timestamp#x, cate#x])
+- Project [cast(val#x as int) AS val#x, cast(val_long#xL as bigint) AS val_long#xL, cast(val_double#x as double) AS val_double#x, cast(val_date#x as date) AS val_date#x, cast(val_timestamp#x as timestamp) AS val_timestamp#x, cast(cate#x as string) AS cate#x]
+- Project [val#x, val_long#xL, val_double#x, val_date#x, val_timestamp#x, cate#x]
+- SubqueryAlias testData
+- LocalRelation [val#x, val_long#xL, val_double#x, val_date#x, val_timestamp#x, cate#x]


-- !query
SELECT val, cate,
sum(val) FILTER (WHERE cate = 'a') OVER(ORDER BY val_long
ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS sum_a,
sum(val) FILTER (WHERE cate = 'b') OVER(ORDER BY val_long
ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS sum_b,
count(val) FILTER (WHERE val > 1) OVER(ORDER BY val_long
ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS cnt_gt1
FROM testData ORDER BY val_long, cate
-- !query analysis
Project [val#x, cate#x, sum_a#xL, sum_b#xL, cnt_gt1#xL]
+- Sort [val_long#xL ASC NULLS FIRST, cate#x ASC NULLS FIRST], true
+- Project [val#x, cate#x, sum_a#xL, sum_b#xL, cnt_gt1#xL, val_long#xL]
+- Project [val#x, cate#x, _w0#x, val_long#xL, _w2#x, _w3#x, sum_a#xL, sum_b#xL, cnt_gt1#xL, sum_a#xL, sum_b#xL, cnt_gt1#xL]
+- Window [sum(val#x) FILTER (WHERE _w0#x) windowspecdefinition(val_long#xL ASC NULLS FIRST, specifiedwindowframe(RowFrame, unboundedpreceding$(), currentrow$())) AS sum_a#xL, sum(val#x) FILTER (WHERE _w2#x) windowspecdefinition(val_long#xL ASC NULLS FIRST, specifiedwindowframe(RowFrame, unboundedpreceding$(), currentrow$())) AS sum_b#xL, count(val#x) FILTER (WHERE _w3#x) windowspecdefinition(val_long#xL ASC NULLS FIRST, specifiedwindowframe(RowFrame, unboundedpreceding$(), currentrow$())) AS cnt_gt1#xL], [val_long#xL ASC NULLS FIRST]
+- Project [val#x, cate#x, (cate#x = a) AS _w0#x, val_long#xL, (cate#x = b) AS _w2#x, (val#x > 1) AS _w3#x]
+- SubqueryAlias testdata
+- View (`testData`, [val#x, val_long#xL, val_double#x, val_date#x, val_timestamp#x, cate#x])
+- Project [cast(val#x as int) AS val#x, cast(val_long#xL as bigint) AS val_long#xL, cast(val_double#x as double) AS val_double#x, cast(val_date#x as date) AS val_date#x, cast(val_timestamp#x as timestamp) AS val_timestamp#x, cast(cate#x as string) AS cate#x]
+- Project [val#x, val_long#xL, val_double#x, val_date#x, val_timestamp#x, cate#x]
+- SubqueryAlias testData
+- LocalRelation [val#x, val_long#xL, val_double#x, val_date#x, val_timestamp#x, cate#x]


-- !query
SELECT val, cate,
sum(val) FILTER (WHERE cate = 'a') OVER(PARTITION BY cate) AS total_sum_filtered
FROM testData ORDER BY cate, val
-- !query analysis
org.apache.spark.sql.AnalysisException
{
"errorClass" : "_LEGACY_ERROR_TEMP_1030"
}
Sort [cate#x ASC NULLS FIRST, val#x ASC NULLS FIRST], true
+- Project [val#x, cate#x, total_sum_filtered#xL]
+- Project [val#x, cate#x, _w0#x, total_sum_filtered#xL, total_sum_filtered#xL]
+- Window [sum(val#x) FILTER (WHERE _w0#x) windowspecdefinition(cate#x, specifiedwindowframe(RowFrame, unboundedpreceding$(), unboundedfollowing$())) AS total_sum_filtered#xL], [cate#x]
+- Project [val#x, cate#x, (cate#x = a) AS _w0#x]
+- SubqueryAlias testdata
+- View (`testData`, [val#x, val_long#xL, val_double#x, val_date#x, val_timestamp#x, cate#x])
+- Project [cast(val#x as int) AS val#x, cast(val_long#xL as bigint) AS val_long#xL, cast(val_double#x as double) AS val_double#x, cast(val_date#x as date) AS val_date#x, cast(val_timestamp#x as timestamp) AS val_timestamp#x, cast(cate#x as string) AS cate#x]
+- Project [val#x, val_long#xL, val_double#x, val_date#x, val_timestamp#x, cate#x]
+- SubqueryAlias testData
+- LocalRelation [val#x, val_long#xL, val_double#x, val_date#x, val_timestamp#x, cate#x]


-- !query
SELECT val, cate,
sum(val) FILTER (WHERE val > 1) OVER(ORDER BY val_long
ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING) AS sliding_sum_filtered
FROM testData ORDER BY val_long, cate
-- !query analysis
Project [val#x, cate#x, sliding_sum_filtered#xL]
+- Sort [val_long#xL ASC NULLS FIRST, cate#x ASC NULLS FIRST], true
+- Project [val#x, cate#x, sliding_sum_filtered#xL, val_long#xL]
+- Project [val#x, cate#x, _w0#x, val_long#xL, sliding_sum_filtered#xL, sliding_sum_filtered#xL]
+- Window [sum(val#x) FILTER (WHERE _w0#x) windowspecdefinition(val_long#xL ASC NULLS FIRST, specifiedwindowframe(RowFrame, -1, 1)) AS sliding_sum_filtered#xL], [val_long#xL ASC NULLS FIRST]
+- Project [val#x, cate#x, (val#x > 1) AS _w0#x, val_long#xL]
+- SubqueryAlias testdata
+- View (`testData`, [val#x, val_long#xL, val_double#x, val_date#x, val_timestamp#x, cate#x])
+- Project [cast(val#x as int) AS val#x, cast(val_long#xL as bigint) AS val_long#xL, cast(val_double#x as double) AS val_double#x, cast(val_date#x as date) AS val_date#x, cast(val_timestamp#x as timestamp) AS val_timestamp#x, cast(cate#x as string) AS cate#x]
+- Project [val#x, val_long#xL, val_double#x, val_date#x, val_timestamp#x, cate#x]
+- SubqueryAlias testData
+- LocalRelation [val#x, val_long#xL, val_double#x, val_date#x, val_timestamp#x, cate#x]


-- !query
SELECT val, cate,
sum(val) FILTER (WHERE cate = 'a') OVER(ORDER BY val
RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS range_sum_filtered
FROM testData ORDER BY val, cate
-- !query analysis
Sort [val#x ASC NULLS FIRST, cate#x ASC NULLS FIRST], true
+- Project [val#x, cate#x, range_sum_filtered#xL]
+- Project [val#x, cate#x, _w0#x, range_sum_filtered#xL, range_sum_filtered#xL]
+- Window [sum(val#x) FILTER (WHERE _w0#x) windowspecdefinition(val#x ASC NULLS FIRST, specifiedwindowframe(RangeFrame, unboundedpreceding$(), currentrow$())) AS range_sum_filtered#xL], [val#x ASC NULLS FIRST]
+- Project [val#x, cate#x, (cate#x = a) AS _w0#x]
+- SubqueryAlias testdata
+- View (`testData`, [val#x, val_long#xL, val_double#x, val_date#x, val_timestamp#x, cate#x])
+- Project [cast(val#x as int) AS val#x, cast(val_long#xL as bigint) AS val_long#xL, cast(val_double#x as double) AS val_double#x, cast(val_date#x as date) AS val_date#x, cast(val_timestamp#x as timestamp) AS val_timestamp#x, cast(cate#x as string) AS cate#x]
+- Project [val#x, val_long#xL, val_double#x, val_date#x, val_timestamp#x, cate#x]
+- SubqueryAlias testData
+- LocalRelation [val#x, val_long#xL, val_double#x, val_date#x, val_timestamp#x, cate#x]


-- !query
Expand Down
34 changes: 32 additions & 2 deletions sql/core/src/test/resources/sql-tests/inputs/window.sql
Original file line number Diff line number Diff line change
Expand Up @@ -182,11 +182,41 @@ FROM testData
WHERE val is not null
WINDOW w AS (PARTITION BY cate ORDER BY val);

-- with filter predicate
-- window aggregate with filter predicate: first_value/last_value (imperative aggregate)
SELECT val, cate,
count(val) FILTER (WHERE val > 1) OVER(PARTITION BY cate)
first_value(val) FILTER (WHERE cate = 'a') OVER(ORDER BY val_long
ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS first_a,
last_value(val) FILTER (WHERE cate = 'a') OVER(ORDER BY val_long
ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS last_a
FROM testData ORDER BY val_long, cate;
Comment on lines +187 to +191
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The tests all use either UNBOUNDED PRECEDING AND CURRENT ROW (growing frame) or no-frame PARTITION BY cate (full partition). There's no test for a true sliding window like:

sum(val) FILTER (WHERE val > 1) OVER (ORDER BY val_long ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING)


-- window aggregate with filter predicate: multiple aggregates with different filters
SELECT val, cate,
sum(val) FILTER (WHERE cate = 'a') OVER(ORDER BY val_long
ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS sum_a,
sum(val) FILTER (WHERE cate = 'b') OVER(ORDER BY val_long
ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS sum_b,
count(val) FILTER (WHERE val > 1) OVER(ORDER BY val_long
ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS cnt_gt1
FROM testData ORDER BY val_long, cate;

-- window aggregate with filter predicate: entire partition frame
SELECT val, cate,
sum(val) FILTER (WHERE cate = 'a') OVER(PARTITION BY cate) AS total_sum_filtered
FROM testData ORDER BY cate, val;

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No test for RANGE frame?

All new tests use ROW frames. There's no test for:

sum(val) FILTER (WHERE cate = 'a') OVER (ORDER BY val_long RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW)

-- window aggregate with filter predicate: sliding window (ROWS frame)
SELECT val, cate,
sum(val) FILTER (WHERE val > 1) OVER(ORDER BY val_long
ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING) AS sliding_sum_filtered
FROM testData ORDER BY val_long, cate;

-- window aggregate with filter predicate: RANGE frame
SELECT val, cate,
sum(val) FILTER (WHERE cate = 'a') OVER(ORDER BY val
RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS range_sum_filtered
FROM testData ORDER BY val, cate;

-- nth_value()/first_value()/any_value() over ()
SELECT
employee_name,
Expand Down
Loading