Skip to content

Commit

Permalink
Only keep necessary attribute output.
Browse files Browse the repository at this point in the history
  • Loading branch information
viirya committed Mar 23, 2015
1 parent bf044de commit 8e16206
Showing 1 changed file with 14 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -130,27 +130,34 @@ class Analyzer(catalog: Catalog,
* expressions which equal GroupBy expressions with Literal(null), if those expressions
* are not set for this grouping set (according to the bit mask).
*/
private[this] def expand(g: GroupingSets): Seq[GroupExpression] = {
private[this] def expand(g: GroupingSets): (Seq[GroupExpression], Seq[Attribute]) = {
val result = new scala.collection.mutable.ArrayBuffer[GroupExpression]

val allExprs = g.aggregations ++ g.groupByExprs

g.bitmasks.foreach { bitmask =>
// get the non selected grouping attributes according to the bit mask
val nonSelectedGroupExprSet = buildNonSelectExprSet(bitmask, g.groupByExprs)

val substitution = (g.child.output :+ g.gid).map(expr => expr transformDown {
val substitution = (g.child.output :+ g.gid).collect {
case x: Expression if nonSelectedGroupExprSet.contains(x) =>
// if the input attribute in the Invalid Grouping Expression set of for this group
// replace it with constant null
Literal(null, expr.dataType)
Literal(null, x.dataType)
case x: Expression if allExprs.exists(_.references.contains(x)) => x
case x if x == g.gid =>
// replace the groupingId with concrete value (the bit mask)
Literal(bitmask, IntegerType)
})
}

result += GroupExpression(substitution)
}

result.toSeq
val output = g.child.output.collect {
case x: Expression if allExprs.exists(_.references.contains(x)) => x
}

(result.toSeq, output)
}

def apply(plan: LogicalPlan): LogicalPlan = plan transform {
Expand All @@ -159,10 +166,11 @@ class Analyzer(catalog: Catalog,
case a: Rollup if a.resolved =>
GroupingSets(bitmasks(a), a.groupByExprs, a.child, a.aggregations, a.gid)
case x: GroupingSets if x.resolved =>
val expanded = expand(x)
Aggregate(
x.groupByExprs :+ x.gid,
x.aggregations,
Expand(expand(x), x.child.output :+ x.gid, x.child))
Expand(expanded._1, expanded._2 :+ x.gid, x.child))
}
}

Expand Down

0 comments on commit 8e16206

Please sign in to comment.