From 645e108eeb6364e57f5d7213dbbd42dbcf1124d3 Mon Sep 17 00:00:00 2001 From: Shixiong Zhu Date: Wed, 11 Oct 2017 13:51:33 -0700 Subject: [PATCH] [SPARK-21988][SS] Implement StreamingRelation.computeStats to fix explain ## What changes were proposed in this pull request? Implement StreamingRelation.computeStats to fix explain ## How was this patch tested? - unit tests: `StreamingRelation.computeStats` and `StreamingExecutionRelation.computeStats`. - regression tests: `explain join with a normal source` and `explain join with MemoryStream`. Author: Shixiong Zhu Closes #19465 from zsxwing/SPARK-21988. --- .../streaming/StreamingRelation.scala | 8 +++ .../spark/sql/streaming/StreamSuite.scala | 65 ++++++++++++++++--- 2 files changed, 63 insertions(+), 10 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingRelation.scala index ab716052c28ba..6b82c78ea653d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingRelation.scala @@ -44,6 +44,14 @@ case class StreamingRelation(dataSource: DataSource, sourceName: String, output: extends LeafNode { override def isStreaming: Boolean = true override def toString: String = sourceName + + // There's no sensible value here. On the execution path, this relation will be + // swapped out with microbatches. But some dataframe operations (in particular explain) do lead + // to this node surviving analysis. So we satisfy the LeafNode contract with the session default + // value. + override def computeStats(): Statistics = Statistics( + sizeInBytes = BigInt(dataSource.sparkSession.sessionState.conf.defaultSizeInBytes) + ) } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala index 9c901062d570a..3d687d2214e90 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala @@ -76,20 +76,65 @@ class StreamSuite extends StreamTest { CheckAnswer(Row(1, 1, "one"), Row(2, 2, "two"), Row(4, 4, "four"))) } + test("StreamingRelation.computeStats") { + val streamingRelation = spark.readStream.format("rate").load().logicalPlan collect { + case s: StreamingRelation => s + } + assert(streamingRelation.nonEmpty, "cannot find StreamingRelation") + assert( + streamingRelation.head.computeStats.sizeInBytes == spark.sessionState.conf.defaultSizeInBytes) + } - test("explain join") { - // Make a table and ensure it will be broadcast. - val smallTable = Seq((1, "one"), (2, "two"), (4, "four")).toDF("number", "word") + test("StreamingExecutionRelation.computeStats") { + val streamingExecutionRelation = MemoryStream[Int].toDF.logicalPlan collect { + case s: StreamingExecutionRelation => s + } + assert(streamingExecutionRelation.nonEmpty, "cannot find StreamingExecutionRelation") + assert(streamingExecutionRelation.head.computeStats.sizeInBytes + == spark.sessionState.conf.defaultSizeInBytes) + } - // Join the input stream with a table. - val inputData = MemoryStream[Int] - val joined = inputData.toDF().join(smallTable, smallTable("number") === $"value") + test("explain join with a normal source") { + // This test triggers CostBasedJoinReorder to call `computeStats` + withSQLConf(SQLConf.CBO_ENABLED.key -> "true", SQLConf.JOIN_REORDER_ENABLED.key -> "true") { + val smallTable = Seq((1, "one"), (2, "two"), (4, "four")).toDF("number", "word") + val smallTable2 = Seq((1, "one"), (2, "two"), (4, "four")).toDF("number", "word") + val smallTable3 = Seq((1, "one"), (2, "two"), (4, "four")).toDF("number", "word") + + // Join the input stream with a table. + val df = spark.readStream.format("rate").load() + val joined = df.join(smallTable, smallTable("number") === $"value") + .join(smallTable2, smallTable2("number") === $"value") + .join(smallTable3, smallTable3("number") === $"value") + + val outputStream = new java.io.ByteArrayOutputStream() + Console.withOut(outputStream) { + joined.explain(true) + } + assert(outputStream.toString.contains("StreamingRelation")) + } + } - val outputStream = new java.io.ByteArrayOutputStream() - Console.withOut(outputStream) { - joined.explain() + test("explain join with MemoryStream") { + // This test triggers CostBasedJoinReorder to call `computeStats` + // Because MemoryStream doesn't use DataSource code path, we need a separate test. + withSQLConf(SQLConf.CBO_ENABLED.key -> "true", SQLConf.JOIN_REORDER_ENABLED.key -> "true") { + val smallTable = Seq((1, "one"), (2, "two"), (4, "four")).toDF("number", "word") + val smallTable2 = Seq((1, "one"), (2, "two"), (4, "four")).toDF("number", "word") + val smallTable3 = Seq((1, "one"), (2, "two"), (4, "four")).toDF("number", "word") + + // Join the input stream with a table. + val df = MemoryStream[Int].toDF + val joined = df.join(smallTable, smallTable("number") === $"value") + .join(smallTable2, smallTable2("number") === $"value") + .join(smallTable3, smallTable3("number") === $"value") + + val outputStream = new java.io.ByteArrayOutputStream() + Console.withOut(outputStream) { + joined.explain(true) + } + assert(outputStream.toString.contains("StreamingRelation")) } - assert(outputStream.toString.contains("StreamingRelation")) } test("SPARK-20432: union one stream with itself") {