From b38a21ef6146784e4b93ef4ce8c899f1eee14572 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Mon, 16 Nov 2015 18:30:26 -0800 Subject: [PATCH 01/15] SPARK-11633 --- .../spark/sql/catalyst/analysis/Analyzer.scala | 3 ++- .../spark/sql/hive/execution/SQLQuerySuite.scala | 16 ++++++++++++++++ 2 files changed, 18 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 2f4670b55bdba..5a5b71e52dd79 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -425,7 +425,8 @@ class Analyzer( */ j case Some((oldRelation, newRelation)) => - val attributeRewrites = AttributeMap(oldRelation.output.zip(newRelation.output)) + val attributeRewrites = + AttributeMap(oldRelation.output.zip(newRelation.output).filter(x => x._1 != x._2)) val newRight = right transformUp { case r if r == oldRelation => newRelation } transformUp { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index 3427152b2da02..5e00546a74c00 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -51,6 +51,8 @@ case class Order( state: String, month: Int) +case class Individual(F1: Integer, F2: Integer) + case class WindowData( month: Int, area: String, @@ -1479,4 +1481,18 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { |FROM (SELECT '{"f1": "value1", "f2": 12}' json, 'hello' as str) test """.stripMargin), Row("value1", "12", 3.14, "hello")) } + + test ("SPARK-11633: HiveContext throws TreeNode Exception : Failed to Copy Node") { + val rdd1 = sparkContext.parallelize(Seq( Individual(1,3), Individual(2,1))) + val df = hiveContext.createDataFrame(rdd1) + df.registerTempTable("foo") + val df2 = sql("select f1, F2 as F2 from foo") + df2.registerTempTable("foo2") + df2.registerTempTable("foo3") + + checkAnswer(sql( + """ + SELECT a.F1 FROM foo2 a INNER JOIN foo3 b ON a.F2=b.F2 + """.stripMargin), Row(2) :: Row(1) :: Nil) + } } From 0546772f151f83d6d3cf4d000cbe341f52545007 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Fri, 20 Nov 2015 10:56:45 -0800 Subject: [PATCH 02/15] converge --- .../spark/sql/catalyst/analysis/Analyzer.scala | 3 +-- .../spark/sql/hive/execution/SQLQuerySuite.scala | 15 --------------- 2 files changed, 1 insertion(+), 17 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 7c9512fbd00aa..47962ebe6ef82 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -417,8 +417,7 @@ class Analyzer( */ j case Some((oldRelation, newRelation)) => - val attributeRewrites = - AttributeMap(oldRelation.output.zip(newRelation.output).filter(x => x._1 != x._2)) + val attributeRewrites = AttributeMap(oldRelation.output.zip(newRelation.output)) val newRight = right transformUp { case r if r == oldRelation => newRelation } transformUp { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index 5e00546a74c00..61d9dcd37572b 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -51,8 +51,6 @@ case class Order( state: String, month: Int) -case class Individual(F1: Integer, F2: Integer) - case class WindowData( month: Int, area: String, @@ -1481,18 +1479,5 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { |FROM (SELECT '{"f1": "value1", "f2": 12}' json, 'hello' as str) test """.stripMargin), Row("value1", "12", 3.14, "hello")) } - - test ("SPARK-11633: HiveContext throws TreeNode Exception : Failed to Copy Node") { - val rdd1 = sparkContext.parallelize(Seq( Individual(1,3), Individual(2,1))) - val df = hiveContext.createDataFrame(rdd1) - df.registerTempTable("foo") - val df2 = sql("select f1, F2 as F2 from foo") - df2.registerTempTable("foo2") - df2.registerTempTable("foo3") - - checkAnswer(sql( - """ - SELECT a.F1 FROM foo2 a INNER JOIN foo3 b ON a.F2=b.F2 - """.stripMargin), Row(2) :: Row(1) :: Nil) } } From b37a64f13956b6ddd0e38ddfd9fe1caee611f1a8 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Fri, 20 Nov 2015 10:58:37 -0800 Subject: [PATCH 03/15] converge --- .../org/apache/spark/sql/hive/execution/SQLQuerySuite.scala | 1 - 1 file changed, 1 deletion(-) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index 61d9dcd37572b..3427152b2da02 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -1479,5 +1479,4 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { |FROM (SELECT '{"f1": "value1", "f2": 12}' json, 'hello' as str) test """.stripMargin), Row("value1", "12", 3.14, "hello")) } - } } From 25f6ff6bbb1614661cd5cfe050536980bce585b5 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Mon, 7 Mar 2016 14:30:06 -0800 Subject: [PATCH 04/15] remove projectList from Window --- .../apache/spark/sql/catalyst/analysis/Analyzer.scala | 11 +++-------- .../spark/sql/catalyst/optimizer/Optimizer.scala | 6 +----- .../sql/catalyst/plans/logical/basicOperators.scala | 3 +-- .../apache/spark/sql/execution/SparkStrategies.scala | 5 ++--- .../scala/org/apache/spark/sql/execution/Window.scala | 6 +++--- 5 files changed, 10 insertions(+), 21 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index b5fa372643bd4..8c3bc3533429d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -424,7 +424,7 @@ class Analyzer( val newOutput = oldVersion.generatorOutput.map(_.newInstance()) (oldVersion, oldVersion.copy(generatorOutput = newOutput)) - case oldVersion @ Window(_, windowExpressions, _, _, child) + case oldVersion @ Window(windowExpressions, _, _, child) if AttributeSet(windowExpressions.map(_.toAttribute)).intersect(conflictingAttributes) .nonEmpty => (oldVersion, oldVersion.copy(windowExpressions = newAliases(windowExpressions))) @@ -661,10 +661,6 @@ class Analyzer( case p: Project => val missing = missingAttrs -- p.child.outputSet Project(p.projectList ++ missingAttrs, addMissingAttr(p.child, missing)) - case w: Window => - val missing = missingAttrs -- w.child.outputSet - w.copy(projectList = w.projectList ++ missingAttrs, - child = addMissingAttr(w.child, missing)) case a: Aggregate => // all the missing attributes should be grouping expressions // TODO: push down AggregateExpression @@ -1169,7 +1165,6 @@ class Analyzer( // Set currentChild to the newly created Window operator. currentChild = Window( - currentChild.output, windowExpressions, partitionSpec, orderSpec, @@ -1439,10 +1434,10 @@ object CleanupAliases extends Rule[LogicalPlan] { val cleanedAggs = aggs.map(trimNonTopLevelAliases(_).asInstanceOf[NamedExpression]) Aggregate(grouping.map(trimAliases), cleanedAggs, child) - case w @ Window(projectList, windowExprs, partitionSpec, orderSpec, child) => + case w @ Window(windowExprs, partitionSpec, orderSpec, child) => val cleanedWindowExprs = windowExprs.map(e => trimNonTopLevelAliases(e).asInstanceOf[NamedExpression]) - Window(projectList, cleanedWindowExprs, partitionSpec.map(trimAliases), + Window(cleanedWindowExprs, partitionSpec.map(trimAliases), orderSpec.map(trimAliases(_).asInstanceOf[SortOrder]), child) // Operators that operate on objects should only have expressions from encoders, which should diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index deea7238f564c..986f2bf07f5f0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -326,9 +326,7 @@ object ColumnPruning extends Rule[LogicalPlan] { p.copy( child = a.copy(aggregateExpressions = a.aggregateExpressions.filter(p.references.contains))) case p @ Project(_, w: Window) if (w.outputSet -- p.references).nonEmpty => - p.copy(child = w.copy( - projectList = w.projectList.filter(p.references.contains), - windowExpressions = w.windowExpressions.filter(p.references.contains))) + p.copy(child = w.copy(windowExpressions = w.windowExpressions.filter(p.references.contains))) case a @ Project(_, e @ Expand(_, _, grandChild)) if (e.outputSet -- a.references).nonEmpty => val newOutput = e.output.filter(a.references.contains(_)) val newProjects = e.projections.map { proj => @@ -345,8 +343,6 @@ object ColumnPruning extends Rule[LogicalPlan] { // Prunes the unused columns from child of Aggregate/Window/Expand/Generate case a @ Aggregate(_, _, child) if (child.outputSet -- a.references).nonEmpty => a.copy(child = prunedChild(child, a.references)) - case w @ Window(_, _, _, _, child) if (child.outputSet -- w.references).nonEmpty => - w.copy(child = prunedChild(child, w.references)) case e @ Expand(_, _, child) if (child.outputSet -- e.references).nonEmpty => e.copy(child = prunedChild(child, e.references)) case g: Generate if !g.join && (g.child.outputSet -- g.references).nonEmpty => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala index 411594c95166c..c8df131d8f0b3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala @@ -434,14 +434,13 @@ case class Aggregate( } case class Window( - projectList: Seq[Attribute], windowExpressions: Seq[NamedExpression], partitionSpec: Seq[Expression], orderSpec: Seq[SortOrder], child: LogicalPlan) extends UnaryNode { override def output: Seq[Attribute] = - projectList ++ windowExpressions.map(_.toAttribute) + child.output ++ windowExpressions.map(_.toAttribute) } private[sql] object Expand { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index debd04aa95b9c..bae0750788088 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -344,9 +344,8 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { execution.Filter(condition, planLater(child)) :: Nil case e @ logical.Expand(_, _, child) => execution.Expand(e.projections, e.output, planLater(child)) :: Nil - case logical.Window(projectList, windowExprs, partitionSpec, orderSpec, child) => - execution.Window( - projectList, windowExprs, partitionSpec, orderSpec, planLater(child)) :: Nil + case logical.Window(windowExprs, partitionSpec, orderSpec, child) => + execution.Window(windowExprs, partitionSpec, orderSpec, planLater(child)) :: Nil case logical.Sample(lb, ub, withReplacement, seed, child) => execution.Sample(lb, ub, withReplacement, seed, planLater(child)) :: Nil case logical.LocalRelation(output, data) => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala index 84154a47de393..a4c0e1c9fba41 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala @@ -81,14 +81,14 @@ import org.apache.spark.util.collection.unsafe.sort.{UnsafeExternalSorter, Unsaf * of specialized classes: [[RowBoundOrdering]] & [[RangeBoundOrdering]]. */ case class Window( - projectList: Seq[Attribute], windowExpression: Seq[NamedExpression], partitionSpec: Seq[Expression], orderSpec: Seq[SortOrder], child: SparkPlan) extends UnaryNode { - override def output: Seq[Attribute] = projectList ++ windowExpression.map(_.toAttribute) + override def output: Seq[Attribute] = + child.output ++ windowExpression.map(_.toAttribute) override def requiredChildDistribution: Seq[Distribution] = { if (partitionSpec.isEmpty) { @@ -275,7 +275,7 @@ case class Window( val unboundToRefMap = expressions.zip(references).toMap val patchedWindowExpression = windowExpression.map(_.transform(unboundToRefMap)) UnsafeProjection.create( - projectList ++ patchedWindowExpression, + child.output ++ patchedWindowExpression, child.output) } From 467b095d89ce641f568aade09d710fb9ea573273 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Mon, 7 Mar 2016 17:46:42 -0800 Subject: [PATCH 05/15] clean the comment. --- .../org/apache/spark/sql/catalyst/optimizer/Optimizer.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 986f2bf07f5f0..671ae4434df70 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -340,7 +340,7 @@ object ColumnPruning extends Rule[LogicalPlan] { case mp @ MapPartitions(_, _, _, child) if (child.outputSet -- mp.references).nonEmpty => mp.copy(child = prunedChild(child, mp.references)) - // Prunes the unused columns from child of Aggregate/Window/Expand/Generate + // Prunes the unused columns from child of Aggregate/Expand/Generate case a @ Aggregate(_, _, child) if (child.outputSet -- a.references).nonEmpty => a.copy(child = prunedChild(child, a.references)) case e @ Expand(_, _, child) if (child.outputSet -- e.references).nonEmpty => From b16923628e5212c122a8ba284b55750bb9784523 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Wed, 9 Mar 2016 16:44:25 -0800 Subject: [PATCH 06/15] changed the column pruning rule for Window. --- .../spark/sql/catalyst/dsl/package.scala | 6 ++ .../sql/catalyst/optimizer/Optimizer.scala | 12 +++- .../optimizer/ColumnPruningSuite.scala | 72 ++++++++++++++++++- 3 files changed, 86 insertions(+), 4 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala index a12f7396fe819..70c3177683b91 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala @@ -268,6 +268,12 @@ package object dsl { Aggregate(groupingExprs, aliasedExprs, logicalPlan) } + def window( + windowExpressions: Seq[NamedExpression], + partitionSpec: Seq[Expression], + orderSpec: Seq[SortOrder]): LogicalPlan = + Window(windowExpressions, partitionSpec, orderSpec, logicalPlan) + def subquery(alias: Symbol): LogicalPlan = SubqueryAlias(alias.name, logicalPlan) def except(otherPlan: LogicalPlan): LogicalPlan = Except(logicalPlan, otherPlan) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 671ae4434df70..7e5cdd48b05cb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -319,14 +319,12 @@ object ColumnPruning extends Rule[LogicalPlan] { output1.zip(output2).forall(pair => pair._1.semanticEquals(pair._2)) def apply(plan: LogicalPlan): LogicalPlan = plan transform { - // Prunes the unused columns from project list of Project/Aggregate/Window/Expand + // Prunes the unused columns from project list of Project/Aggregate/Expand case p @ Project(_, p2: Project) if (p2.outputSet -- p.references).nonEmpty => p.copy(child = p2.copy(projectList = p2.projectList.filter(p.references.contains))) case p @ Project(_, a: Aggregate) if (a.outputSet -- p.references).nonEmpty => p.copy( child = a.copy(aggregateExpressions = a.aggregateExpressions.filter(p.references.contains))) - case p @ Project(_, w: Window) if (w.outputSet -- p.references).nonEmpty => - p.copy(child = w.copy(windowExpressions = w.windowExpressions.filter(p.references.contains))) case a @ Project(_, e @ Expand(_, _, grandChild)) if (e.outputSet -- a.references).nonEmpty => val newOutput = e.output.filter(a.references.contains(_)) val newProjects = e.projections.map { proj => @@ -382,6 +380,14 @@ object ColumnPruning extends Rule[LogicalPlan] { // Eliminate no-op Projects case p @ Project(projectList, child) if sameOutput(child.output, p.output) => child + // Prune windowExpressions and child of Window + case p @ Project(_, w: Window) + if (w.outputSet -- p.references).nonEmpty || + (w.child.inputSet -- (w.references ++ p.references)).nonEmpty => + val windowExpressions = w.windowExpressions.filter(p.references.contains) + val newChild = prunedChild(w.child, allReferences = w.references ++ p.references) + p.copy(child = w.copy(windowExpressions = windowExpressions, child = newChild)) + // for all other logical plans that inherits the output from it's children case p @ Project(_, child) => val required = child.references ++ p.references diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala index d09601e0343d7..7e5dc47a9776c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala @@ -23,7 +23,8 @@ import org.apache.spark.sql.catalyst.analysis import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder -import org.apache.spark.sql.catalyst.expressions.{Ascending, Explode, Literal, SortOrder} +import org.apache.spark.sql.catalyst.expressions.aggregate.{Complete, Count, AggregateExpression} +import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.{Inner, PlanTest} import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules.RuleExecutor @@ -242,6 +243,75 @@ class ColumnPruningSuite extends PlanTest { comparePlans(optimized1, analysis.EliminateSubqueryAliases(correctAnswer1)) } + test("Column pruning on Window with useless aggregate functions") { + val input = LocalRelation('a.int, 'b.string, 'c.double, 'd.int) + + val originalQuery = + input.groupBy('a)('a, 'b, 'c, 'd, + WindowExpression( + AggregateExpression(Count('b), Complete, isDistinct = false), + WindowSpecDefinition( 'a :: Nil, + SortOrder('b, Ascending) :: Nil, + UnspecifiedFrame)).as('window)).select('a, 'c) + + val correctAnswer = + input.select('a, 'b, 'c).groupBy('a)('a, 'b, 'c) + .window(Seq.empty[NamedExpression], 'a :: Nil, 'b.asc :: Nil) + .select('a, 'c).analyze + + val optimized = Optimize.execute(originalQuery.analyze) + + comparePlans(optimized, correctAnswer) + } + + test("Column pruning on Window with selected agg expressions") { + val input = LocalRelation('a.int, 'b.string, 'c.double, 'd.int) + + val originalQuery = + input.select('a, 'b, 'c, 'd, + WindowExpression( + AggregateExpression(Count('b), Complete, isDistinct = false), + WindowSpecDefinition( 'a :: Nil, + SortOrder('b, Ascending) :: Nil, + UnspecifiedFrame)).as('window)).where('window > 1).select('a, 'c) + + val correctAnswer = + input.select('a, 'b, 'c) + .window(WindowExpression( + AggregateExpression(Count('b), Complete, isDistinct = false), + WindowSpecDefinition( 'a :: Nil, + SortOrder('b, Ascending) :: Nil, + UnspecifiedFrame)).as('window) :: Nil, + 'a :: Nil, 'b.asc :: Nil) + .select('a, 'c, 'window).select('a, 'c, 'window, 'window) + .select('a, 'c, 'window).where('window > 1).select('a, 'c).analyze + + val optimized = Optimize.execute(originalQuery.analyze) + + comparePlans(optimized, correctAnswer) + } + + test("Column pruning on Window in select") { + val input = LocalRelation('a.int, 'b.string, 'c.double, 'd.int) + + val originalQuery = + input.select('a, 'b, 'c, 'd, + WindowExpression( + AggregateExpression(Count('b), Complete, isDistinct = false), + WindowSpecDefinition( 'a :: Nil, + SortOrder('b, Ascending) :: Nil, + UnspecifiedFrame)).as('window)).select('a, 'c) + + val correctAnswer = + input.select('a, 'b, 'c) + .window(Seq.empty[NamedExpression], 'a :: Nil, 'b.asc :: Nil) + .select('a, 'c).analyze + + val optimized = Optimize.execute(originalQuery.analyze) + + comparePlans(optimized, correctAnswer) + } + test("Column pruning on Union") { val input1 = LocalRelation('a.int, 'b.string, 'c.double) val input2 = LocalRelation('c.int, 'd.string, 'e.double) From b229ea2cfb1dfacb78e1a19fbf3897fdb890c977 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Wed, 9 Mar 2016 16:52:46 -0800 Subject: [PATCH 07/15] style fix. --- .../spark/sql/catalyst/optimizer/ColumnPruningSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala index 7e5dc47a9776c..21693762e7992 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala @@ -23,8 +23,8 @@ import org.apache.spark.sql.catalyst.analysis import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder -import org.apache.spark.sql.catalyst.expressions.aggregate.{Complete, Count, AggregateExpression} import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, Complete, Count} import org.apache.spark.sql.catalyst.plans.{Inner, PlanTest} import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules.RuleExecutor From f0fbe78f2449f23d818ad2119049080b21c9b1b6 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Wed, 9 Mar 2016 18:17:08 -0800 Subject: [PATCH 08/15] address comments. --- .../spark/sql/catalyst/optimizer/Optimizer.scala | 14 +++++--------- 1 file changed, 5 insertions(+), 9 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 4e479c5ea1b46..ec68fddfb1e96 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -320,12 +320,16 @@ object ColumnPruning extends Rule[LogicalPlan] { output1.zip(output2).forall(pair => pair._1.semanticEquals(pair._2)) def apply(plan: LogicalPlan): LogicalPlan = plan transform { - // Prunes the unused columns from project list of Project/Aggregate/Expand + // Prunes the unused columns from project list of Project/Aggregate/Window/Expand case p @ Project(_, p2: Project) if (p2.outputSet -- p.references).nonEmpty => p.copy(child = p2.copy(projectList = p2.projectList.filter(p.references.contains))) case p @ Project(_, a: Aggregate) if (a.outputSet -- p.references).nonEmpty => p.copy( child = a.copy(aggregateExpressions = a.aggregateExpressions.filter(p.references.contains))) + case p @ Project(_, w: Window) + if (AttributeSet(w.windowExpressions.map(_.toAttribute)) -- p.references).nonEmpty => + p.copy(child = w.copy( + windowExpressions = w.windowExpressions.filter(p.references.contains))) case a @ Project(_, e @ Expand(_, _, grandChild)) if (e.outputSet -- a.references).nonEmpty => val newOutput = e.output.filter(a.references.contains(_)) val newProjects = e.projections.map { proj => @@ -381,14 +385,6 @@ object ColumnPruning extends Rule[LogicalPlan] { // Can't prune the columns on LeafNode case p @ Project(_, l: LeafNode) => p - // Prune windowExpressions and child of Window - case p @ Project(_, w: Window) - if (w.outputSet -- p.references).nonEmpty || - (w.child.inputSet -- (w.references ++ p.references)).nonEmpty => - val windowExpressions = w.windowExpressions.filter(p.references.contains) - val newChild = prunedChild(w.child, allReferences = w.references ++ p.references) - p.copy(child = w.copy(windowExpressions = windowExpressions, child = newChild)) - // for all other logical plans that inherits the output from it's children case p @ Project(_, child) => val required = child.references ++ p.references From f8fd37f21f428200f35fd7df0de8955b5a984bbc Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Wed, 9 Mar 2016 20:52:01 -0800 Subject: [PATCH 09/15] address comments. --- .../spark/sql/catalyst/optimizer/Optimizer.scala | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index ec68fddfb1e96..3fd65399bbeab 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -320,16 +320,12 @@ object ColumnPruning extends Rule[LogicalPlan] { output1.zip(output2).forall(pair => pair._1.semanticEquals(pair._2)) def apply(plan: LogicalPlan): LogicalPlan = plan transform { - // Prunes the unused columns from project list of Project/Aggregate/Window/Expand + // Prunes the unused columns from project list of Project/Aggregate/Expand case p @ Project(_, p2: Project) if (p2.outputSet -- p.references).nonEmpty => p.copy(child = p2.copy(projectList = p2.projectList.filter(p.references.contains))) case p @ Project(_, a: Aggregate) if (a.outputSet -- p.references).nonEmpty => p.copy( child = a.copy(aggregateExpressions = a.aggregateExpressions.filter(p.references.contains))) - case p @ Project(_, w: Window) - if (AttributeSet(w.windowExpressions.map(_.toAttribute)) -- p.references).nonEmpty => - p.copy(child = w.copy( - windowExpressions = w.windowExpressions.filter(p.references.contains))) case a @ Project(_, e @ Expand(_, _, grandChild)) if (e.outputSet -- a.references).nonEmpty => val newOutput = e.output.filter(a.references.contains(_)) val newProjects = e.projections.map { proj => @@ -385,6 +381,13 @@ object ColumnPruning extends Rule[LogicalPlan] { // Can't prune the columns on LeafNode case p @ Project(_, l: LeafNode) => p + case p @ Project(_, w: Window) if (w.outputSet -- p.references).nonEmpty => + val newWindowExprs = w.windowExpressions.filter(p.references.contains) + val newGrandChild = prunedChild(w.child, w.references ++ p.references) + p.copy(child = w.copy( + windowExpressions = newWindowExprs, + child = newGrandChild)) + // for all other logical plans that inherits the output from it's children case p @ Project(_, child) => val required = child.references ++ p.references From 4dd3e66bf5cb098ee2ea3d47f1cb8d59fddf39f6 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Wed, 9 Mar 2016 20:54:11 -0800 Subject: [PATCH 10/15] added a comment --- .../org/apache/spark/sql/catalyst/optimizer/Optimizer.scala | 1 + 1 file changed, 1 insertion(+) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 3fd65399bbeab..5b9112e9b17e4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -381,6 +381,7 @@ object ColumnPruning extends Rule[LogicalPlan] { // Can't prune the columns on LeafNode case p @ Project(_, l: LeafNode) => p + // Prune windowExpressions and child of Window case p @ Project(_, w: Window) if (w.outputSet -- p.references).nonEmpty => val newWindowExprs = w.windowExpressions.filter(p.references.contains) val newGrandChild = prunedChild(w.child, w.references ++ p.references) From 6cf6f444697c0bc32dbc09906fe144563a7d66df Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Wed, 9 Mar 2016 23:10:08 -0800 Subject: [PATCH 11/15] eliminate useless Aggregate and Window --- .../sql/catalyst/optimizer/Optimizer.scala | 26 +++++++++++++++---- .../optimizer/ColumnPruningSuite.scala | 15 ++++------- 2 files changed, 26 insertions(+), 15 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 5b9112e9b17e4..4d1443596bcbe 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -315,10 +315,6 @@ object SetOperationPushDown extends Rule[LogicalPlan] with PredicateHelper { * - LeftSemiJoin */ object ColumnPruning extends Rule[LogicalPlan] { - def sameOutput(output1: Seq[Attribute], output2: Seq[Attribute]): Boolean = - output1.size == output2.size && - output1.zip(output2).forall(pair => pair._1.semanticEquals(pair._2)) - def apply(plan: LogicalPlan): LogicalPlan = plan transform { // Prunes the unused columns from project list of Project/Aggregate/Expand case p @ Project(_, p2: Project) if (p2.outputSet -- p.references).nonEmpty => @@ -378,13 +374,21 @@ object ColumnPruning extends Rule[LogicalPlan] { // Eliminate no-op Projects case p @ Project(projectList, child) if sameOutput(child.output, p.output) => child + // Eliminate no-op Window + case w: Window if sameOutput(w.child.output, w.output) => w.child + + // Convert Aggregate to Project if no aggregate function exists + case a: Aggregate if !containsAggregates(a.expressions) => + Project(a.aggregateExpressions, a.child) + // Can't prune the columns on LeafNode case p @ Project(_, l: LeafNode) => p // Prune windowExpressions and child of Window case p @ Project(_, w: Window) if (w.outputSet -- p.references).nonEmpty => val newWindowExprs = w.windowExpressions.filter(p.references.contains) - val newGrandChild = prunedChild(w.child, w.references ++ p.references) + val newGrandChild = + prunedChild(w.child, p.references ++ AttributeSet(newWindowExprs.flatMap(_.references))) p.copy(child = w.copy( windowExpressions = newWindowExprs, child = newGrandChild)) @@ -407,6 +411,18 @@ object ColumnPruning extends Rule[LogicalPlan] { } else { c } + + private def sameOutput(output1: Seq[Attribute], output2: Seq[Attribute]): Boolean = + output1.size == output2.size && + output1.zip(output2).forall(pair => pair._1.semanticEquals(pair._2)) + + private def isAggregateExpression(e: Expression): Boolean = { + e.isInstanceOf[AggregateExpression] || e.isInstanceOf[Grouping] || e.isInstanceOf[GroupingID] + } + + private def containsAggregates(exprs: Seq[Expression]): Boolean = { + exprs.exists(_.find(isAggregateExpression).nonEmpty) + } } /** diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala index 898623cc4ca22..37b76df9fc10f 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala @@ -184,8 +184,7 @@ class ColumnPruningSuite extends PlanTest { val optimized = Optimize.execute(originalQuery.analyze) val correctAnswer = testRelation - .select('a) - .groupBy('a)('a).analyze + .select('a).analyze comparePlans(optimized, correctAnswer) } @@ -202,7 +201,7 @@ class ColumnPruningSuite extends PlanTest { val correctAnswer = testRelation .select('a) - .groupBy('a)('a as 'c).analyze + .select('a as 'c).analyze comparePlans(optimized, correctAnswer) } @@ -263,7 +262,7 @@ class ColumnPruningSuite extends PlanTest { val input = LocalRelation('a.int, 'b.string, 'c.double, 'd.int) val originalQuery = - input.groupBy('a)('a, 'b, 'c, 'd, + input.groupBy('a, 'c, 'd)('a, 'c, 'd, WindowExpression( AggregateExpression(Count('b), Complete, isDistinct = false), WindowSpecDefinition( 'a :: Nil, @@ -271,9 +270,7 @@ class ColumnPruningSuite extends PlanTest { UnspecifiedFrame)).as('window)).select('a, 'c) val correctAnswer = - input.select('a, 'b, 'c).groupBy('a)('a, 'b, 'c) - .window(Seq.empty[NamedExpression], 'a :: Nil, 'b.asc :: Nil) - .select('a, 'c).analyze + input.select('a, 'c).analyze val optimized = Optimize.execute(originalQuery.analyze) @@ -319,9 +316,7 @@ class ColumnPruningSuite extends PlanTest { UnspecifiedFrame)).as('window)).select('a, 'c) val correctAnswer = - input.select('a, 'b, 'c) - .window(Seq.empty[NamedExpression], 'a :: Nil, 'b.asc :: Nil) - .select('a, 'c).analyze + input.select('a, 'c).analyze val optimized = Optimize.execute(originalQuery.analyze) From 44326f11dab33d686885b7f038a1034c1723338e Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Wed, 9 Mar 2016 23:40:41 -0800 Subject: [PATCH 12/15] remove aggregate replacement. --- .../sql/catalyst/optimizer/Optimizer.scala | 22 +++++-------------- .../optimizer/ColumnPruningSuite.scala | 7 +++--- 2 files changed, 9 insertions(+), 20 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 4d1443596bcbe..0e82aa0b39ee2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -315,6 +315,10 @@ object SetOperationPushDown extends Rule[LogicalPlan] with PredicateHelper { * - LeftSemiJoin */ object ColumnPruning extends Rule[LogicalPlan] { + private def sameOutput(output1: Seq[Attribute], output2: Seq[Attribute]): Boolean = + output1.size == output2.size && + output1.zip(output2).forall(pair => pair._1.semanticEquals(pair._2)) + def apply(plan: LogicalPlan): LogicalPlan = plan transform { // Prunes the unused columns from project list of Project/Aggregate/Expand case p @ Project(_, p2: Project) if (p2.outputSet -- p.references).nonEmpty => @@ -375,11 +379,7 @@ object ColumnPruning extends Rule[LogicalPlan] { case p @ Project(projectList, child) if sameOutput(child.output, p.output) => child // Eliminate no-op Window - case w: Window if sameOutput(w.child.output, w.output) => w.child - - // Convert Aggregate to Project if no aggregate function exists - case a: Aggregate if !containsAggregates(a.expressions) => - Project(a.aggregateExpressions, a.child) + case w: Window if w.windowExpressions.isEmpty => w.child // Can't prune the columns on LeafNode case p @ Project(_, l: LeafNode) => p @@ -411,18 +411,6 @@ object ColumnPruning extends Rule[LogicalPlan] { } else { c } - - private def sameOutput(output1: Seq[Attribute], output2: Seq[Attribute]): Boolean = - output1.size == output2.size && - output1.zip(output2).forall(pair => pair._1.semanticEquals(pair._2)) - - private def isAggregateExpression(e: Expression): Boolean = { - e.isInstanceOf[AggregateExpression] || e.isInstanceOf[Grouping] || e.isInstanceOf[GroupingID] - } - - private def containsAggregates(exprs: Seq[Expression]): Boolean = { - exprs.exists(_.find(isAggregateExpression).nonEmpty) - } } /** diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala index 37b76df9fc10f..2020c48effbc4 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala @@ -184,7 +184,8 @@ class ColumnPruningSuite extends PlanTest { val optimized = Optimize.execute(originalQuery.analyze) val correctAnswer = testRelation - .select('a).analyze + .select('a) + .groupBy('a)('a).analyze comparePlans(optimized, correctAnswer) } @@ -201,7 +202,7 @@ class ColumnPruningSuite extends PlanTest { val correctAnswer = testRelation .select('a) - .select('a as 'c).analyze + .groupBy('a)('a as 'c).analyze comparePlans(optimized, correctAnswer) } @@ -270,7 +271,7 @@ class ColumnPruningSuite extends PlanTest { UnspecifiedFrame)).as('window)).select('a, 'c) val correctAnswer = - input.select('a, 'c).analyze + input.select('a, 'c, 'd).groupBy('a, 'c, 'd)('a, 'c).analyze val optimized = Optimize.execute(originalQuery.analyze) From fc96d84ac1ef6bec70e1af911c2cef0447f1514f Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Thu, 10 Mar 2016 08:33:15 -0800 Subject: [PATCH 13/15] address comments. --- .../sql/catalyst/optimizer/ColumnPruningSuite.scala | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala index 2020c48effbc4..dd7d65ddc9e96 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala @@ -34,7 +34,8 @@ class ColumnPruningSuite extends PlanTest { object Optimize extends RuleExecutor[LogicalPlan] { val batches = Batch("Column pruning", FixedPoint(100), - ColumnPruning) :: Nil + ColumnPruning, + CollapseProject) :: Nil } test("Column pruning for Generate when Generate.join = false") { @@ -270,8 +271,7 @@ class ColumnPruningSuite extends PlanTest { SortOrder('b, Ascending) :: Nil, UnspecifiedFrame)).as('window)).select('a, 'c) - val correctAnswer = - input.select('a, 'c, 'd).groupBy('a, 'c, 'd)('a, 'c).analyze + val correctAnswer = input.select('a, 'c, 'd).groupBy('a, 'c, 'd)('a, 'c).analyze val optimized = Optimize.execute(originalQuery.analyze) @@ -297,7 +297,6 @@ class ColumnPruningSuite extends PlanTest { SortOrder('b, Ascending) :: Nil, UnspecifiedFrame)).as('window) :: Nil, 'a :: Nil, 'b.asc :: Nil) - .select('a, 'c, 'window).select('a, 'c, 'window, 'window) .select('a, 'c, 'window).where('window > 1).select('a, 'c).analyze val optimized = Optimize.execute(originalQuery.analyze) @@ -316,8 +315,7 @@ class ColumnPruningSuite extends PlanTest { SortOrder('b, Ascending) :: Nil, UnspecifiedFrame)).as('window)).select('a, 'c) - val correctAnswer = - input.select('a, 'c).analyze + val correctAnswer = input.select('a, 'c).analyze val optimized = Optimize.execute(originalQuery.analyze) From 6a59b4240eaf1c63c6ad2dcb3ce6876ef9f2d189 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Thu, 10 Mar 2016 17:58:31 -0800 Subject: [PATCH 14/15] address comments. --- .../spark/sql/catalyst/optimizer/Optimizer.scala | 14 ++++---------- .../catalyst/plans/logical/basicOperators.scala | 2 ++ 2 files changed, 6 insertions(+), 10 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 0e82aa0b39ee2..634c8dcff138b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -320,12 +320,15 @@ object ColumnPruning extends Rule[LogicalPlan] { output1.zip(output2).forall(pair => pair._1.semanticEquals(pair._2)) def apply(plan: LogicalPlan): LogicalPlan = plan transform { - // Prunes the unused columns from project list of Project/Aggregate/Expand + // Prunes the unused columns from project list of Project/Aggregate/Window/Expand case p @ Project(_, p2: Project) if (p2.outputSet -- p.references).nonEmpty => p.copy(child = p2.copy(projectList = p2.projectList.filter(p.references.contains))) case p @ Project(_, a: Aggregate) if (a.outputSet -- p.references).nonEmpty => p.copy( child = a.copy(aggregateExpressions = a.aggregateExpressions.filter(p.references.contains))) + case p @ Project(_, w: Window) if (w.windowOutputSet -- p.references).nonEmpty => + p.copy(child = w.copy( + windowExpressions = w.windowExpressions.filter(p.references.contains))) case a @ Project(_, e @ Expand(_, _, grandChild)) if (e.outputSet -- a.references).nonEmpty => val newOutput = e.output.filter(a.references.contains(_)) val newProjects = e.projections.map { proj => @@ -384,15 +387,6 @@ object ColumnPruning extends Rule[LogicalPlan] { // Can't prune the columns on LeafNode case p @ Project(_, l: LeafNode) => p - // Prune windowExpressions and child of Window - case p @ Project(_, w: Window) if (w.outputSet -- p.references).nonEmpty => - val newWindowExprs = w.windowExpressions.filter(p.references.contains) - val newGrandChild = - prunedChild(w.child, p.references ++ AttributeSet(newWindowExprs.flatMap(_.references))) - p.copy(child = w.copy( - windowExpressions = newWindowExprs, - child = newGrandChild)) - // for all other logical plans that inherits the output from it's children case p @ Project(_, child) => val required = child.references ++ p.references diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala index 86383538f6f07..09ea3fea6a694 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala @@ -441,6 +441,8 @@ case class Window( override def output: Seq[Attribute] = child.output ++ windowExpressions.map(_.toAttribute) + + def windowOutputSet: AttributeSet = AttributeSet(windowExpressions.map(_.toAttribute)) } private[sql] object Expand { From bd35ee7144a9b88281f189748fd044df7971dbf3 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Thu, 10 Mar 2016 18:25:17 -0800 Subject: [PATCH 15/15] address comments. --- .../spark/sql/catalyst/optimizer/Optimizer.scala | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 634c8dcff138b..85776670e5c4e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -320,15 +320,12 @@ object ColumnPruning extends Rule[LogicalPlan] { output1.zip(output2).forall(pair => pair._1.semanticEquals(pair._2)) def apply(plan: LogicalPlan): LogicalPlan = plan transform { - // Prunes the unused columns from project list of Project/Aggregate/Window/Expand + // Prunes the unused columns from project list of Project/Aggregate/Expand case p @ Project(_, p2: Project) if (p2.outputSet -- p.references).nonEmpty => p.copy(child = p2.copy(projectList = p2.projectList.filter(p.references.contains))) case p @ Project(_, a: Aggregate) if (a.outputSet -- p.references).nonEmpty => p.copy( child = a.copy(aggregateExpressions = a.aggregateExpressions.filter(p.references.contains))) - case p @ Project(_, w: Window) if (w.windowOutputSet -- p.references).nonEmpty => - p.copy(child = w.copy( - windowExpressions = w.windowExpressions.filter(p.references.contains))) case a @ Project(_, e @ Expand(_, _, grandChild)) if (e.outputSet -- a.references).nonEmpty => val newOutput = e.output.filter(a.references.contains(_)) val newProjects = e.projections.map { proj => @@ -378,12 +375,17 @@ object ColumnPruning extends Rule[LogicalPlan] { p } - // Eliminate no-op Projects - case p @ Project(projectList, child) if sameOutput(child.output, p.output) => child + // Prune unnecessary window expressions + case p @ Project(_, w: Window) if (w.windowOutputSet -- p.references).nonEmpty => + p.copy(child = w.copy( + windowExpressions = w.windowExpressions.filter(p.references.contains))) // Eliminate no-op Window case w: Window if w.windowExpressions.isEmpty => w.child + // Eliminate no-op Projects + case p @ Project(projectList, child) if sameOutput(child.output, p.output) => child + // Can't prune the columns on LeafNode case p @ Project(_, l: LeafNode) => p