New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[SPARK-30276][SQL] Support Filter expression allows simultaneous use of DISTINCT #27428
Changes from 25 commits
2dc9db4
6a32d83
5c38bbe
a6498f9
c6caf73
cd00f91
4a6f903
96456e2
4314005
bd314cb
a56f2b0
7d6ada4
529b69e
54f6d84
a7bcbc9
73dc600
5a4ca02
70ff08e
16d8c1d
d6af4a7
5cd1439
3c49156
762e839
12e6fbc
d531864
5bbbfd7
20ad143
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||
---|---|---|---|---|
|
@@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.optimizer | |||
|
||||
import org.apache.spark.sql.catalyst.expressions._ | ||||
import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, AggregateFunction, Complete} | ||||
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Expand, LogicalPlan} | ||||
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Expand, LogicalPlan, Project} | ||||
import org.apache.spark.sql.catalyst.rules.Rule | ||||
import org.apache.spark.sql.types.IntegerType | ||||
|
||||
|
@@ -31,10 +31,10 @@ import org.apache.spark.sql.types.IntegerType | |||
* First example: query without filter clauses (in scala): | ||||
* {{{ | ||||
* val data = Seq( | ||||
* ("a", "ca1", "cb1", 10), | ||||
* ("a", "ca1", "cb2", 5), | ||||
* ("b", "ca1", "cb1", 13)) | ||||
* .toDF("key", "cat1", "cat2", "value") | ||||
* (1, "a", "ca1", "cb1", 10), | ||||
* (2, "a", "ca1", "cb2", 5), | ||||
* (3, "b", "ca1", "cb1", 13)) | ||||
* .toDF("id", "key", "cat1", "cat2", "value") | ||||
* data.createOrReplaceTempView("data") | ||||
* | ||||
* val agg = data.groupBy($"key") | ||||
|
@@ -102,23 +102,126 @@ import org.apache.spark.sql.types.IntegerType | |||
* {{{ | ||||
* Aggregate( | ||||
* key = ['key] | ||||
* functions = [count(if (('gid = 1)) 'cat1 else null), | ||||
* count(if (('gid = 2)) 'cat2 else null), | ||||
* functions = [count(if (('gid = 1)) '_gen_attr_1 else null), | ||||
* count(if (('gid = 2)) '_gen_attr_2 else null), | ||||
* first(if (('gid = 0)) 'total else null) ignore nulls] | ||||
* output = ['key, 'cat1_cnt, 'cat2_cnt, 'total]) | ||||
* Aggregate( | ||||
* key = ['key, 'cat1, 'cat2, 'gid] | ||||
* functions = [sum('value) with FILTER('id > 1)] | ||||
* output = ['key, 'cat1, 'cat2, 'gid, 'total]) | ||||
* key = ['key, '_gen_attr_1, '_gen_attr_2, 'gid] | ||||
* functions = [sum('_gen_attr_3)] | ||||
* output = ['key, '_gen_attr_1, '_gen_attr_2, 'gid, 'total]) | ||||
* Expand( | ||||
* projections = [('key, null, null, 0, cast('value as bigint), 'id), | ||||
* projections = [('key, null, null, 0, if ('id > 1) cast('value as bigint) else null, 'id), | ||||
* ('key, 'cat1, null, 1, null, null), | ||||
* ('key, null, 'cat2, 2, null, null)] | ||||
* output = ['key, 'cat1, 'cat2, 'gid, 'value, 'id]) | ||||
* output = ['key, '_gen_attr_1, '_gen_attr_2, 'gid, '_gen_attr_3, 'id]) | ||||
* LocalTableScan [...] | ||||
* }}} | ||||
* | ||||
* Third example: single distinct aggregate function with filter clauses and have | ||||
* not other distinct aggregate function (in sql): | ||||
* {{{ | ||||
* SELECT | ||||
* COUNT(DISTINCT cat1) FILTER (WHERE id > 1) as cat1_cnt, | ||||
* SUM(value) AS total | ||||
* FROM | ||||
* data | ||||
* GROUP BY | ||||
* key | ||||
* }}} | ||||
* | ||||
* This translates to the following (pseudo) logical plan: | ||||
* {{{ | ||||
* Aggregate( | ||||
* key = ['key] | ||||
* functions = [COUNT(DISTINCT 'cat1) with FILTER('id > 1), | ||||
* sum('value)] | ||||
* output = ['key, 'cat1_cnt, 'total]) | ||||
* LocalTableScan [...] | ||||
* }}} | ||||
* | ||||
* This rule rewrites this logical plan to the following (pseudo) logical plan: | ||||
* {{{ | ||||
* Aggregate( | ||||
* key = ['key] | ||||
* functions = [count('_gen_attr_1), | ||||
* sum('_gen_attr_2)] | ||||
* output = ['key, 'cat1_cnt, 'total]) | ||||
* Project( | ||||
* projectionList = ['key, if ('id > 1) 'cat1 else null, cast('value as bigint)] | ||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is this necessary? The query can work fine even if we don't add this Project in this rule, right? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This rule should be skipped if there is only one distinct. Having a filter or not shouldn't change it. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If not apply this rule, can't support the case that have only one distinct with filter clause. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I mean to unify the implementations of the filter clause that are handled by this rule. This case is not handled by this rule before your PR. Sorry if I didn't make myself clear enough. |
||||
* output = ['key, '_gen_attr_1, '_gen_attr_2]) | ||||
* LocalTableScan [...] | ||||
* }}} | ||||
* | ||||
* The rule does the following things here: | ||||
* Four example: single distinct aggregate function with filter clauses (in sql): | ||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. OK. How about |
||||
* {{{ | ||||
* SELECT | ||||
* COUNT(DISTINCT cat1) FILTER (WHERE id > 1) as cat1_cnt, | ||||
* COUNT(DISTINCT cat2) as cat2_cnt, | ||||
* SUM(value) AS total | ||||
* FROM | ||||
* data | ||||
* GROUP BY | ||||
* key | ||||
* }}} | ||||
* | ||||
* This translates to the following (pseudo) logical plan: | ||||
* {{{ | ||||
* Aggregate( | ||||
* key = ['key] | ||||
* functions = [COUNT(DISTINCT 'cat1) with FILTER('id > 1), | ||||
* COUNT(DISTINCT 'cat2), | ||||
* sum('value)] | ||||
* output = ['key, 'cat1_cnt, 'cat2_cnt, 'total]) | ||||
* LocalTableScan [...] | ||||
* }}} | ||||
* | ||||
* This rule rewrites this logical plan to the following (pseudo) logical plan: | ||||
* {{{ | ||||
* Aggregate( | ||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This plan rewriting LGTM. Shall we update the second example to make it consistent with this example?
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This means to revert some code like There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. OK. I got it. |
||||
* key = ['key] | ||||
* functions = [count(if (('gid = 1)) '_gen_attr_1 else null), | ||||
* count(if (('gid = 2)) '_gen_attr_2 else null), | ||||
* first(if (('gid = 0)) 'total else null) ignore nulls] | ||||
* output = ['key, 'cat1_cnt, 'cat2_cnt, 'total]) | ||||
* Aggregate( | ||||
* key = ['key, '_gen_attr_1, '_gen_attr_2, 'gid] | ||||
* functions = [sum('_gen_attr_3)] | ||||
* output = ['key, '_gen_attr_1, '_gen_attr_2, 'gid, 'total]) | ||||
* Expand( | ||||
* projections = [('key, null, null, 0, cast('value as bigint)), | ||||
* ('key, if ('id > 1) 'cat1 else null, null, 1, null), | ||||
* ('key, null, 'cat2, 2, null)] | ||||
* output = ['key, '_gen_attr_1, '_gen_attr_2, 'gid, '_gen_attr_3]) | ||||
* LocalTableScan [...] | ||||
* }}} | ||||
* | ||||
* The rule consists of the two phases as follows: | ||||
* | ||||
* In the first phase, if the aggregate query exists filter clauses, project the output of | ||||
* the child of the aggregate query: | ||||
* 1. Project the data. There are three aggregation groups in this query: | ||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can you update the comment? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||||
* i. the non-distinct group; | ||||
* ii. the distinct 'cat1 group; | ||||
* iii. the distinct 'cat2 group with filter clause. | ||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This doesn't match the group. Maybe just make it general There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. OK |
||||
* Because there is at least one group with filter clause (e.g. the distinct 'cat2 | ||||
* group with filter clause), then will project the data. | ||||
* 2. Avoid projections that may output the same attributes. There are three aggregation groups | ||||
* in this query: | ||||
* i. the non-distinct 'cat1 group; | ||||
* ii. the distinct 'cat1 group; | ||||
* iii. the distinct 'cat1 group with filter clause. | ||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We don't need to repeat these 3 groups. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. They are different |
||||
* The attributes referenced by different aggregate expressions are likely to overlap, | ||||
* and if no additional processing is performed, data loss will occur. If we directly output | ||||
* the attributes of the aggregate expression, we may get three attributes 'cat1. To prevent | ||||
* this, we generate new attributes (e.g. '_gen_attr_1) and replace the original ones. | ||||
* | ||||
* Why we need the first phase? guaranteed to compute filter clauses in the first aggregate | ||||
* locally. | ||||
* Note: after generate new attributes, the aggregate may have at least two distinct groups, | ||||
* so we need the second phase too. | ||||
* | ||||
* In the second phase, rewrite a query with two or more distinct groups: | ||||
* 1. Expand the data. There are three aggregation groups in this query: | ||||
* i. the non-distinct group; | ||||
* ii. the distinct 'cat1 group; | ||||
|
@@ -135,6 +238,9 @@ import org.apache.spark.sql.types.IntegerType | |||
* aggregation. In this step we use the group id to filter the inputs for the aggregate | ||||
* functions. The result of the non-distinct group are 'aggregated' by using the first operator, | ||||
* it might be more elegant to use the native UDAF merge mechanism for this in the future. | ||||
* 4. If the first phase inserted a project operator as the child of aggregate and the second phase | ||||
* already decided to insert an expand operator as the child of aggregate, the second phase will | ||||
* merge the project operator with expand operator. | ||||
* | ||||
* This rule duplicates the input data by two or more times (# distinct groups + an optional | ||||
* non-distinct group). This will put quite a bit of memory pressure of the used aggregate and | ||||
|
@@ -148,24 +254,77 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] { | |||
val distinctAggs = exprs.flatMap { _.collect { | ||||
case ae: AggregateExpression if ae.isDistinct => ae | ||||
}} | ||||
// We need at least two distinct aggregates for this rule because aggregation | ||||
// strategy can handle a single distinct group. | ||||
// We need at least two distinct aggregates or a single distinct aggregate with a filter for | ||||
// this rule because aggregation strategy can handle a single distinct group without a filter. | ||||
// This check can produce false-positives, e.g., SUM(DISTINCT a) & COUNT(DISTINCT a). | ||||
distinctAggs.size > 1 | ||||
distinctAggs.size > 1 || distinctAggs.exists(_.filter.isDefined) | ||||
} | ||||
|
||||
def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { | ||||
case a: Aggregate if mayNeedtoRewrite(a.aggregateExpressions) => rewrite(a) | ||||
case a: Aggregate if mayNeedtoRewrite(a.aggregateExpressions) => | ||||
val (aggregate, projected) = projectFiltersInAggregates(a) | ||||
rewriteDistinctAggregates(aggregate, projected) | ||||
} | ||||
|
||||
def rewrite(a: Aggregate): Aggregate = { | ||||
|
||||
// Collect all aggregate expressions. | ||||
val aggExpressions = a.aggregateExpressions.flatMap { e => | ||||
e.collect { | ||||
case ae: AggregateExpression => ae | ||||
private def projectFiltersInAggregates(a: Aggregate): (Aggregate, Boolean) = { | ||||
val aggExpressions = collectAggregateExprs(a) | ||||
if (aggExpressions.exists(_.filter.isDefined)) { | ||||
// Constructs pairs between old and new expressions for aggregates. | ||||
val aggExprs = aggExpressions.filter(e => e.children.exists(!_.foldable)) | ||||
val (projections, aggPairs) = aggExprs.map { | ||||
case ae @ AggregateExpression(af, _, _, filter, _) => | ||||
// First, In order to reduce costs, it is better to handle the filter clause locally. | ||||
// e.g. COUNT (DISTINCT a) FILTER (WHERE id > 1), evaluate expression | ||||
// If(id > 1) 'a else null first, and use the result as output. | ||||
// Second, If at least two DISTINCT aggregate expression which may references the | ||||
// same attributes. We need to construct the generated attributes so as the output not | ||||
// lost. e.g. SUM (DISTINCT a), COUNT (DISTINCT a) FILTER (WHERE id > 1) will output | ||||
// attribute '_gen_attr-1 and attribute '_gen_attr-2 instead of two 'a. | ||||
// Note: The illusionary mechanism may result in at least two distinct groups, so we | ||||
// still need to call `rewriteDistinctAggregates`. | ||||
val unfoldableChildren = af.children.filter(!_.foldable) | ||||
// Expand projection | ||||
val projectionMap = unfoldableChildren.map { | ||||
case e if filter.isDefined => | ||||
val ife = If(filter.get, e, nullify(e)) | ||||
e -> Alias(ife, s"_gen_attr_${NamedExpression.newExprId.id}")() | ||||
// For convenience and unification, we always alias the column, even if | ||||
// there is no filter. | ||||
case e => e -> Alias(e, s"_gen_attr_${NamedExpression.newExprId.id}")() | ||||
} | ||||
val projection = projectionMap.map(_._2) | ||||
val exprAttrs = projectionMap.map { kv => | ||||
(kv._1, kv._2.toAttribute) | ||||
} | ||||
val exprAttrLookup = exprAttrs.toMap | ||||
val newChildren = af.children.map(c => exprAttrLookup.getOrElse(c, c)) | ||||
val raf = af.withNewChildren(newChildren).asInstanceOf[AggregateFunction] | ||||
val aggExpr = ae.copy(aggregateFunction = raf, filter = None) | ||||
(projection, (ae, aggExpr)) | ||||
}.unzip | ||||
// Construct the aggregate input projection. | ||||
val namedGroupingExpressions = a.groupingExpressions.map { | ||||
case ne: NamedExpression => ne | ||||
case other => Alias(other, other.toString)() | ||||
} | ||||
val rewriteAggProjection = namedGroupingExpressions ++ projections.flatten | ||||
// Construct the project operator. | ||||
val project = Project(rewriteAggProjection, a.child) | ||||
val groupByAttrs = namedGroupingExpressions.map(_.toAttribute) | ||||
val rewriteAggExprLookup = aggPairs.toMap | ||||
val patchedAggExpressions = a.aggregateExpressions.map { e => | ||||
e.transformDown { | ||||
case ae: AggregateExpression => rewriteAggExprLookup.getOrElse(ae, ae) | ||||
}.asInstanceOf[NamedExpression] | ||||
} | ||||
(Aggregate(groupByAttrs, patchedAggExpressions, project), true) | ||||
} else { | ||||
(a, false) | ||||
} | ||||
} | ||||
|
||||
private def rewriteDistinctAggregates(a: Aggregate, projected: Boolean): Aggregate = { | ||||
val aggExpressions = collectAggregateExprs(a) | ||||
|
||||
// Extract distinct aggregate expressions. | ||||
val distinctAggGroups = aggExpressions.filter(_.isDistinct).groupBy { e => | ||||
|
@@ -236,10 +395,9 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] { | |||
// only expand unfoldable children | ||||
val regularAggExprs = aggExpressions | ||||
.filter(e => !e.isDistinct && e.children.exists(!_.foldable)) | ||||
val regularAggFunChildren = regularAggExprs | ||||
val regularAggChildren = regularAggExprs | ||||
.flatMap(_.aggregateFunction.children.filter(!_.foldable)) | ||||
val regularAggFilterAttrs = regularAggExprs.flatMap(_.filterAttributes) | ||||
val regularAggChildren = (regularAggFunChildren ++ regularAggFilterAttrs).distinct | ||||
.distinct | ||||
val regularAggChildAttrMap = regularAggChildren.map(expressionAttributePair) | ||||
|
||||
// Setup aggregates for 'regular' aggregate expressions. | ||||
|
@@ -248,12 +406,7 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] { | |||
val regularAggOperatorMap = regularAggExprs.map { e => | ||||
// Perform the actual aggregation in the initial aggregate. | ||||
val af = patchAggregateFunctionChildren(e.aggregateFunction)(regularAggChildAttrLookup.get) | ||||
// We changed the attributes in the [[Expand]] output using expressionAttributePair. | ||||
// So we need to replace the attributes in FILTER expression with new ones. | ||||
val filterOpt = e.filter.map(_.transform { | ||||
case a: Attribute => regularAggChildAttrLookup.getOrElse(a, a) | ||||
}) | ||||
val operator = Alias(e.copy(aggregateFunction = af, filter = filterOpt), e.sql)() | ||||
val operator = Alias(e.copy(aggregateFunction = af), e.sql)() | ||||
|
||||
// Select the result of the first aggregate in the last aggregate. | ||||
val result = AggregateExpression( | ||||
|
@@ -294,11 +447,27 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] { | |||
regularAggNulls | ||||
} | ||||
|
||||
val (projections, expandChild) = if (projected) { | ||||
// If `projectFiltersInAggregates` inserts Project as child of Aggregate and | ||||
// `rewriteDistinctAggregates` will insert Expand here, merge Project with the Expand. | ||||
val projectAttributeExpressionMap = a.child.asInstanceOf[Project].projectList.map { | ||||
case ne: NamedExpression => ne.name -> ne | ||||
}.toMap | ||||
val projections = (regularAggProjection ++ distinctAggProjections).map { | ||||
case projection: Seq[Expression] => projection.map { | ||||
case ne: NamedExpression => projectAttributeExpressionMap.getOrElse(ne.name, ne) | ||||
case other => other | ||||
} | ||||
} | ||||
(projections, a.child.asInstanceOf[Project].child) | ||||
} else { | ||||
(regularAggProjection ++ distinctAggProjections, a.child) | ||||
} | ||||
// Construct the expand operator. | ||||
val expand = Expand( | ||||
regularAggProjection ++ distinctAggProjections, | ||||
projections, | ||||
groupByAttrs ++ distinctAggChildAttrs ++ Seq(gid) ++ regularAggChildAttrMap.map(_._2), | ||||
a.child) | ||||
expandChild) | ||||
|
||||
// Construct the first aggregate operator. This de-duplicates all the children of | ||||
// distinct operators, and applies the regular aggregate operators. | ||||
|
@@ -331,6 +500,14 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] { | |||
} | ||||
} | ||||
|
||||
private def collectAggregateExprs(a: Aggregate): Seq[AggregateExpression] = { | ||||
a.aggregateExpressions.flatMap { e => | ||||
e.collect { | ||||
case ae: AggregateExpression => ae | ||||
} | ||||
} | ||||
} | ||||
|
||||
private def nullify(e: Expression) = Literal.create(null, e.dataType) | ||||
|
||||
private def expressionAttributePair(e: Expression) = | ||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we need to rewrite this query? The planner can handle single distinct agg func AFAIK.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we can keep the previous behavior.
AggregationIterator
already done this.