Skip to content

Commit

Permalink
[SPARK-28709][DSTREAMS] - Fix StreamingContext leak through Streaming…
Browse files Browse the repository at this point in the history
…JobProgressListener on stop
  • Loading branch information
Nikita Gorbachevsky authored and choojoyq committed Aug 14, 2019
1 parent 247bebc commit 4d5965e
Show file tree
Hide file tree
Showing 7 changed files with 63 additions and 30 deletions.
3 changes: 3 additions & 0 deletions core/src/main/scala/org/apache/spark/ui/SparkUI.scala
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,9 @@ private[spark] class SparkUI private (
streamingJobProgressListener = Option(sparkListener)
}

def clearStreamingJobProgressListener(): Unit = {
streamingJobProgressListener = None
}
}

private[spark] abstract class SparkUITab(parent: SparkUI, prefix: String)
Expand Down
1 change: 1 addition & 0 deletions core/src/main/scala/org/apache/spark/ui/WebUI.scala
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ private[spark] abstract class WebUI(
attachHandler(renderJsonHandler)
val handlers = pageToHandlers.getOrElseUpdate(page, ArrayBuffer[ServletContextHandler]())
handlers += renderHandler
handlers += renderJsonHandler
}

/** Attaches a handler to this UI. */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@ import org.apache.spark.annotation.{DeveloperApi, Experimental}
import org.apache.spark.deploy.SparkHadoopUtil
import org.apache.spark.input.FixedLengthBinaryInputFormat
import org.apache.spark.internal.Logging
import org.apache.spark.internal.config.UI._
import org.apache.spark.rdd.{RDD, RDDOperationScope}
import org.apache.spark.scheduler.LiveListenerBus
import org.apache.spark.serializer.SerializationDebugger
Expand Down Expand Up @@ -189,10 +188,9 @@ class StreamingContext private[streaming] (
private[streaming] val progressListener = new StreamingJobProgressListener(this)

private[streaming] val uiTab: Option[StreamingTab] =
if (conf.get(UI_ENABLED)) {
Some(new StreamingTab(this))
} else {
None
sparkContext.ui match {
case Some(ui) => Some(new StreamingTab(this, ui))
case None => None
}

/* Initializing a streamingSource to register metrics */
Expand Down Expand Up @@ -511,6 +509,10 @@ class StreamingContext private[streaming] (
scheduler.listenerBus.addListener(streamingListener)
}

def removeStreamingListener(streamingListener: StreamingListener): Unit = {
scheduler.listenerBus.removeListener(streamingListener)
}

private def validate() {
assert(graph != null, "Graph is null")
graph.validate()
Expand Down Expand Up @@ -575,6 +577,8 @@ class StreamingContext private[streaming] (
try {
validate()

registerProgressListener()

// Start the streaming scheduler in a new thread, so that thread local properties
// like call sites and job groups can be reset without affecting those of the
// current thread.
Expand Down Expand Up @@ -690,6 +694,9 @@ class StreamingContext private[streaming] (
Utils.tryLogNonFatalError {
uiTab.foreach(_.detach())
}
Utils.tryLogNonFatalError {
unregisterProgressListener()
}
StreamingContext.setActiveContext(null)
Utils.tryLogNonFatalError {
waiter.notifyStop()
Expand All @@ -716,6 +723,18 @@ class StreamingContext private[streaming] (
// Do not stop SparkContext, let its own shutdown hook stop it
stop(stopSparkContext = false, stopGracefully = stopGracefully)
}

private def registerProgressListener(): Unit = {
addStreamingListener(progressListener)
sc.addSparkListener(progressListener)
sc.ui.foreach(_.setStreamingJobProgressListener(progressListener))
}

private def unregisterProgressListener(): Unit = {
removeStreamingListener(progressListener)
sc.removeSparkListener(progressListener)
sc.ui.foreach(_.clearStreamingJobProgressListener())
}
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@

package org.apache.spark.streaming.ui

import org.apache.spark.SparkException
import org.apache.spark.internal.Logging
import org.apache.spark.streaming.StreamingContext
import org.apache.spark.ui.{SparkUI, SparkUITab}
Expand All @@ -26,37 +25,24 @@ import org.apache.spark.ui.{SparkUI, SparkUITab}
* Spark Web UI tab that shows statistics of a streaming job.
* This assumes the given SparkContext has enabled its SparkUI.
*/
private[spark] class StreamingTab(val ssc: StreamingContext)
extends SparkUITab(StreamingTab.getSparkUI(ssc), "streaming") with Logging {

import StreamingTab._
private[spark] class StreamingTab(val ssc: StreamingContext, sparkUI: SparkUI)
extends SparkUITab(sparkUI, "streaming") with Logging {

private val STATIC_RESOURCE_DIR = "org/apache/spark/streaming/ui/static"

val parent = getSparkUI(ssc)
val parent = sparkUI
val listener = ssc.progressListener

ssc.addStreamingListener(listener)
ssc.sc.addSparkListener(listener)
parent.setStreamingJobProgressListener(listener)
attachPage(new StreamingPage(this))
attachPage(new BatchPage(this))

def attach() {
getSparkUI(ssc).attachTab(this)
getSparkUI(ssc).addStaticHandler(STATIC_RESOURCE_DIR, "/static/streaming")
parent.attachTab(this)
parent.addStaticHandler(STATIC_RESOURCE_DIR, "/static/streaming")
}

def detach() {
getSparkUI(ssc).detachTab(this)
getSparkUI(ssc).detachHandler("/static/streaming")
}
}

private object StreamingTab {
def getSparkUI(ssc: StreamingContext): SparkUI = {
ssc.sc.ui.getOrElse {
throw new SparkException("Parent SparkUI to attach this tab to not found!")
}
parent.detachTab(this)
parent.detachHandler("/static/streaming")
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,6 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter {

// Set up the streaming context and input streams
withStreamingContext(new StreamingContext(conf, batchDuration)) { ssc =>
ssc.addStreamingListener(ssc.progressListener)

val input = Seq(1, 2, 3, 4, 5)
// Use "batchCount" to make sure we check the result after all batches finish
val batchCounter = new BatchCounter(ssc)
Expand Down Expand Up @@ -106,8 +104,6 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter {
testServer.start()

withStreamingContext(new StreamingContext(conf, batchDuration)) { ssc =>
ssc.addStreamingListener(ssc.progressListener)

val batchCounter = new BatchCounter(ssc)
val networkStream = ssc.socketTextStream(
"localhost", testServer.port, StorageLevel.MEMORY_AND_DISK)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ import org.scalatest.time.SpanSugar._

import org.apache.spark._
import org.apache.spark.internal.Logging
import org.apache.spark.internal.config.UI.UI_ENABLED
import org.apache.spark.metrics.MetricsSystem
import org.apache.spark.metrics.source.Source
import org.apache.spark.storage.StorageLevel
Expand Down Expand Up @@ -392,6 +393,29 @@ class StreamingContextSuite extends SparkFunSuite with BeforeAndAfter with TimeL
assert(!sourcesAfterStop.contains(streamingSourceAfterStop))
}

test("SPARK-28709 registering and de-registering of progressListener") {
val conf = new SparkConf().setMaster(master).setAppName(appName)
conf.set(UI_ENABLED, true)

ssc = new StreamingContext(conf, batchDuration)

assert(ssc.sc.ui.isDefined, "Spark UI is not started!")
val sparkUI = ssc.sc.ui.get

addInputStream(ssc).register()
ssc.start()

assert(ssc.scheduler.listenerBus.listeners.contains(ssc.progressListener))
assert(ssc.sc.listenerBus.listeners.contains(ssc.progressListener))
assert(sparkUI.getStreamingJobProgressListener.get == ssc.progressListener)

ssc.stop()

assert(!ssc.scheduler.listenerBus.listeners.contains(ssc.progressListener))
assert(!ssc.sc.listenerBus.listeners.contains(ssc.progressListener))
assert(sparkUI.getStreamingJobProgressListener.isEmpty)
}

test("awaitTermination") {
ssc = new StreamingContext(master, appName, batchDuration)
val inputStream = addInputStream(ssc)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,8 @@ class UISeleniumSuite

val sparkUI = ssc.sparkContext.ui.get

sparkUI.getHandlers.count(_.getContextPath.contains("/streaming")) should be (5)

eventually(timeout(10.seconds), interval(50.milliseconds)) {
go to (sparkUI.webUrl.stripSuffix("/"))
find(cssSelector( """ul li a[href*="streaming"]""")) should not be (None)
Expand Down Expand Up @@ -196,6 +198,8 @@ class UISeleniumSuite

ssc.stop(false)

sparkUI.getHandlers.count(_.getContextPath.contains("/streaming")) should be (0)

eventually(timeout(10.seconds), interval(50.milliseconds)) {
go to (sparkUI.webUrl.stripSuffix("/"))
find(cssSelector( """ul li a[href*="streaming"]""")) should be(None)
Expand Down

0 comments on commit 4d5965e

Please sign in to comment.