From 6e4eba6c9c637f9af11dae17a4ee2f1b39ee00be Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Mon, 6 Mar 2017 08:56:38 +0000 Subject: [PATCH 1/7] Rewrite physical Project operator's output partitioning and ordering to ensure no unnecessary shuffle/sort in Datasets. --- .../catalyst/expressions/Canonicalize.scala | 8 +++- .../execution/basicPhysicalOperators.scala | 37 ++++++++++++++++- .../columnar/InMemoryTableScanExec.scala | 30 ++++++++++++-- .../org/apache/spark/sql/DatasetSuite.scala | 41 +++++++++++++++++++ 4 files changed, 110 insertions(+), 6 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Canonicalize.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Canonicalize.scala index 65e497afc12cd..ea24c2fe9b12e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Canonicalize.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Canonicalize.scala @@ -33,7 +33,7 @@ package org.apache.spark.sql.catalyst.expressions */ object Canonicalize extends { def execute(e: Expression): Expression = { - expressionReorder(ignoreNamesTypes(e)) + expressionReorder(ignoreParameters(ignoreNamesTypes(e))) } /** Remove names and nullability from types. */ @@ -43,6 +43,12 @@ object Canonicalize extends { case _ => e } + /** Remove some unnecessary parameters. */ + private[expressions] def ignoreParameters(e: Expression): Expression = e match { + case GetStructField(child, ordinal, _) => GetStructField(child, ordinal, None) + case _ => e + } + /** Collects adjacent commutative operations. */ private def gatherCommutative( e: Expression, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala index 87e90ed685cca..328d725c5f5a2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala @@ -78,9 +78,42 @@ case class ProjectExec(projectList: Seq[NamedExpression], child: SparkPlan) } } - override def outputOrdering: Seq[SortOrder] = child.outputOrdering + /** + * If there are any `CreateNamedStructLike` expressions in the projection, keep value expressions + * and corresponding extractors. We will use those extractors to rewrite `outputPartitioning` and + * `outputOrdering`. + */ + private lazy val valExprsAndExtractors = projectList.collect { + case a @ Alias(ns: CreateNamedStructLike, _) => (a.toAttribute, ns) + }.flatMap { case (attr, ns) => + ns.valExprs.zipWithIndex.map { case (valExpr, ordinal) => + (valExpr, GetStructField(attr, ordinal, Some(ns.names(ordinal).toString))) + } + } - override def outputPartitioning: Partitioning = child.outputPartitioning + override def outputOrdering: Seq[SortOrder] = { + child.outputOrdering.map { sortOrder => + val newSortExpr = sortOrder.child.transform { + case required: Expression => + val found = valExprsAndExtractors.find(_._1.semanticEquals(required)) + found.map(_._2).getOrElse(required) + } + SortOrder(newSortExpr, sortOrder.direction, sortOrder.nullOrdering) + } + } + + override def outputPartitioning: Partitioning = { + child.outputPartitioning match { + case HashPartitioning(requiredClustering, numPartitions) => + val newRequiredClustering = requiredClustering.map(_.transform { + case required: Expression => + val found = valExprsAndExtractors.find(_._1.semanticEquals(required)) + found.map(_._2).getOrElse(required) + }) + HashPartitioning(newRequiredClustering, numPartitions) + case _ => child.outputPartitioning + } + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala index 9028caa446e8c..f405539046150 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala @@ -22,7 +22,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.QueryPlan -import org.apache.spark.sql.catalyst.plans.physical.Partitioning +import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, Partitioning} import org.apache.spark.sql.execution.LeafExecNode import org.apache.spark.sql.execution.metric.SQLMetrics import org.apache.spark.sql.types.UserDefinedType @@ -42,10 +42,34 @@ case class InMemoryTableScanExec( override def output: Seq[Attribute] = attributes // The cached version does not change the outputPartitioning of the original SparkPlan. - override def outputPartitioning: Partitioning = relation.child.outputPartitioning + // But the cached version could alias output, so we need to replace output. + override def outputPartitioning: Partitioning = { + val attrMap = AttributeMap( + relation.child.output.zip(output) + ) + relation.child.outputPartitioning match { + case HashPartitioning(expressions, numPartitions) => + val newExprs = expressions.map(_.transform { + case attr: Attribute if attrMap.contains(attr) => attrMap.get(attr).get + }) + HashPartitioning(newExprs, numPartitions) + case _ => relation.child.outputPartitioning + } + } // The cached version does not change the outputOrdering of the original SparkPlan. - override def outputOrdering: Seq[SortOrder] = relation.child.outputOrdering + // But the cached version could alias output, so we need to replace output. + override def outputOrdering: Seq[SortOrder] = { + val attrMap = AttributeMap( + relation.child.output.zip(output) + ) + relation.child.outputOrdering.map { sortOrder => + val newSortExpr = sortOrder.child.transform { + case attr: Attribute if attrMap.contains(attr) => attrMap.get(attr).get + } + SortOrder(newSortExpr, sortOrder.direction, sortOrder.nullOrdering) + } + } private def statsFor(a: Attribute) = relation.partitionStatistics.forAttribute(a) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index b37bf131e8dce..76b9df9bf90cc 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -23,11 +23,13 @@ import java.sql.{Date, Timestamp} import org.apache.spark.sql.catalyst.encoders.{OuterScopes, RowEncoder} import org.apache.spark.sql.catalyst.util.sideBySide import org.apache.spark.sql.execution.{LogicalRDD, RDDScanExec, SortExec} +import org.apache.spark.sql.execution.columnar.InMemoryTableScanExec import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ShuffleExchange} import org.apache.spark.sql.execution.streaming.MemoryStream import org.apache.spark.sql.functions._ import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ +import org.apache.spark.storage.StorageLevel case class TestDataPoint(x: Int, y: Double, s: String, t: TestDataPoint2) case class TestDataPoint2(x: Int, s: String) @@ -1136,6 +1138,45 @@ class DatasetSuite extends QueryTest with SharedSQLContext { assert(spark.range(1).map { x => new java.sql.Timestamp(100000) }.head == new java.sql.Timestamp(100000)) } + + test("No unnecessary shuffles and sort when datasets are well partitioned and sorted") { + withSQLConf("spark.sql.autoBroadcastJoinThreshold" -> "-1") { + val ds1 = Seq((0, 0), (1, 1)).toDS + .repartition(col("_1")).sortWithinPartitions(col("_1")).persist(StorageLevel.DISK_ONLY) + val ds2 = Seq((0, 0), (1, 1)).toDS + .repartition(col("_1")).sortWithinPartitions(col("_1")).persist(StorageLevel.DISK_ONLY) + val joined = ds1.joinWith(ds2, ds1("_1") === ds2("_1")) + + checkAnswer( + joined.toDF(), + Row(Row(0, 0), Row(0, 0)) :: Row(Row(1, 1), Row(1, 1)) :: Nil) + + val shuffles = joined.queryExecution.executedPlan.collect { + case s: ShuffleExchange => s + } + val sorts = joined.queryExecution.executedPlan.collect { + case s: SortExec => s + } + assert(shuffles.length == 0) + assert(sorts.length == 0) + + val memoryRelations = joined.queryExecution.executedPlan.collect { + case mem: InMemoryTableScanExec => mem.relation.child + } + val shufflesInCached = memoryRelations.flatMap { plan => + plan.collect { + case s: ShuffleExchange => s + } + } + val sortsInCached = memoryRelations.flatMap { plan => + plan.collect { + case s: SortExec => s + } + } + assert(shufflesInCached.length == 2) + assert(sortsInCached.length == 2) + } + } } case class WithImmutableMap(id: String, map_test: scala.collection.immutable.Map[Long, String]) From c83919e2e4d94f78862a91916cfadf6f5ce6f575 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Mon, 13 Mar 2017 03:32:56 +0000 Subject: [PATCH 2/7] Keep the change of InMemoryTableScanExec. --- .../catalyst/expressions/Canonicalize.scala | 8 +--- .../execution/basicPhysicalOperators.scala | 37 +---------------- .../org/apache/spark/sql/DatasetSuite.scala | 41 ------------------- .../columnar/InMemoryColumnarQuerySuite.scala | 22 ++++++++++ 4 files changed, 25 insertions(+), 83 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Canonicalize.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Canonicalize.scala index ea24c2fe9b12e..65e497afc12cd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Canonicalize.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Canonicalize.scala @@ -33,7 +33,7 @@ package org.apache.spark.sql.catalyst.expressions */ object Canonicalize extends { def execute(e: Expression): Expression = { - expressionReorder(ignoreParameters(ignoreNamesTypes(e))) + expressionReorder(ignoreNamesTypes(e)) } /** Remove names and nullability from types. */ @@ -43,12 +43,6 @@ object Canonicalize extends { case _ => e } - /** Remove some unnecessary parameters. */ - private[expressions] def ignoreParameters(e: Expression): Expression = e match { - case GetStructField(child, ordinal, _) => GetStructField(child, ordinal, None) - case _ => e - } - /** Collects adjacent commutative operations. */ private def gatherCommutative( e: Expression, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala index 328d725c5f5a2..87e90ed685cca 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala @@ -78,42 +78,9 @@ case class ProjectExec(projectList: Seq[NamedExpression], child: SparkPlan) } } - /** - * If there are any `CreateNamedStructLike` expressions in the projection, keep value expressions - * and corresponding extractors. We will use those extractors to rewrite `outputPartitioning` and - * `outputOrdering`. - */ - private lazy val valExprsAndExtractors = projectList.collect { - case a @ Alias(ns: CreateNamedStructLike, _) => (a.toAttribute, ns) - }.flatMap { case (attr, ns) => - ns.valExprs.zipWithIndex.map { case (valExpr, ordinal) => - (valExpr, GetStructField(attr, ordinal, Some(ns.names(ordinal).toString))) - } - } - - override def outputOrdering: Seq[SortOrder] = { - child.outputOrdering.map { sortOrder => - val newSortExpr = sortOrder.child.transform { - case required: Expression => - val found = valExprsAndExtractors.find(_._1.semanticEquals(required)) - found.map(_._2).getOrElse(required) - } - SortOrder(newSortExpr, sortOrder.direction, sortOrder.nullOrdering) - } - } + override def outputOrdering: Seq[SortOrder] = child.outputOrdering - override def outputPartitioning: Partitioning = { - child.outputPartitioning match { - case HashPartitioning(requiredClustering, numPartitions) => - val newRequiredClustering = requiredClustering.map(_.transform { - case required: Expression => - val found = valExprsAndExtractors.find(_._1.semanticEquals(required)) - found.map(_._2).getOrElse(required) - }) - HashPartitioning(newRequiredClustering, numPartitions) - case _ => child.outputPartitioning - } - } + override def outputPartitioning: Partitioning = child.outputPartitioning } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index 76b9df9bf90cc..b37bf131e8dce 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -23,13 +23,11 @@ import java.sql.{Date, Timestamp} import org.apache.spark.sql.catalyst.encoders.{OuterScopes, RowEncoder} import org.apache.spark.sql.catalyst.util.sideBySide import org.apache.spark.sql.execution.{LogicalRDD, RDDScanExec, SortExec} -import org.apache.spark.sql.execution.columnar.InMemoryTableScanExec import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ShuffleExchange} import org.apache.spark.sql.execution.streaming.MemoryStream import org.apache.spark.sql.functions._ import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ -import org.apache.spark.storage.StorageLevel case class TestDataPoint(x: Int, y: Double, s: String, t: TestDataPoint2) case class TestDataPoint2(x: Int, s: String) @@ -1138,45 +1136,6 @@ class DatasetSuite extends QueryTest with SharedSQLContext { assert(spark.range(1).map { x => new java.sql.Timestamp(100000) }.head == new java.sql.Timestamp(100000)) } - - test("No unnecessary shuffles and sort when datasets are well partitioned and sorted") { - withSQLConf("spark.sql.autoBroadcastJoinThreshold" -> "-1") { - val ds1 = Seq((0, 0), (1, 1)).toDS - .repartition(col("_1")).sortWithinPartitions(col("_1")).persist(StorageLevel.DISK_ONLY) - val ds2 = Seq((0, 0), (1, 1)).toDS - .repartition(col("_1")).sortWithinPartitions(col("_1")).persist(StorageLevel.DISK_ONLY) - val joined = ds1.joinWith(ds2, ds1("_1") === ds2("_1")) - - checkAnswer( - joined.toDF(), - Row(Row(0, 0), Row(0, 0)) :: Row(Row(1, 1), Row(1, 1)) :: Nil) - - val shuffles = joined.queryExecution.executedPlan.collect { - case s: ShuffleExchange => s - } - val sorts = joined.queryExecution.executedPlan.collect { - case s: SortExec => s - } - assert(shuffles.length == 0) - assert(sorts.length == 0) - - val memoryRelations = joined.queryExecution.executedPlan.collect { - case mem: InMemoryTableScanExec => mem.relation.child - } - val shufflesInCached = memoryRelations.flatMap { plan => - plan.collect { - case s: ShuffleExchange => s - } - } - val sortsInCached = memoryRelations.flatMap { plan => - plan.collect { - case s: SortExec => s - } - } - assert(shufflesInCached.length == 2) - assert(sortsInCached.length == 2) - } - } } case class WithImmutableMap(id: String, map_test: scala.collection.immutable.Map[Long, String]) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala index f355a5200ce2f..e1450720e5efd 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala @@ -21,6 +21,9 @@ import java.nio.charset.StandardCharsets import java.sql.{Date, Timestamp} import org.apache.spark.sql.{DataFrame, QueryTest, Row} +import org.apache.spark.sql.catalyst.expressions.AttributeSet +import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning +import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.test.SQLTestData._ @@ -390,4 +393,23 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext { } } + test("InMemoryTableScanExec should return currect output ordering and partitioning") { + val ds1 = Seq((0, 0), (1, 1)).toDS + .repartition(col("_1")).sortWithinPartitions(col("_1")).persist + val ds2 = Seq((0, 0), (1, 1)).toDS + .repartition(col("_1")).sortWithinPartitions(col("_1")).persist + val joined = ds1.joinWith(ds2, ds1("_1") === ds2("_1")) + + val inMemoryScans = joined.queryExecution.executedPlan.collect { + case m: InMemoryTableScanExec => m + } + inMemoryScans.foreach { inMemoryScan => + val sortedAttrs = AttributeSet(inMemoryScan.outputOrdering.flatMap(_.references)) + assert(sortedAttrs.subsetOf(inMemoryScan.outputSet)) + + val partitionedAttrs = + inMemoryScan.outputPartitioning.asInstanceOf[HashPartitioning].references + assert(partitionedAttrs.subsetOf(inMemoryScan.outputSet)) + } + } } From 454515d4531d649f65bbce04b0141fa578a89bc3 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Mon, 13 Mar 2017 06:13:58 +0000 Subject: [PATCH 3/7] Refactor the change. --- .../columnar/InMemoryTableScanExec.scala | 18 +++++++----------- 1 file changed, 7 insertions(+), 11 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala index f405539046150..09f6db2db2a22 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala @@ -41,6 +41,11 @@ case class InMemoryTableScanExec( override def output: Seq[Attribute] = attributes + private def updateAttribute(expr: Expression, attrMap: AttributeMap[Attribute]): Expression = + expr.transform { + case attr: Attribute if attrMap.contains(attr) => attrMap.get(attr).get + } + // The cached version does not change the outputPartitioning of the original SparkPlan. // But the cached version could alias output, so we need to replace output. override def outputPartitioning: Partitioning = { @@ -48,11 +53,7 @@ case class InMemoryTableScanExec( relation.child.output.zip(output) ) relation.child.outputPartitioning match { - case HashPartitioning(expressions, numPartitions) => - val newExprs = expressions.map(_.transform { - case attr: Attribute if attrMap.contains(attr) => attrMap.get(attr).get - }) - HashPartitioning(newExprs, numPartitions) + case h: HashPartitioning => updateAttribute(h, attrMap).asInstanceOf[HashPartitioning] case _ => relation.child.outputPartitioning } } @@ -63,12 +64,7 @@ case class InMemoryTableScanExec( val attrMap = AttributeMap( relation.child.output.zip(output) ) - relation.child.outputOrdering.map { sortOrder => - val newSortExpr = sortOrder.child.transform { - case attr: Attribute if attrMap.contains(attr) => attrMap.get(attr).get - } - SortOrder(newSortExpr, sortOrder.direction, sortOrder.nullOrdering) - } + relation.child.outputOrdering.map(updateAttribute(_, attrMap).asInstanceOf[SortOrder]) } private def statsFor(a: Attribute) = relation.partitionStatistics.forAttribute(a) From 1fc023cd081409994ad1dc9d32677f4d5c1191f1 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Tue, 14 Mar 2017 04:30:26 +0000 Subject: [PATCH 4/7] Move attrMap into updateAttribute. --- .../columnar/InMemoryTableScanExec.scala | 19 ++++++++----------- 1 file changed, 8 insertions(+), 11 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala index 09f6db2db2a22..578ba31b9ab99 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala @@ -41,31 +41,28 @@ case class InMemoryTableScanExec( override def output: Seq[Attribute] = attributes - private def updateAttribute(expr: Expression, attrMap: AttributeMap[Attribute]): Expression = + private def updateAttribute(expr: Expression): Expression = { + val attrMap = AttributeMap( + relation.child.output.zip(output) + ) expr.transform { case attr: Attribute if attrMap.contains(attr) => attrMap.get(attr).get } + } // The cached version does not change the outputPartitioning of the original SparkPlan. // But the cached version could alias output, so we need to replace output. override def outputPartitioning: Partitioning = { - val attrMap = AttributeMap( - relation.child.output.zip(output) - ) relation.child.outputPartitioning match { - case h: HashPartitioning => updateAttribute(h, attrMap).asInstanceOf[HashPartitioning] + case h: HashPartitioning => updateAttribute(h).asInstanceOf[HashPartitioning] case _ => relation.child.outputPartitioning } } // The cached version does not change the outputOrdering of the original SparkPlan. // But the cached version could alias output, so we need to replace output. - override def outputOrdering: Seq[SortOrder] = { - val attrMap = AttributeMap( - relation.child.output.zip(output) - ) - relation.child.outputOrdering.map(updateAttribute(_, attrMap).asInstanceOf[SortOrder]) - } + override def outputOrdering: Seq[SortOrder] = + relation.child.outputOrdering.map(updateAttribute(_).asInstanceOf[SortOrder]) private def statsFor(a: Attribute) = relation.partitionStatistics.forAttribute(a) From ef918da32c67d1230944ec6059fec44dddfb2afe Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Tue, 14 Mar 2017 13:05:04 +0000 Subject: [PATCH 5/7] Address comment. --- .../spark/sql/execution/columnar/InMemoryTableScanExec.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala index 578ba31b9ab99..2bc12940182f8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala @@ -46,7 +46,7 @@ case class InMemoryTableScanExec( relation.child.output.zip(output) ) expr.transform { - case attr: Attribute if attrMap.contains(attr) => attrMap.get(attr).get + case attr: Attribute => attrMap.getOrElse(attr, attr) } } From b4d5d0f19db08447dfd33c0189dd87a7c609a6da Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Wed, 15 Mar 2017 07:05:06 +0000 Subject: [PATCH 6/7] For comment. --- .../spark/sql/execution/columnar/InMemoryTableScanExec.scala | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala index 2bc12940182f8..214e8d309de11 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala @@ -42,9 +42,7 @@ case class InMemoryTableScanExec( override def output: Seq[Attribute] = attributes private def updateAttribute(expr: Expression): Expression = { - val attrMap = AttributeMap( - relation.child.output.zip(output) - ) + val attrMap = AttributeMap(relation.child.output.zip(output)) expr.transform { case attr: Attribute => attrMap.getOrElse(attr, attr) } From b25156f7255f5ef814fc20761a52c46426174648 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Wed, 15 Mar 2017 07:35:44 +0000 Subject: [PATCH 7/7] Use DataFrame and add comment to test. --- .../columnar/InMemoryColumnarQuerySuite.scala | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala index 68afe3e14e61c..1e6a6a8ba3362 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala @@ -391,12 +391,16 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext { } } - test("InMemoryTableScanExec should return currect output ordering and partitioning") { - val ds1 = Seq((0, 0), (1, 1)).toDS + test("InMemoryTableScanExec should return correct output ordering and partitioning") { + val df1 = Seq((0, 0), (1, 1)).toDF .repartition(col("_1")).sortWithinPartitions(col("_1")).persist - val ds2 = Seq((0, 0), (1, 1)).toDS + val df2 = Seq((0, 0), (1, 1)).toDF .repartition(col("_1")).sortWithinPartitions(col("_1")).persist - val joined = ds1.joinWith(ds2, ds1("_1") === ds2("_1")) + + // Because two cached dataframes have the same logical plan, this is a self-join actually. + // So we force one of in-memory relation to alias its output. Then we can test if original and + // aliased in-memory relations have correct ordering and partitioning. + val joined = df1.joinWith(df2, df1("_1") === df2("_1")) val inMemoryScans = joined.queryExecution.executedPlan.collect { case m: InMemoryTableScanExec => m