Skip to content

Commit

Permalink
[SPARK-28709][DSTREAMS] Fix StreamingContext leak through Streaming
Browse files Browse the repository at this point in the history
In my application spark streaming is restarted programmatically by stopping StreamingContext without stopping of SparkContext and creating/starting a new one. I use it for automatic detection of Kafka topic/partition changes and automatic failover in case of non fatal exceptions.

However i notice that after multiple restarts driver fails with OOM. During investigation of heap dump i figured out that StreamingContext object isn't cleared by GC after stopping.

<img width="1901" alt="Screen Shot 2019-08-14 at 12 23 33" src="https://user-images.githubusercontent.com/13151161/63010149-83f4c200-be8e-11e9-9f48-12b6e97839f4.png">

There are several places which holds reference to it :

1. StreamingTab registers StreamingJobProgressListener which holds reference to Streaming Context directly to LiveListenerBus shared queue via ssc.sc.addSparkListener(listener) method invocation. However this listener isn't unregistered at stop method.
2. json handlers (/streaming/json and /streaming/batch/json) aren't unregistered in SparkUI, while they hold reference to StreamingJobProgressListener. Basically the same issue affects all the pages, i assume that renderJsonHandler should be added to pageToHandlers cache on attachPage method invocation in order to unregistered it as well on detachPage.
3. SparkUi holds reference to StreamingJobProgressListener in the corresponding local variable which isn't cleared after stopping of StreamingContext.

Added tests to existing test suites.
After i applied these changes via reflection in my app OOM on driver side gone.

Closes #25439 from choojoyq/SPARK-28709-fix-streaming-context-leak-on-stop.

Authored-by: Nikita Gorbachevsky <nikitag@playtika.com>
Signed-off-by: Sean Owen <sean.owen@databricks.com>
  • Loading branch information
Nikita Gorbachevsky authored and srowen committed Sep 3, 2019
1 parent 1a5858f commit 3f3f524
Show file tree
Hide file tree
Showing 7 changed files with 62 additions and 29 deletions.
3 changes: 3 additions & 0 deletions core/src/main/scala/org/apache/spark/ui/SparkUI.scala
Expand Up @@ -140,6 +140,9 @@ private[spark] class SparkUI private (
streamingJobProgressListener = Option(sparkListener)
}

def clearStreamingJobProgressListener(): Unit = {
streamingJobProgressListener = None
}
}

private[spark] abstract class SparkUITab(parent: SparkUI, prefix: String)
Expand Down
1 change: 1 addition & 0 deletions core/src/main/scala/org/apache/spark/ui/WebUI.scala
Expand Up @@ -88,6 +88,7 @@ private[spark] abstract class WebUI(
attachHandler(renderJsonHandler)
val handlers = pageToHandlers.getOrElseUpdate(page, ArrayBuffer[ServletContextHandler]())
handlers += renderHandler
handlers += renderJsonHandler
}

/** Attaches a handler to this UI. */
Expand Down
Expand Up @@ -188,10 +188,9 @@ class StreamingContext private[streaming] (
private[streaming] val progressListener = new StreamingJobProgressListener(this)

private[streaming] val uiTab: Option[StreamingTab] =
if (conf.getBoolean("spark.ui.enabled", true)) {
Some(new StreamingTab(this))
} else {
None
sparkContext.ui match {
case Some(ui) => Some(new StreamingTab(this, ui))
case None => None
}

/* Initializing a streamingSource to register metrics */
Expand Down Expand Up @@ -508,6 +507,10 @@ class StreamingContext private[streaming] (
scheduler.listenerBus.addListener(streamingListener)
}

def removeStreamingListener(streamingListener: StreamingListener): Unit = {
scheduler.listenerBus.removeListener(streamingListener)
}

private def validate() {
assert(graph != null, "Graph is null")
graph.validate()
Expand Down Expand Up @@ -572,6 +575,8 @@ class StreamingContext private[streaming] (
try {
validate()

registerProgressListener()

// Start the streaming scheduler in a new thread, so that thread local properties
// like call sites and job groups can be reset without affecting those of the
// current thread.
Expand Down Expand Up @@ -687,6 +692,9 @@ class StreamingContext private[streaming] (
Utils.tryLogNonFatalError {
uiTab.foreach(_.detach())
}
Utils.tryLogNonFatalError {
unregisterProgressListener()
}
StreamingContext.setActiveContext(null)
Utils.tryLogNonFatalError {
waiter.notifyStop()
Expand All @@ -713,6 +721,18 @@ class StreamingContext private[streaming] (
// Do not stop SparkContext, let its own shutdown hook stop it
stop(stopSparkContext = false, stopGracefully = stopGracefully)
}

private def registerProgressListener(): Unit = {
addStreamingListener(progressListener)
sc.addSparkListener(progressListener)
sc.ui.foreach(_.setStreamingJobProgressListener(progressListener))
}

private def unregisterProgressListener(): Unit = {
removeStreamingListener(progressListener)
sc.removeSparkListener(progressListener)
sc.ui.foreach(_.clearStreamingJobProgressListener())
}
}

/**
Expand Down
Expand Up @@ -17,7 +17,6 @@

package org.apache.spark.streaming.ui

import org.apache.spark.SparkException
import org.apache.spark.internal.Logging
import org.apache.spark.streaming.StreamingContext
import org.apache.spark.ui.{SparkUI, SparkUITab}
Expand All @@ -26,37 +25,24 @@ import org.apache.spark.ui.{SparkUI, SparkUITab}
* Spark Web UI tab that shows statistics of a streaming job.
* This assumes the given SparkContext has enabled its SparkUI.
*/
private[spark] class StreamingTab(val ssc: StreamingContext)
extends SparkUITab(StreamingTab.getSparkUI(ssc), "streaming") with Logging {

import StreamingTab._
private[spark] class StreamingTab(val ssc: StreamingContext, sparkUI: SparkUI)
extends SparkUITab(sparkUI, "streaming") with Logging {

private val STATIC_RESOURCE_DIR = "org/apache/spark/streaming/ui/static"

val parent = getSparkUI(ssc)
val parent = sparkUI
val listener = ssc.progressListener

ssc.addStreamingListener(listener)
ssc.sc.addSparkListener(listener)
parent.setStreamingJobProgressListener(listener)
attachPage(new StreamingPage(this))
attachPage(new BatchPage(this))

def attach() {
getSparkUI(ssc).attachTab(this)
getSparkUI(ssc).addStaticHandler(STATIC_RESOURCE_DIR, "/static/streaming")
parent.attachTab(this)
parent.addStaticHandler(STATIC_RESOURCE_DIR, "/static/streaming")
}

def detach() {
getSparkUI(ssc).detachTab(this)
getSparkUI(ssc).detachHandler("/static/streaming")
}
}

private object StreamingTab {
def getSparkUI(ssc: StreamingContext): SparkUI = {
ssc.sc.ui.getOrElse {
throw new SparkException("Parent SparkUI to attach this tab to not found!")
}
parent.detachTab(this)
parent.detachHandler("/static/streaming")
}
}
Expand Up @@ -50,8 +50,6 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter {

// Set up the streaming context and input streams
withStreamingContext(new StreamingContext(conf, batchDuration)) { ssc =>
ssc.addStreamingListener(ssc.progressListener)

val input = Seq(1, 2, 3, 4, 5)
// Use "batchCount" to make sure we check the result after all batches finish
val batchCounter = new BatchCounter(ssc)
Expand Down Expand Up @@ -104,8 +102,6 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter {
testServer.start()

withStreamingContext(new StreamingContext(conf, batchDuration)) { ssc =>
ssc.addStreamingListener(ssc.progressListener)

val batchCounter = new BatchCounter(ssc)
val networkStream = ssc.socketTextStream(
"localhost", testServer.port, StorageLevel.MEMORY_AND_DISK)
Expand Down
Expand Up @@ -393,6 +393,29 @@ class StreamingContextSuite extends SparkFunSuite with BeforeAndAfter with TimeL
assert(!sourcesAfterStop.contains(streamingSourceAfterStop))
}

test("SPARK-28709 registering and de-registering of progressListener") {
val conf = new SparkConf().setMaster(master).setAppName(appName)
conf.set("spark.ui.enabled", "true")

ssc = new StreamingContext(conf, batchDuration)

assert(ssc.sc.ui.isDefined, "Spark UI is not started!")
val sparkUI = ssc.sc.ui.get

addInputStream(ssc).register()
ssc.start()

assert(ssc.scheduler.listenerBus.listeners.contains(ssc.progressListener))
assert(ssc.sc.listenerBus.listeners.contains(ssc.progressListener))
assert(sparkUI.getStreamingJobProgressListener.get == ssc.progressListener)

ssc.stop()

assert(!ssc.scheduler.listenerBus.listeners.contains(ssc.progressListener))
assert(!ssc.sc.listenerBus.listeners.contains(ssc.progressListener))
assert(sparkUI.getStreamingJobProgressListener.isEmpty)
}

test("awaitTermination") {
ssc = new StreamingContext(master, appName, batchDuration)
val inputStream = addInputStream(ssc)
Expand Down
Expand Up @@ -96,6 +96,8 @@ class UISeleniumSuite

val sparkUI = ssc.sparkContext.ui.get

sparkUI.getHandlers.count(_.getContextPath.contains("/streaming")) should be (5)

eventually(timeout(10 seconds), interval(50 milliseconds)) {
go to (sparkUI.webUrl.stripSuffix("/"))
find(cssSelector( """ul li a[href*="streaming"]""")) should not be (None)
Expand Down Expand Up @@ -195,6 +197,8 @@ class UISeleniumSuite

ssc.stop(false)

sparkUI.getHandlers.count(_.getContextPath.contains("/streaming")) should be (0)

eventually(timeout(10 seconds), interval(50 milliseconds)) {
go to (sparkUI.webUrl.stripSuffix("/"))
find(cssSelector( """ul li a[href*="streaming"]""")) should be(None)
Expand Down

0 comments on commit 3f3f524

Please sign in to comment.