From 6cf6f444697c0bc32dbc09906fe144563a7d66df Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Wed, 9 Mar 2016 23:10:08 -0800 Subject: [PATCH] eliminate useless Aggregate and Window --- .../sql/catalyst/optimizer/Optimizer.scala | 26 +++++++++++++++---- .../optimizer/ColumnPruningSuite.scala | 15 ++++------- 2 files changed, 26 insertions(+), 15 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 5b9112e9b17e4..4d1443596bcbe 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -315,10 +315,6 @@ object SetOperationPushDown extends Rule[LogicalPlan] with PredicateHelper { * - LeftSemiJoin */ object ColumnPruning extends Rule[LogicalPlan] { - def sameOutput(output1: Seq[Attribute], output2: Seq[Attribute]): Boolean = - output1.size == output2.size && - output1.zip(output2).forall(pair => pair._1.semanticEquals(pair._2)) - def apply(plan: LogicalPlan): LogicalPlan = plan transform { // Prunes the unused columns from project list of Project/Aggregate/Expand case p @ Project(_, p2: Project) if (p2.outputSet -- p.references).nonEmpty => @@ -378,13 +374,21 @@ object ColumnPruning extends Rule[LogicalPlan] { // Eliminate no-op Projects case p @ Project(projectList, child) if sameOutput(child.output, p.output) => child + // Eliminate no-op Window + case w: Window if sameOutput(w.child.output, w.output) => w.child + + // Convert Aggregate to Project if no aggregate function exists + case a: Aggregate if !containsAggregates(a.expressions) => + Project(a.aggregateExpressions, a.child) + // Can't prune the columns on LeafNode case p @ Project(_, l: LeafNode) => p // Prune windowExpressions and child of Window case p @ Project(_, w: Window) if (w.outputSet -- p.references).nonEmpty => val newWindowExprs = w.windowExpressions.filter(p.references.contains) - val newGrandChild = prunedChild(w.child, w.references ++ p.references) + val newGrandChild = + prunedChild(w.child, p.references ++ AttributeSet(newWindowExprs.flatMap(_.references))) p.copy(child = w.copy( windowExpressions = newWindowExprs, child = newGrandChild)) @@ -407,6 +411,18 @@ object ColumnPruning extends Rule[LogicalPlan] { } else { c } + + private def sameOutput(output1: Seq[Attribute], output2: Seq[Attribute]): Boolean = + output1.size == output2.size && + output1.zip(output2).forall(pair => pair._1.semanticEquals(pair._2)) + + private def isAggregateExpression(e: Expression): Boolean = { + e.isInstanceOf[AggregateExpression] || e.isInstanceOf[Grouping] || e.isInstanceOf[GroupingID] + } + + private def containsAggregates(exprs: Seq[Expression]): Boolean = { + exprs.exists(_.find(isAggregateExpression).nonEmpty) + } } /** diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala index 898623cc4ca22..37b76df9fc10f 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala @@ -184,8 +184,7 @@ class ColumnPruningSuite extends PlanTest { val optimized = Optimize.execute(originalQuery.analyze) val correctAnswer = testRelation - .select('a) - .groupBy('a)('a).analyze + .select('a).analyze comparePlans(optimized, correctAnswer) } @@ -202,7 +201,7 @@ class ColumnPruningSuite extends PlanTest { val correctAnswer = testRelation .select('a) - .groupBy('a)('a as 'c).analyze + .select('a as 'c).analyze comparePlans(optimized, correctAnswer) } @@ -263,7 +262,7 @@ class ColumnPruningSuite extends PlanTest { val input = LocalRelation('a.int, 'b.string, 'c.double, 'd.int) val originalQuery = - input.groupBy('a)('a, 'b, 'c, 'd, + input.groupBy('a, 'c, 'd)('a, 'c, 'd, WindowExpression( AggregateExpression(Count('b), Complete, isDistinct = false), WindowSpecDefinition( 'a :: Nil, @@ -271,9 +270,7 @@ class ColumnPruningSuite extends PlanTest { UnspecifiedFrame)).as('window)).select('a, 'c) val correctAnswer = - input.select('a, 'b, 'c).groupBy('a)('a, 'b, 'c) - .window(Seq.empty[NamedExpression], 'a :: Nil, 'b.asc :: Nil) - .select('a, 'c).analyze + input.select('a, 'c).analyze val optimized = Optimize.execute(originalQuery.analyze) @@ -319,9 +316,7 @@ class ColumnPruningSuite extends PlanTest { UnspecifiedFrame)).as('window)).select('a, 'c) val correctAnswer = - input.select('a, 'b, 'c) - .window(Seq.empty[NamedExpression], 'a :: Nil, 'b.asc :: Nil) - .select('a, 'c).analyze + input.select('a, 'c).analyze val optimized = Optimize.execute(originalQuery.analyze)