diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 4d1443596bcbe..0e82aa0b39ee2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -315,6 +315,10 @@ object SetOperationPushDown extends Rule[LogicalPlan] with PredicateHelper { * - LeftSemiJoin */ object ColumnPruning extends Rule[LogicalPlan] { + private def sameOutput(output1: Seq[Attribute], output2: Seq[Attribute]): Boolean = + output1.size == output2.size && + output1.zip(output2).forall(pair => pair._1.semanticEquals(pair._2)) + def apply(plan: LogicalPlan): LogicalPlan = plan transform { // Prunes the unused columns from project list of Project/Aggregate/Expand case p @ Project(_, p2: Project) if (p2.outputSet -- p.references).nonEmpty => @@ -375,11 +379,7 @@ object ColumnPruning extends Rule[LogicalPlan] { case p @ Project(projectList, child) if sameOutput(child.output, p.output) => child // Eliminate no-op Window - case w: Window if sameOutput(w.child.output, w.output) => w.child - - // Convert Aggregate to Project if no aggregate function exists - case a: Aggregate if !containsAggregates(a.expressions) => - Project(a.aggregateExpressions, a.child) + case w: Window if w.windowExpressions.isEmpty => w.child // Can't prune the columns on LeafNode case p @ Project(_, l: LeafNode) => p @@ -411,18 +411,6 @@ object ColumnPruning extends Rule[LogicalPlan] { } else { c } - - private def sameOutput(output1: Seq[Attribute], output2: Seq[Attribute]): Boolean = - output1.size == output2.size && - output1.zip(output2).forall(pair => pair._1.semanticEquals(pair._2)) - - private def isAggregateExpression(e: Expression): Boolean = { - e.isInstanceOf[AggregateExpression] || e.isInstanceOf[Grouping] || e.isInstanceOf[GroupingID] - } - - private def containsAggregates(exprs: Seq[Expression]): Boolean = { - exprs.exists(_.find(isAggregateExpression).nonEmpty) - } } /** diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala index 37b76df9fc10f..2020c48effbc4 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala @@ -184,7 +184,8 @@ class ColumnPruningSuite extends PlanTest { val optimized = Optimize.execute(originalQuery.analyze) val correctAnswer = testRelation - .select('a).analyze + .select('a) + .groupBy('a)('a).analyze comparePlans(optimized, correctAnswer) } @@ -201,7 +202,7 @@ class ColumnPruningSuite extends PlanTest { val correctAnswer = testRelation .select('a) - .select('a as 'c).analyze + .groupBy('a)('a as 'c).analyze comparePlans(optimized, correctAnswer) } @@ -270,7 +271,7 @@ class ColumnPruningSuite extends PlanTest { UnspecifiedFrame)).as('window)).select('a, 'c) val correctAnswer = - input.select('a, 'c).analyze + input.select('a, 'c, 'd).groupBy('a, 'c, 'd)('a, 'c).analyze val optimized = Optimize.execute(originalQuery.analyze)