From e1f28e2c6c9ba99c92fea339946323c9490062d4 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Thu, 15 Mar 2018 06:16:22 +0000 Subject: [PATCH 1/4] Fix incorrect reuse exchange when caching is used. --- .../sql/execution/columnar/InMemoryRelation.scala | 10 ++++++++++ .../sql/execution/columnar/InMemoryTableScanExec.scala | 7 ++++++- .../test/scala/org/apache/spark/sql/DatasetSuite.scala | 9 +++++++++ .../org/apache/spark/sql/execution/ExchangeSuite.scala | 7 +++++++ 4 files changed, 32 insertions(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala index 22e16913d4da9..4729daf26404d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala @@ -24,6 +24,7 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.QueryPlan import org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.catalyst.plans.logical.{HintInfo, Statistics} import org.apache.spark.sql.execution.SparkPlan @@ -68,6 +69,15 @@ case class InMemoryRelation( override protected def innerChildren: Seq[SparkPlan] = Seq(child) + override def doCanonicalize(): logical.LogicalPlan = + copy(output = output.map(QueryPlan.normalizeExprId(_, child.output)), + storageLevel = new StorageLevel(), + child = child.canonicalized, + tableName = None)( + _cachedColumnBuffers, + sizeInBytesStats, + statsOfPlanToCache) + override def producedAttributes: AttributeSet = outputSet @transient val partitionStatistics = new PartitionStatistics(output) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala index a93e8a1ad954d..62cc47c121139 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala @@ -24,7 +24,7 @@ import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.QueryPlan import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, Partitioning} -import org.apache.spark.sql.execution.{ColumnarBatchScan, LeafExecNode, WholeStageCodegenExec} +import org.apache.spark.sql.execution.{ColumnarBatchScan, LeafExecNode, SparkPlan, WholeStageCodegenExec} import org.apache.spark.sql.execution.vectorized._ import org.apache.spark.sql.types._ import org.apache.spark.sql.vectorized.{ColumnarBatch, ColumnVector} @@ -38,6 +38,11 @@ case class InMemoryTableScanExec( override protected def innerChildren: Seq[QueryPlan[_]] = Seq(relation) ++ super.innerChildren + override def doCanonicalize(): SparkPlan = + copy(attributes = attributes.map(QueryPlan.normalizeExprId(_, relation.output)), + predicates = predicates.map(QueryPlan.normalizeExprId(_, relation.output)), + relation = relation.canonicalized.asInstanceOf[InMemoryRelation]) + override def vectorTypes: Option[Seq[String]] = Option(Seq.fill(attributes.length)( if (!conf.offHeapColumnVectorEnabled) { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index 49c59cf695dc1..9b745befcb611 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -1446,8 +1446,17 @@ class DatasetSuite extends QueryTest with SharedSQLContext { val data = Seq(("a", null)) checkDataset(data.toDS(), data: _*) } + + test("SPARK-23614: Union produces incorrect results when caching is used") { + val cached = spark.createDataset(Seq(TestDataUnion(1, 2, 3), TestDataUnion(4, 5, 6))).cache() + val group1 = cached.groupBy("x").agg(min(col("y")) as "value") + val group2 = cached.groupBy("x").agg(min(col("z")) as "value") + checkAnswer(group1.union(group2), Row(4, 5) :: Row(1, 2) :: Row(4, 6) :: Row(1, 3) :: Nil) + } } +case class TestDataUnion(x: Int, y: Int, z: Int) + case class SingleData(id: Int) case class DoubleData(id: Int, val1: String) case class TripleData(id: Int, val1: String, val2: Long) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala index 697d7e6520713..bde2de5b39fd7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala @@ -125,4 +125,11 @@ class ExchangeSuite extends SparkPlanTest with SharedSQLContext { assertConsistency(spark.range(10000).map(i => Random.nextInt(1000).toLong)) } } + + test("SPARK-23614: Fix incorrect reuse exchange when caching is used") { + val cached = spark.createDataset(Seq((1, 2, 3), (4, 5, 6))).cache() + val projection1 = cached.select("_1", "_2").queryExecution.executedPlan + val projection2 = cached.select("_1", "_3").queryExecution.executedPlan + assert(!projection1.sameResult(projection2)) + } } From a3f86b23c7e75428446ac95d2e527bcdd5562a5f Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Thu, 15 Mar 2018 09:07:54 +0000 Subject: [PATCH 2/4] Address comment. --- .../apache/spark/sql/execution/columnar/InMemoryRelation.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala index 4729daf26404d..2579046e30708 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala @@ -71,7 +71,7 @@ case class InMemoryRelation( override def doCanonicalize(): logical.LogicalPlan = copy(output = output.map(QueryPlan.normalizeExprId(_, child.output)), - storageLevel = new StorageLevel(), + storageLevel = StorageLevel.NONE, child = child.canonicalized, tableName = None)( _cachedColumnBuffers, From fb0f949ec5e144ad38ddc780ae1bb228c59bce55 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Thu, 15 Mar 2018 14:00:40 +0000 Subject: [PATCH 3/4] Fix error. --- .../spark/sql/execution/columnar/InMemoryTableScanExec.scala | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala index 62cc47c121139..5576835768d3d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala @@ -174,7 +174,10 @@ case class InMemoryTableScanExec( override def outputOrdering: Seq[SortOrder] = relation.child.outputOrdering.map(updateAttribute(_).asInstanceOf[SortOrder]) - private def statsFor(a: Attribute) = relation.partitionStatistics.forAttribute(a) + // When we make canonicalized plan, we can't find a normalized attribute in this map. + // We return a `ColumnStatisticsSchema` for normalized attribute in this case. + private def statsFor(a: Attribute) = relation.partitionStatistics.forAttribute.get(a). + getOrElse(new ColumnStatisticsSchema(a)) // Returned filter predicate should return false iff it is impossible for the input expression // to evaluate to `true' based on statistics collected about this partition batch. From a592882d0c61eacdae7dfd9309635468535b416c Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Wed, 21 Mar 2018 23:54:07 +0000 Subject: [PATCH 4/4] Make few variables as lazy. --- .../columnar/InMemoryTableScanExec.scala | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala index 5576835768d3d..e73e1378d52e3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala @@ -174,14 +174,13 @@ case class InMemoryTableScanExec( override def outputOrdering: Seq[SortOrder] = relation.child.outputOrdering.map(updateAttribute(_).asInstanceOf[SortOrder]) - // When we make canonicalized plan, we can't find a normalized attribute in this map. - // We return a `ColumnStatisticsSchema` for normalized attribute in this case. - private def statsFor(a: Attribute) = relation.partitionStatistics.forAttribute.get(a). - getOrElse(new ColumnStatisticsSchema(a)) + // Keeps relation's partition statistics because we don't serialize relation. + private val stats = relation.partitionStatistics + private def statsFor(a: Attribute) = stats.forAttribute(a) // Returned filter predicate should return false iff it is impossible for the input expression // to evaluate to `true' based on statistics collected about this partition batch. - @transient val buildFilter: PartialFunction[Expression, Expression] = { + @transient lazy val buildFilter: PartialFunction[Expression, Expression] = { case And(lhs: Expression, rhs: Expression) if buildFilter.isDefinedAt(lhs) || buildFilter.isDefinedAt(rhs) => (buildFilter.lift(lhs) ++ buildFilter.lift(rhs)).reduce(_ && _) @@ -221,14 +220,14 @@ case class InMemoryTableScanExec( l.asInstanceOf[Literal] <= statsFor(a).upperBound).reduce(_ || _) } - val partitionFilters: Seq[Expression] = { + lazy val partitionFilters: Seq[Expression] = { predicates.flatMap { p => val filter = buildFilter.lift(p) val boundFilter = filter.map( BindReferences.bindReference( _, - relation.partitionStatistics.schema, + stats.schema, allowFailures = true)) boundFilter.foreach(_ => @@ -251,7 +250,7 @@ case class InMemoryTableScanExec( private def filteredCachedBatches(): RDD[CachedBatch] = { // Using these variables here to avoid serialization of entire objects (if referenced directly) // within the map Partitions closure. - val schema = relation.partitionStatistics.schema + val schema = stats.schema val schemaIndex = schema.zipWithIndex val buffers = relation.cachedColumnBuffers