diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 2aeb9575f1dd2..55adc06320a58 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -331,7 +331,10 @@ object ColumnPruning extends Rule[LogicalPlan] { }.unzip._1 } a.copy(child = Expand(newProjects, newOutput, grandChild)) - // TODO: support some logical plan for Dataset + + // Prunes the unused columns from child of MapPartitions + case mp @ MapPartitions(_, _, _, child) if (child.outputSet -- mp.references).nonEmpty => + mp.copy(child = prunedChild(child, mp.references)) // Prunes the unused columns from child of Aggregate/Window/Expand/Generate case a @ Aggregate(_, _, child) if (child.outputSet -- a.references).nonEmpty => diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala index 715d01a3cd876..5cab1fc95a364 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala @@ -17,9 +17,12 @@ package org.apache.spark.sql.catalyst.optimizer +import scala.reflect.runtime.universe.TypeTag + import org.apache.spark.sql.catalyst.analysis import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ +import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.catalyst.expressions.{Ascending, Explode, Literal, SortOrder} import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.logical._ @@ -249,5 +252,16 @@ class ColumnPruningSuite extends PlanTest { comparePlans(Optimize.execute(query), expected) } + implicit private def productEncoder[T <: Product : TypeTag] = ExpressionEncoder[T]() + private val func = identity[Iterator[OtherTuple]] _ + + test("Column pruning on MapPartitions") { + val input = LocalRelation('_1.int, '_2.int, 'c.int) + val plan1 = MapPartitions(func, input) + val correctAnswer1 = + MapPartitions(func, Project(Seq('_1, '_2), input)).analyze + comparePlans(Optimize.execute(plan1.analyze), correctAnswer1) + } + // todo: add more tests for column pruning } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index 33df6375e3aad..79e10215f4d3f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -113,7 +113,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext { ("a", 2), ("b", 3), ("c", 4)) } - test("map with type change") { + test("map with type change with the exact matched number of attributes") { val ds = Seq(("a", 1), ("b", 2), ("c", 3)).toDS() checkAnswer( @@ -123,6 +123,15 @@ class DatasetSuite extends QueryTest with SharedSQLContext { OtherTuple("a", 1), OtherTuple("b", 2), OtherTuple("c", 3)) } + test("map with type change with less attributes") { + val ds = Seq(("a", 1, 3), ("b", 2, 4), ("c", 3, 5)).toDS() + + checkAnswer( + ds.as[OtherTuple] + .map(identity[OtherTuple]), + OtherTuple("a", 1), OtherTuple("b", 2), OtherTuple("c", 3)) + } + test("map and group by with class data") { // We inject a group by here to make sure this test case is future proof // when we implement better pipelining and local execution mode.