Skip to content

Commit

Permalink
Unified implementation of filter in regular aggregates and distinct a…
Browse files Browse the repository at this point in the history
…ggregates.
  • Loading branch information
beliefer committed Jul 13, 2020
1 parent 762e839 commit 12e6fbc
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 108 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -102,19 +102,19 @@ import org.apache.spark.sql.types.IntegerType
* {{{
* Aggregate(
* key = ['key]
* functions = [count(if (('gid = 1)) 'cat1 else null),
* count(if (('gid = 2)) 'cat2 else null),
* functions = [count(if (('gid = 1)) '_gen_attr_1 else null),
* count(if (('gid = 2)) '_gen_attr_2 else null),
* first(if (('gid = 0)) 'total else null) ignore nulls]
* output = ['key, 'cat1_cnt, 'cat2_cnt, 'total])
* Aggregate(
* key = ['key, 'cat1, 'cat2, 'gid]
* functions = [sum('value) with FILTER('id > 1)]
* output = ['key, 'cat1, 'cat2, 'gid, 'total])
* key = ['key, '_gen_attr_1, '_gen_attr_2, 'gid]
* functions = [sum('_gen_attr_3)]
* output = ['key, '_gen_attr_1, '_gen_attr_2, 'gid, 'total])
* Expand(
* projections = [('key, null, null, 0, cast('value as bigint), 'id),
* projections = [('key, null, null, 0, if ('id > 1) cast('value as bigint) else null, 'id),
* ('key, 'cat1, null, 1, null, null),
* ('key, null, 'cat2, 2, null, null)]
* output = ['key, 'cat1, 'cat2, 'gid, 'value, 'id])
* output = ['key, '_gen_attr_1, '_gen_attr_2, 'gid, '_gen_attr_3, 'id])
* LocalTableScan [...]
* }}}
*
Expand Down Expand Up @@ -144,12 +144,12 @@ import org.apache.spark.sql.types.IntegerType
* {{{
* Aggregate(
* key = ['key]
* functions = [count('_gen_distinct_1),
* sum('value)]
* functions = [count('_gen_attr_1),
* sum('_gen_attr_2)]
* output = ['key, 'cat1_cnt, 'total])
* Project(
* projectionList = ['key, if ('id > 1) 'cat1 else null, cast('value as bigint)]
* output = ['key, '_gen_distinct_1, 'value])
* output = ['key, '_gen_attr_1, '_gen_attr_2])
* LocalTableScan [...]
* }}}
*
Expand Down Expand Up @@ -180,45 +180,45 @@ import org.apache.spark.sql.types.IntegerType
* {{{
* Aggregate(
* key = ['key]
* functions = [count(if (('gid = 1)) '_gen_distinct_1 else null),
* count(if (('gid = 2)) '_gen_distinct_2 else null),
* functions = [count(if (('gid = 1)) '_gen_attr_1 else null),
* count(if (('gid = 2)) '_gen_attr_2 else null),
* first(if (('gid = 0)) 'total else null) ignore nulls]
* output = ['key, 'cat1_cnt, 'cat2_cnt, 'total])
* Aggregate(
* key = ['key, '_gen_distinct_1, '_gen_distinct_2, 'gid]
* functions = [sum('value)]
* output = ['key, '_gen_distinct_1, '_gen_distinct_2, 'gid, 'total])
* key = ['key, '_gen_attr_1, '_gen_attr_2, 'gid]
* functions = [sum('_gen_attr_3)]
* output = ['key, '_gen_attr_1, '_gen_attr_2, 'gid, 'total])
* Expand(
* projections = [('key, null, null, 0, cast('value as bigint)),
* ('key, if ('id > 1) 'cat1 else null, null, 1, null),
* ('key, null, 'cat2, 2, null)]
* output = ['key, '_gen_distinct_1, '_gen_distinct_2, 'gid, 'value])
* output = ['key, '_gen_attr_1, '_gen_attr_2, 'gid, '_gen_attr_3])
* LocalTableScan [...]
* }}}
*
* The rule consists of the two phases as follows:
*
* In the first phase, if the aggregate query with distinct aggregations and
* filter clauses, project the output of the child of the aggregate query:
* In the first phase, if the aggregate query exists filter clauses, project the output of
* the child of the aggregate query:
* 1. Project the data. There are three aggregation groups in this query:
* i. the non-distinct group;
* ii. the distinct 'cat1 group;
* iii. the distinct 'cat2 group with filter clause.
* Because there is at least one distinct group with filter clause (e.g. the distinct 'cat2
* Because there is at least one group with filter clause (e.g. the distinct 'cat2
* group with filter clause), then will project the data.
* 2. Avoid projections that may output the same attributes. There are three aggregation groups
* in this query:
* i. the non-distinct group;
* i. the non-distinct 'cat1 group;
* ii. the distinct 'cat1 group;
* iii. the distinct 'cat1 group with filter clause.
* The attributes referenced by different distinct aggregate expressions are likely to overlap,
* The attributes referenced by different aggregate expressions are likely to overlap,
* and if no additional processing is performed, data loss will occur. If we directly output
* the attributes of the aggregate expression, we may get two attributes 'cat1. To prevent
* this, we generate new attributes (e.g. '_gen_distinct_1) and replace the original ones.
* the attributes of the aggregate expression, we may get three attributes 'cat1. To prevent
* this, we generate new attributes (e.g. '_gen_attr_1) and replace the original ones.
*
* Why we need the first phase? guaranteed to compute filter clauses in the first aggregate
* locally.
* Note: after generate new attributes, the aggregate may have at least two distinct aggregates,
* Note: after generate new attributes, the aggregate may have at least two distinct groups,
* so we need the second phase too.
*
* In the second phase, rewrite a query with two or more distinct groups:
Expand Down Expand Up @@ -262,64 +262,35 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] {

def apply(plan: LogicalPlan): LogicalPlan = plan transformUp {
case a: Aggregate if mayNeedtoRewrite(a.aggregateExpressions) =>
val (aggregate, projected) = projectFiltersInDistinctAggregates(a)
val (aggregate, projected) = projectFiltersInAggregates(a)
rewriteDistinctAggregates(aggregate, projected)
}

private def projectFiltersInDistinctAggregates(a: Aggregate): (Aggregate, Boolean) = {
private def projectFiltersInAggregates(a: Aggregate): (Aggregate, Boolean) = {
val aggExpressions = collectAggregateExprs(a)
val (distinctAggExpressions, regularAggExpressions) = aggExpressions.partition(_.isDistinct)
if (distinctAggExpressions.exists(_.filter.isDefined)) {
// Constructs pairs between old and new expressions for regular aggregates. Because we
// will construct a new `Aggregate` and the children of the distinct aggregates will be
// changed to generated ones, we need to create new references to avoid collisions between
// distinct and regular aggregate children.
val regularAggExprs = regularAggExpressions.filter(_.children.exists(!_.foldable))
val regularFunChildren = regularAggExprs
.flatMap(_.aggregateFunction.children.filter(!_.foldable))
val regularFilterAttrs = regularAggExprs.flatMap(_.filterAttributes)
val regularAggChildren = (regularFunChildren ++ regularFilterAttrs).distinct
val regularAggChildrenMap = regularAggChildren.map {
case ne: NamedExpression => ne -> ne
case other => other -> Alias(other, other.toString)()
}
val namedRegularAggChildren = regularAggChildrenMap.map(_._2)
val regularAggChildAttrLookup = regularAggChildrenMap.map { kv =>
(kv._1, kv._2.toAttribute)
}.toMap
val regularAggPairs = regularAggExprs.map {
case ae @ AggregateExpression(af, _, _, filter, _) =>
val newChildren = af.children.map(c => regularAggChildAttrLookup.getOrElse(c, c))
val raf = af.withNewChildren(newChildren).asInstanceOf[AggregateFunction]
val filterOpt = filter.map(_.transform {
case a: Attribute => regularAggChildAttrLookup.getOrElse(a, a)
})
val aggExpr = ae.copy(aggregateFunction = raf, filter = filterOpt)
(ae, aggExpr)
}

// Constructs pairs between old and new expressions for distinct aggregates, too.
val distinctAggExprs = distinctAggExpressions.filter(e => e.children.exists(!_.foldable))
val (projections, distinctAggPairs) = distinctAggExprs.map {
if (aggExpressions.exists(_.filter.isDefined)) {
// Constructs pairs between old and new expressions for aggregates.
val aggExprs = aggExpressions.filter(e => e.children.exists(!_.foldable))
val (projections, aggPairs) = aggExprs.map {
case ae @ AggregateExpression(af, _, _, filter, _) =>
// First, In order to reduce costs, it is better to handle the filter clause locally.
// e.g. COUNT (DISTINCT a) FILTER (WHERE id > 1), evaluate expression
// If(id > 1) 'a else null first, and use the result as output.
// Second, If at least two DISTINCT aggregate expression which may references the
// same attributes. We need to construct the generated attributes so as the output not
// lost. e.g. SUM (DISTINCT a), COUNT (DISTINCT a) FILTER (WHERE id > 1) will output
// attribute '_gen_distinct-1 and attribute '_gen_distinct-2 instead of two 'a.
// attribute '_gen_attr-1 and attribute '_gen_attr-2 instead of two 'a.
// Note: The illusionary mechanism may result in at least two distinct groups, so we
// still need to call `rewrite`.
// still need to call `rewriteDistinctAggregates`.
val unfoldableChildren = af.children.filter(!_.foldable)
// Expand projection
val projectionMap = unfoldableChildren.map {
case e if filter.isDefined =>
val ife = If(filter.get, e, nullify(e))
e -> Alias(ife, s"_gen_distinct_${NamedExpression.newExprId.id}")()
e -> Alias(ife, s"_gen_attr_${NamedExpression.newExprId.id}")()
// For convenience and unification, we always alias the distinct column, even if
// there is no filter.
case e => e -> Alias(e, s"_gen_distinct_${NamedExpression.newExprId.id}")()
case e => e -> Alias(e, s"_gen_attr_${NamedExpression.newExprId.id}")()
}
val projection = projectionMap.map(_._2)
val exprAttrs = projectionMap.map { kv =>
Expand All @@ -336,12 +307,11 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] {
case ne: NamedExpression => ne
case other => Alias(other, other.toString)()
}
val rewriteAggProjection =
namedGroupingExpressions ++ namedRegularAggChildren ++ projections.flatten
val rewriteAggProjection = namedGroupingExpressions ++ projections.flatten
// Construct the project operator.
val project = Project(rewriteAggProjection, a.child)
val groupByAttrs = namedGroupingExpressions.map(_.toAttribute)
val rewriteAggExprLookup = (distinctAggPairs ++ regularAggPairs).toMap
val rewriteAggExprLookup = aggPairs.toMap
val patchedAggExpressions = a.aggregateExpressions.map { e =>
e.transformDown {
case ae: AggregateExpression => rewriteAggExprLookup.getOrElse(ae, ae)
Expand Down Expand Up @@ -425,10 +395,9 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] {
// only expand unfoldable children
val regularAggExprs = aggExpressions
.filter(e => !e.isDistinct && e.children.exists(!_.foldable))
val regularAggFunChildren = regularAggExprs
val regularAggChildren = regularAggExprs
.flatMap(_.aggregateFunction.children.filter(!_.foldable))
val regularAggFilterAttrs = regularAggExprs.flatMap(_.filterAttributes)
val regularAggChildren = (regularAggFunChildren ++ regularAggFilterAttrs).distinct
.distinct
val regularAggChildAttrMap = regularAggChildren.map(expressionAttributePair)

// Setup aggregates for 'regular' aggregate expressions.
Expand All @@ -437,12 +406,7 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] {
val regularAggOperatorMap = regularAggExprs.map { e =>
// Perform the actual aggregation in the initial aggregate.
val af = patchAggregateFunctionChildren(e.aggregateFunction)(regularAggChildAttrLookup.get)
// We changed the attributes in the [[Expand]] output using expressionAttributePair.
// So we need to replace the attributes in FILTER expression with new ones.
val filterOpt = e.filter.map(_.transform {
case a: Attribute => regularAggChildAttrLookup.getOrElse(a, a)
})
val operator = Alias(e.copy(aggregateFunction = af, filter = filterOpt), e.sql)()
val operator = Alias(e.copy(aggregateFunction = af), e.sql)()

// Select the result of the first aggregate in the last aggregate.
val result = AggregateExpression(
Expand Down Expand Up @@ -484,7 +448,7 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] {
}

val (projections, expandChild) = if (projected) {
// If `projectFiltersInDistinctAggregates` inserts Project as child of Aggregate and
// If `projectFiltersInAggregates` inserts Project as child of Aggregate and
// `rewriteDistinctAggregates` will insert Expand here, merge Project with the Expand.
val projectAttributeExpressionMap = a.child.asInstanceOf[Project].projectList.map {
case ne: NamedExpression => ne.name -> ne
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -157,44 +157,19 @@ abstract class AggregationIterator(
inputAttributes: Seq[Attribute]): (InternalRow, InternalRow) => Unit = {
val joinedRow = new JoinedRow
if (expressions.nonEmpty) {
val mergeExpressions =
functions.zip(expressions.map(ae => (ae.mode, ae.isDistinct, ae.filter))).flatMap {
case (ae: DeclarativeAggregate, (mode, isDistinct, filter)) =>
mode match {
case Partial | Complete =>
if (filter.isDefined) {
ae.updateExpressions.zip(ae.aggBufferAttributes).map {
case (updateExpr, attr) => If(filter.get, updateExpr, attr)
}
} else {
ae.updateExpressions
}
case PartialMerge | Final => ae.mergeExpressions
}
case (agg: AggregateFunction, _) => Seq.fill(agg.aggBufferAttributes.length)(NoOp)
}
// Initialize predicates for aggregate functions if necessary
val predicateOptions = expressions.map {
case AggregateExpression(_, mode, _, Some(filter), _) =>
mode match {
case Partial | Complete =>
val predicate = Predicate.create(filter, inputAttributes)
predicate.initialize(partIndex)
Some(predicate)
case _ => None
val mergeExpressions = functions.zip(expressions).flatMap {
case (ae: DeclarativeAggregate, expression) =>
expression.mode match {
case Partial | Complete => ae.updateExpressions
case PartialMerge | Final => ae.mergeExpressions
}
case _ => None
case (agg: AggregateFunction, _) => Seq.fill(agg.aggBufferAttributes.length)(NoOp)
}
val updateFunctions = functions.zipWithIndex.collect {
case (ae: ImperativeAggregate, i) =>
expressions(i).mode match {
case Partial | Complete =>
if (predicateOptions(i).isDefined) {
(buffer: InternalRow, row: InternalRow) =>
if (predicateOptions(i).get.eval(row)) { ae.update(buffer, row) }
} else {
(buffer: InternalRow, row: InternalRow) => ae.update(buffer, row)
}
(buffer: InternalRow, row: InternalRow) => ae.update(buffer, row)
case PartialMerge | Final =>
(buffer: InternalRow, row: InternalRow) => ae.merge(buffer, row)
}
Expand Down

0 comments on commit 12e6fbc

Please sign in to comment.