Skip to content

Commit

Permalink
[SPARK-19603][SS] Fix StreamingQuery explain command
Browse files Browse the repository at this point in the history
## What changes were proposed in this pull request?

`StreamingQuery.explain` doesn't show the correct streaming physical plan right now because `ExplainCommand` receives a runtime batch plan and its `logicalPlan.isStreaming` is always false.

This PR adds `streaming` parameter to `ExplainCommand` to allow `StreamExecution` to specify that it's a streaming plan.

Examples of the explain outputs:

- streaming DataFrame.explain()
```
== Physical Plan ==
*HashAggregate(keys=[value#518], functions=[count(1)])
+- StateStoreSave [value#518], OperatorStateId(<unknown>,0,0), Append, 0
   +- *HashAggregate(keys=[value#518], functions=[merge_count(1)])
      +- StateStoreRestore [value#518], OperatorStateId(<unknown>,0,0)
         +- *HashAggregate(keys=[value#518], functions=[merge_count(1)])
            +- Exchange hashpartitioning(value#518, 5)
               +- *HashAggregate(keys=[value#518], functions=[partial_count(1)])
                  +- *SerializeFromObject [staticinvoke(class org.apache.spark.unsafe.types.UTF8String, StringType, fromString, input[0, java.lang.String, true], true) AS value#518]
                     +- *MapElements <function1>, obj#517: java.lang.String
                        +- *DeserializeToObject value#513.toString, obj#516: java.lang.String
                           +- StreamingRelation MemoryStream[value#513], [value#513]
```

- StreamingQuery.explain(extended = false)
```
== Physical Plan ==
*HashAggregate(keys=[value#518], functions=[count(1)])
+- StateStoreSave [value#518], OperatorStateId(...,0,0), Complete, 0
   +- *HashAggregate(keys=[value#518], functions=[merge_count(1)])
      +- StateStoreRestore [value#518], OperatorStateId(...,0,0)
         +- *HashAggregate(keys=[value#518], functions=[merge_count(1)])
            +- Exchange hashpartitioning(value#518, 5)
               +- *HashAggregate(keys=[value#518], functions=[partial_count(1)])
                  +- *SerializeFromObject [staticinvoke(class org.apache.spark.unsafe.types.UTF8String, StringType, fromString, input[0, java.lang.String, true], true) AS value#518]
                     +- *MapElements <function1>, obj#517: java.lang.String
                        +- *DeserializeToObject value#543.toString, obj#516: java.lang.String
                           +- LocalTableScan [value#543]
```

- StreamingQuery.explain(extended = true)
```
== Parsed Logical Plan ==
Aggregate [value#518], [value#518, count(1) AS count(1)#524L]
+- SerializeFromObject [staticinvoke(class org.apache.spark.unsafe.types.UTF8String, StringType, fromString, input[0, java.lang.String, true], true) AS value#518]
   +- MapElements <function1>, class java.lang.String, [StructField(value,StringType,true)], obj#517: java.lang.String
      +- DeserializeToObject cast(value#543 as string).toString, obj#516: java.lang.String
         +- LocalRelation [value#543]

== Analyzed Logical Plan ==
value: string, count(1): bigint
Aggregate [value#518], [value#518, count(1) AS count(1)#524L]
+- SerializeFromObject [staticinvoke(class org.apache.spark.unsafe.types.UTF8String, StringType, fromString, input[0, java.lang.String, true], true) AS value#518]
   +- MapElements <function1>, class java.lang.String, [StructField(value,StringType,true)], obj#517: java.lang.String
      +- DeserializeToObject cast(value#543 as string).toString, obj#516: java.lang.String
         +- LocalRelation [value#543]

== Optimized Logical Plan ==
Aggregate [value#518], [value#518, count(1) AS count(1)#524L]
+- SerializeFromObject [staticinvoke(class org.apache.spark.unsafe.types.UTF8String, StringType, fromString, input[0, java.lang.String, true], true) AS value#518]
   +- MapElements <function1>, class java.lang.String, [StructField(value,StringType,true)], obj#517: java.lang.String
      +- DeserializeToObject value#543.toString, obj#516: java.lang.String
         +- LocalRelation [value#543]

== Physical Plan ==
*HashAggregate(keys=[value#518], functions=[count(1)], output=[value#518, count(1)#524L])
+- StateStoreSave [value#518], OperatorStateId(...,0,0), Complete, 0
   +- *HashAggregate(keys=[value#518], functions=[merge_count(1)], output=[value#518, count#530L])
      +- StateStoreRestore [value#518], OperatorStateId(...,0,0)
         +- *HashAggregate(keys=[value#518], functions=[merge_count(1)], output=[value#518, count#530L])
            +- Exchange hashpartitioning(value#518, 5)
               +- *HashAggregate(keys=[value#518], functions=[partial_count(1)], output=[value#518, count#530L])
                  +- *SerializeFromObject [staticinvoke(class org.apache.spark.unsafe.types.UTF8String, StringType, fromString, input[0, java.lang.String, true], true) AS value#518]
                     +- *MapElements <function1>, obj#517: java.lang.String
                        +- *DeserializeToObject value#543.toString, obj#516: java.lang.String
                           +- LocalTableScan [value#543]
```

## How was this patch tested?

The updated unit test.

Author: Shixiong Zhu <shixiong@databricks.com>

Closes #16934 from zsxwing/SPARK-19603.
  • Loading branch information
zsxwing committed Feb 16, 2017
1 parent 08c1972 commit fc02ef9
Show file tree
Hide file tree
Showing 3 changed files with 52 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -86,18 +86,18 @@ case class ExecutedCommandExec(cmd: RunnableCommand) extends SparkPlan {
* }}}
*
* @param logicalPlan plan to explain
* @param output output schema
* @param extended whether to do extended explain or not
* @param codegen whether to output generated code from whole-stage codegen or not
*/
case class ExplainCommand(
logicalPlan: LogicalPlan,
override val output: Seq[Attribute] =
Seq(AttributeReference("plan", StringType, nullable = true)()),
extended: Boolean = false,
codegen: Boolean = false)
extends RunnableCommand {

override val output: Seq[Attribute] =
Seq(AttributeReference("plan", StringType, nullable = true)())

// Run through the optimizer to generate the physical plan.
override def run(sparkSession: SparkSession): Seq[Row] = try {
val queryExecution =
Expand All @@ -121,3 +121,25 @@ case class ExplainCommand(
("Error occurred during query planning: \n" + cause.getMessage).split("\n").map(Row(_))
}
}

/** An explain command for users to see how a streaming batch is executed. */
case class StreamingExplainCommand(
queryExecution: IncrementalExecution,
extended: Boolean) extends RunnableCommand {

override val output: Seq[Attribute] =
Seq(AttributeReference("plan", StringType, nullable = true)())

// Run through the optimizer to generate the physical plan.
override def run(sparkSession: SparkSession): Seq[Row] = try {
val outputString =
if (extended) {
queryExecution.toString
} else {
queryExecution.simpleString
}
Seq(Row(outputString))
} catch { case cause: TreeNodeException[_] =>
("Error occurred during query planning: \n" + cause.getMessage).split("\n").map(Row(_))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@

package org.apache.spark.sql.execution.streaming

import java.io.IOException
import java.util.UUID
import java.util.concurrent.{CountDownLatch, TimeUnit}
import java.util.concurrent.locks.ReentrantLock
Expand All @@ -33,7 +32,7 @@ import org.apache.spark.sql.catalyst.encoders.RowEncoder
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, CurrentBatchTimestamp, CurrentDate, CurrentTimestamp}
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
import org.apache.spark.sql.execution.QueryExecution
import org.apache.spark.sql.execution.command.ExplainCommand
import org.apache.spark.sql.execution.command.StreamingExplainCommand
import org.apache.spark.sql.streaming._
import org.apache.spark.util.{Clock, UninterruptibleThread, Utils}

Expand Down Expand Up @@ -162,7 +161,7 @@ class StreamExecution(
private var state: State = INITIALIZING

@volatile
var lastExecution: QueryExecution = _
var lastExecution: IncrementalExecution = _

/** Holds the most recent input data for each source. */
protected var newData: Map[Source, DataFrame] = _
Expand Down Expand Up @@ -673,7 +672,7 @@ class StreamExecution(
if (lastExecution == null) {
"No physical plan. Waiting for data."
} else {
val explain = ExplainCommand(lastExecution.logical, extended = extended)
val explain = StreamingExplainCommand(lastExecution, extended = extended)
sparkSession.sessionState.executePlan(explain).executedPlan.executeCollect()
.map(_.getString(0)).mkString("\n")
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,9 @@ import scala.util.control.ControlThrowable

import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.streaming.InternalOutputModes
import org.apache.spark.sql.execution.command.ExplainCommand
import org.apache.spark.sql.execution.streaming._
import org.apache.spark.sql.functions._
import org.apache.spark.sql.sources.StreamSourceProvider
import org.apache.spark.sql.types.{IntegerType, StructField, StructType}

Expand Down Expand Up @@ -277,10 +279,24 @@ class StreamSuite extends StreamTest {

test("explain") {
val inputData = MemoryStream[String]
val df = inputData.toDS().map(_ + "foo")
// Test `explain` not throwing errors
df.explain()
val q = df.writeStream.queryName("memory_explain").format("memory").start()
val df = inputData.toDS().map(_ + "foo").groupBy("value").agg(count("*"))

// Test `df.explain`
val explain = ExplainCommand(df.queryExecution.logical, extended = false)
val explainString =
spark.sessionState
.executePlan(explain)
.executedPlan
.executeCollect()
.map(_.getString(0))
.mkString("\n")
assert(explainString.contains("StateStoreRestore"))
assert(explainString.contains("StreamingRelation"))
assert(!explainString.contains("LocalTableScan"))

// Test StreamingQuery.display
val q = df.writeStream.queryName("memory_explain").outputMode("complete").format("memory")
.start()
.asInstanceOf[StreamingQueryWrapper]
.streamingQuery
try {
Expand All @@ -294,12 +310,16 @@ class StreamSuite extends StreamTest {
// `extended = false` only displays the physical plan.
assert("LocalRelation".r.findAllMatchIn(explainWithoutExtended).size === 0)
assert("LocalTableScan".r.findAllMatchIn(explainWithoutExtended).size === 1)
// Use "StateStoreRestore" to verify that it does output a streaming physical plan
assert(explainWithoutExtended.contains("StateStoreRestore"))

val explainWithExtended = q.explainInternal(true)
// `extended = true` displays 3 logical plans (Parsed/Optimized/Optimized) and 1 physical
// plan.
assert("LocalRelation".r.findAllMatchIn(explainWithExtended).size === 3)
assert("LocalTableScan".r.findAllMatchIn(explainWithExtended).size === 1)
// Use "StateStoreRestore" to verify that it does output a streaming physical plan
assert(explainWithExtended.contains("StateStoreRestore"))
} finally {
q.stop()
}
Expand Down

0 comments on commit fc02ef9

Please sign in to comment.