Skip to content

Commit

Permalink
Optimize code
Browse files Browse the repository at this point in the history
  • Loading branch information
beliefer committed Feb 2, 2020
1 parent 6a32d83 commit 5c38bbe
Show file tree
Hide file tree
Showing 4 changed files with 69 additions and 45 deletions.
Expand Up @@ -170,6 +170,8 @@ package object dsl {
def count(e: Expression): Expression = Count(e).toAggregateExpression()
def countDistinct(e: Expression*): Expression =
Count(e).toAggregateExpression(isDistinct = true)
def countDistinct(filter: Option[Expression], e: Expression*): Expression =
Count(e).toAggregateExpression(isDistinct = true, filter = filter)
def approxCountDistinct(e: Expression, rsd: Double = 0.05): Expression =
HyperLogLogPlusPlus(e, rsd).toAggregateExpression()
def avg(e: Expression): Expression = Average(e).toAggregateExpression()
Expand Down
Expand Up @@ -216,15 +216,21 @@ abstract class AggregateFunction extends Expression {
def toAggregateExpression(): AggregateExpression = toAggregateExpression(isDistinct = false)

/**
* Wraps this [[AggregateFunction]] in an [[AggregateExpression]] and sets `isDistinct`
* flag of the [[AggregateExpression]] to the given value because
* Wraps this [[AggregateFunction]] in an [[AggregateExpression]] with `isDistinct`
* flag and `filter` option of the [[AggregateExpression]] to the given value because
* [[AggregateExpression]] is the container of an [[AggregateFunction]], aggregation mode,
* and the flag indicating if this aggregation is distinct aggregation or not.
* the flag indicating if this aggregation is distinct aggregation or not and filter option.
* An [[AggregateFunction]] should not be used without being wrapped in
* an [[AggregateExpression]].
*/
def toAggregateExpression(isDistinct: Boolean): AggregateExpression = {
AggregateExpression(aggregateFunction = this, mode = Complete, isDistinct = isDistinct)
def toAggregateExpression(
isDistinct: Boolean,
filter: Option[Expression] = None): AggregateExpression = {
AggregateExpression(
aggregateFunction = this,
mode = Complete,
isDistinct = isDistinct,
filter = filter)
}

def sql(isDistinct: Boolean): String = {
Expand Down
Expand Up @@ -121,8 +121,8 @@ import org.apache.spark.sql.types.IntegerType
* Third example: single distinct aggregate function with filter clauses (in sql):
* {{{
* SELECT
* COUNT(DISTINCT cat1) FILTER (WHERE id > 1) as cat1_cnt1,
* COUNT(DISTINCT cat1) as cat1_cnt2,
* COUNT(DISTINCT cat1) FILTER (WHERE id > 1) as cat1_cnt,
* COUNT(DISTINCT cat2) as cat2_cnt,
* SUM(value) AS total
* FROM
* data
Expand All @@ -135,9 +135,9 @@ import org.apache.spark.sql.types.IntegerType
* Aggregate(
* key = ['key]
* functions = [COUNT(DISTINCT 'cat1) with FILTER('id > 1),
* COUNT(DISTINCT 'cat1),
* COUNT(DISTINCT 'cat2),
* sum('value)]
* output = ['key, 'cat1_cnt1, 'cat1_cnt2, 'total])
* output = ['key, 'cat1_cnt, 'cat2_cnt, 'total])
* LocalTableScan [...]
* }}}
*
Expand All @@ -148,36 +148,33 @@ import org.apache.spark.sql.types.IntegerType
* functions = [count(if (('gid = 1)) '_gen_distinct_1 else null),
* count(if (('gid = 2)) '_gen_distinct_2 else null),
* first(if (('gid = 0)) 'total else null) ignore nulls]
* output = ['key, 'cat1_cnt1, 'cat1_cnt2, 'total])
* output = ['key, 'cat1_cnt, 'cat2_cnt, 'total])
* Aggregate(
* key = ['key, '_gen_distinct_1, '_gen_distinct_2, 'gid]
* functions = [sum('value)]
* output = ['key, '_gen_distinct_1, '_gen_distinct_2, 'gid, 'total])
* Expand(
* projections = [('key, null, null, 0, 'value),
* ('key, '_gen_distinct_1, null, 1, null),
* ('key, null, '_gen_distinct_2, 2, null)]
* ('key, '_gen_distinct_1, null, 1, null),
* ('key, null, '_gen_distinct_2, 2, null)]
* output = ['key, '_gen_distinct_1, '_gen_distinct_2, 'gid, 'value])
* Expand(
* projections = [('key, if ('id > 1) 'cat1 else null, 'cat1, cast('value as bigint))]
* projections = [('key, if ('id > 1) 'cat1 else null, 'cat2, cast('value as bigint))]
* output = ['key, '_gen_distinct_1, '_gen_distinct_2, 'value])
* LocalTableScan [...]
* }}}
*
* The rule serves two purposes:
* 1. Expand distinct aggregates which exists filter clause.
* 2. Rewrite when aggregate exists at least two distinct aggregates.
* The rule consists of the two phases as follows:
*
* The first child rule does the following things here:
* 1. Guaranteed to compute filter clause locally.
* In the first phase, expands data for the distinct aggregates where filter clauses exist:
* 1. Guaranteed to compute filter clauses in the first aggregate locally.
* 2. The attributes referenced by different distinct aggregate expressions are likely to overlap,
* and if no additional processing is performed, data loss will occur. To prevent this, we
* generate new attributes and replace the original ones.
* 3. If we apply the first rule to distinct aggregate expressions which exists filter
* clause, the aggregate after expand may have at least two distinct aggregates, so we need to
* apply the second rule too.
* 3. After generate new attributes, the aggregate may have at least two distinct aggregates,
* so we need the second phase too.
*
* The second child rule does the following things here:
* In the second phase, rewrite a query with two or more distinct groups:
* 1. Expand the data. There are three aggregation groups in this query:
* i. the non-distinct group;
* ii. the distinct 'cat1 group;
Expand Down Expand Up @@ -207,38 +204,34 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] {
val distinctAggs = exprs.flatMap { _.collect {
case ae: AggregateExpression if ae.isDistinct => ae
}}
// This rule serves two purposes:
// One is to rewrite when there exists at least two distinct aggregates. We need at least
// two distinct aggregates for this rule because aggregation strategy can handle a single
// distinct group.
// Another is to expand distinct aggregates which exists filter clause so that we can
// evaluate the filter locally.
// We need at least two distinct aggregates or a single distinct aggregate with a filter for
// this rule because aggregation strategy can handle a single distinct group without a filter.
// This check can produce false-positives, e.g., SUM(DISTINCT a) & COUNT(DISTINCT a).
distinctAggs.size >= 1 || distinctAggs.exists(_.filter.isDefined)
distinctAggs.size > 1 || distinctAggs.exists(_.filter.isDefined)
}

def apply(plan: LogicalPlan): LogicalPlan = plan transformUp {
case a: Aggregate if mayNeedtoRewrite(a.aggregateExpressions) =>
val expandAggregate = extractFiltersInDistinctAggregate(a)
rewriteDistinctAggregate(expandAggregate)
val expandAggregate = extractFiltersInDistinctAggregates(a)
rewriteDistinctAggregates(expandAggregate)
}

private def extractFiltersInDistinctAggregate(a: Aggregate): Aggregate = {
private def extractFiltersInDistinctAggregates(a: Aggregate): Aggregate = {
val aggExpressions = collectAggregateExprs(a)
val (distinctAggExpressions, regularAggExpressions) = aggExpressions.partition(_.isDistinct)
if (distinctAggExpressions.exists(_.filter.isDefined)) {
// Setup expand for the 'regular' aggregate expressions. Because we will construct a new
// aggregate, the children of the distinct aggregates will be changed to the generate
// ones, so we need creates new references to avoid collisions between distinct and
// regular aggregate children.
// Constructs pairs between old and new expressions for regular aggregates. Because we
// will construct a new `Aggregate` and the children of the distinct aggregates will be
// changed to generated ones, we need to create new references to avoid collisions between
// distinct and regular aggregate children.
val regularAggExprs = regularAggExpressions.filter(_.children.exists(!_.foldable))
val regularFunChildren = regularAggExprs
.flatMap(_.aggregateFunction.children.filter(!_.foldable))
val regularFilterAttrs = regularAggExprs.flatMap(_.filterAttributes)
val regularAggChildren = (regularFunChildren ++ regularFilterAttrs).distinct
val regularAggChildAttrMap = regularAggChildren.map(expressionAttributePair)
val regularAggChildAttrLookup = regularAggChildAttrMap.toMap
val regularAggMap = regularAggExprs.map {
val regularAggPairs = regularAggExprs.map {
case ae @ AggregateExpression(af, _, _, filter, _) =>
val newChildren = af.children.map(c => regularAggChildAttrLookup.getOrElse(c, c))
val raf = af.withNewChildren(newChildren).asInstanceOf[AggregateFunction]
Expand All @@ -249,9 +242,9 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] {
(ae, aggExpr)
}

// Setup expand for the 'distinct' aggregate expressions.
// Constructs pairs between old and new expressions for distinct aggregates, too.
val distinctAggExprs = distinctAggExpressions.filter(e => e.children.exists(!_.foldable))
val (projections, expressionAttrs, aggExprPairs) = distinctAggExprs.map {
val (projections, expressionAttrs, distinctAggPairs) = distinctAggExprs.map {
case ae @ AggregateExpression(af, _, _, filter, _) =>
// Why do we need to construct the `exprId` ?
// First, In order to reduce costs, it is better to handle the filter clause locally.
Expand All @@ -261,9 +254,8 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] {
// same attributes. We need to construct the generated attributes so as the output not
// lost. e.g. SUM (DISTINCT a), COUNT (DISTINCT a) FILTER (WHERE id > 1) will output
// attribute '_gen_distinct-1 and attribute '_gen_distinct-2 instead of two 'a.
// Note: We just need to illusion the expression with filter clause.
// The illusionary mechanism may result in multiple distinct aggregations uses
// different column, so we still need to call `rewrite`.
// Note: The illusionary mechanism may result in at least two distinct groups, so we
// still need to call `rewrite`.
val exprId = NamedExpression.newExprId.id
val unfoldableChildren = af.children.filter(!_.foldable)
val exprAttrs = unfoldableChildren.map { e =>
Expand Down Expand Up @@ -292,7 +284,7 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] {
val groupByAttrs = groupByMap.map(_._2)
// Construct the expand operator.
val expand = Expand(rewriteAggProjections, groupByAttrs ++ allAggAttrs, a.child)
val rewriteAggExprLookup = (aggExprPairs ++ regularAggMap).toMap
val rewriteAggExprLookup = (distinctAggPairs ++ regularAggPairs).toMap
val patchedAggExpressions = a.aggregateExpressions.map { e =>
e.transformDown {
case ae: AggregateExpression => rewriteAggExprLookup.getOrElse(ae, ae)
Expand All @@ -305,7 +297,7 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] {
}
}

private def rewriteDistinctAggregate(a: Aggregate): Aggregate = {
private def rewriteDistinctAggregates(a: Aggregate): Aggregate = {
val aggExpressions = collectAggregateExprs(a)

// Extract distinct aggregate expressions.
Expand Down
Expand Up @@ -20,7 +20,7 @@ import org.apache.spark.sql.catalyst.analysis.{Analyzer, EmptyFunctionRegistry}
import org.apache.spark.sql.catalyst.catalog.{InMemoryCatalog, SessionCatalog}
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions.Literal
import org.apache.spark.sql.catalyst.expressions.{EqualTo, Literal}
import org.apache.spark.sql.catalyst.expressions.aggregate.CollectSet
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Expand, LocalRelation, LogicalPlan}
Expand All @@ -42,6 +42,16 @@ class RewriteDistinctAggregatesSuite extends PlanTest {
case _ => fail(s"Plan is not rewritten:\n$rewrite")
}

private def checkGenerate(generate: LogicalPlan): Unit = generate match {
case Aggregate(_, _, _: Expand) =>
case _ => fail(s"Plan is not generated:\n$generate")
}

private def checkGenerateAndRewrite(rewrite: LogicalPlan): Unit = rewrite match {
case Aggregate(_, _, Aggregate(_, _, Expand(_, _, _: Expand))) =>
case _ => fail(s"Plan is not rewritten:\n$rewrite")
}

test("single distinct group") {
val input = testRelation
.groupBy('a)(countDistinct('e))
Expand All @@ -50,6 +60,13 @@ class RewriteDistinctAggregatesSuite extends PlanTest {
comparePlans(input, rewrite)
}

test("single distinct group with filter") {
val input = testRelation
.groupBy('a)(countDistinct(Some(EqualTo('d, Literal(""))), 'e))
.analyze
checkGenerate(RewriteDistinctAggregates(input))
}

test("single distinct group with partial aggregates") {
val input = testRelation
.groupBy('a, 'd)(
Expand All @@ -67,6 +84,13 @@ class RewriteDistinctAggregatesSuite extends PlanTest {
checkRewrite(RewriteDistinctAggregates(input))
}

test("multiple distinct groups with filter") {
val input = testRelation
.groupBy('a)(countDistinct(Some(EqualTo('d, Literal(""))), 'b, 'c), countDistinct('d))
.analyze
checkGenerateAndRewrite(RewriteDistinctAggregates(input))
}

test("multiple distinct groups with partial aggregates") {
val input = testRelation
.groupBy('a)(countDistinct('b, 'c), countDistinct('d), sum('e))
Expand Down

0 comments on commit 5c38bbe

Please sign in to comment.