Skip to content

Commit

Permalink
Merge pull request Kyligence#332 from JialeHe/kyspark-3.1.1.x-4.x
Browse files Browse the repository at this point in the history
KE-31310 Fix Memory leak of ExecutionListenerBus
  • Loading branch information
JialeHe committed Oct 18, 2021
2 parents e175c27 + 2e0f391 commit 24b9854
Show file tree
Hide file tree
Showing 4 changed files with 66 additions and 5 deletions.
19 changes: 19 additions & 0 deletions core/src/main/scala/org/apache/spark/ContextCleaner.scala
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import org.apache.spark.broadcast.Broadcast
import org.apache.spark.internal.Logging
import org.apache.spark.internal.config._
import org.apache.spark.rdd.{RDD, ReliableRDDCheckpointData}
import org.apache.spark.scheduler.SparkListener
import org.apache.spark.shuffle.api.ShuffleDriverComponents
import org.apache.spark.util.{AccumulatorContext, AccumulatorV2, ThreadUtils, Utils}

Expand All @@ -39,6 +40,7 @@ private case class CleanShuffle(shuffleId: Int) extends CleanupTask
private case class CleanBroadcast(broadcastId: Long) extends CleanupTask
private case class CleanAccum(accId: Long) extends CleanupTask
private case class CleanCheckpoint(rddId: Int) extends CleanupTask
private case class CleanSparkListener(listener: SparkListener) extends CleanupTask

/**
* A WeakReference associated with a CleanupTask.
Expand Down Expand Up @@ -183,6 +185,11 @@ private[spark] class ContextCleaner(
registerForCleanup(rdd, CleanCheckpoint(parentId))
}

/** Register a SparkListener to be cleaned up when its owner is garbage collected. */
def registerSparkListenerForCleanup(listenerOwner: AnyRef, listener: SparkListener): Unit = {
registerForCleanup(listenerOwner, CleanSparkListener(listener))
}

/** Register an object for cleanup. */
private def registerForCleanup(objectForCleanup: AnyRef, task: CleanupTask): Unit = {
referenceBuffer.add(new CleanupTaskWeakReference(task, objectForCleanup, referenceQueue))
Expand Down Expand Up @@ -224,6 +231,8 @@ private[spark] class ContextCleaner(
doCleanupAccum(accId, blocking = blockOnCleanupTasks)
case CleanCheckpoint(rddId) =>
doCleanCheckpoint(rddId)
case CleanSparkListener(listener) =>
doCleanSparkListener(listener)
}
} catch {
case ie: InterruptedException if stopped => // ignore
Expand Down Expand Up @@ -302,6 +311,16 @@ private[spark] class ContextCleaner(
}
}

def doCleanSparkListener(listener: SparkListener): Unit = {
try {
logDebug(s"Cleaning Spark Listener $listener")
sc.listenerBus.removeListener(listener)
logDebug(s"Cleaned Spark Listener $listener")
} catch {
case e: Exception => logError(s"Error cleaning Spark Listener $listener", e)
}
}

private def broadcastManager = sc.env.broadcastManager
private def mapOutputTrackerMaster = sc.env.mapOutputTracker.asInstanceOf[MapOutputTrackerMaster]
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package org.apache.spark.sql

import java.io.Closeable
import java.util.UUID
import java.util.concurrent.TimeUnit._
import java.util.concurrent.atomic.{AtomicBoolean, AtomicReference}

Expand Down Expand Up @@ -102,6 +103,8 @@ class SparkSession private(
new SparkSessionExtensions), Map.empty)
}

private[sql] val sessionUUID: String = UUID.randomUUID().toString

sparkContext.assertNotStopped()

// If there is no active SparkSession, uses the default SQL conf. Otherwise, use the session's.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ trait QueryExecutionListener {
class ExecutionListenerManager private[sql](session: SparkSession, loadExtensions: Boolean)
extends Logging {

private val listenerBus = new ExecutionListenerBus(session)
private val listenerBus = new ExecutionListenerBus(this, session)

if (loadExtensions) {
val conf = session.sparkContext.conf
Expand Down Expand Up @@ -124,10 +124,16 @@ class ExecutionListenerManager private[sql](session: SparkSession, loadExtension
}
}

private[sql] class ExecutionListenerBus(session: SparkSession)
private[sql] class ExecutionListenerBus private(sessionUUID: String)
extends SparkListener with ListenerBus[QueryExecutionListener, SparkListenerSQLExecutionEnd] {

session.sparkContext.listenerBus.addToSharedQueue(this)
def this(manager: ExecutionListenerManager, session: SparkSession) = {
this(session.sessionUUID)
session.sparkContext.listenerBus.addToSharedQueue(this)
session.sparkContext.cleaner.foreach(cleaner => {
cleaner.registerSparkListenerForCleanup(manager, this)
})
}

override def onOtherEvent(event: SparkListenerEvent): Unit = event match {
case e: SparkListenerSQLExecutionEnd => postToAll(e)
Expand Down Expand Up @@ -158,6 +164,6 @@ private[sql] class ExecutionListenerBus(session: SparkSession)
private def shouldReport(e: SparkListenerSQLExecutionEnd): Boolean = {
// Only catch SQL execution with a name, and triggered by the same spark session that this
// listener manager belongs.
e.executionName.isDefined && e.qe.sparkSession.eq(this.session)
e.executionName.isDefined && e.qe.sparkSession.sessionUUID == sessionUUID
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,19 +17,24 @@

package org.apache.spark.sql

import scala.collection.JavaConverters._

import org.scalatest.BeforeAndAfterEach
import org.scalatest.concurrent.Eventually
import org.scalatest.time.SpanSugar._

import org.apache.spark.{SparkConf, SparkContext, SparkException, SparkFunSuite}
import org.apache.spark.internal.config.EXECUTOR_ALLOW_SPARK_CONTEXT
import org.apache.spark.internal.config.UI.UI_ENABLED
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.internal.StaticSQLConf._
import org.apache.spark.sql.util.ExecutionListenerBus
import org.apache.spark.util.ThreadUtils

/**
* Test cases for the builder pattern of [[SparkSession]].
*/
class SparkSessionBuilderSuite extends SparkFunSuite with BeforeAndAfterEach {
class SparkSessionBuilderSuite extends SparkFunSuite with BeforeAndAfterEach with Eventually {

override def afterEach(): Unit = {
// This suite should not interfere with the other test suites.
Expand All @@ -39,6 +44,34 @@ class SparkSessionBuilderSuite extends SparkFunSuite with BeforeAndAfterEach {
SparkSession.clearDefaultSession()
}

test("Fix Memory leak of ExecutionListenerBus") {
val spark = SparkSession.builder()
.master("local")
.getOrCreate()

@inline def listenerNum(): Int = {
spark.sparkContext
.listenerBus
.listeners
.asScala
.count(_.isInstanceOf[ExecutionListenerBus])
}

(1 to 10).foreach(_ => {
spark.cloneSession()
SparkSession.clearActiveSession()
})

eventually(timeout(10.seconds), interval(1.seconds)) {
System.gc()
// After GC, the number of ExecutionListenerBus should be less than 11 (we created 11
// SparkSessions in total).
// Since GC can't 100% guarantee all out-of-referenced objects be cleaned at one time,
// here, we check at least one listener is cleaned up to prove the mechanism works.
assert(listenerNum() < 11)
}
}

test("create with config options and propagate them to SparkContext and SparkSession") {
val session = SparkSession.builder()
.master("local")
Expand Down

0 comments on commit 24b9854

Please sign in to comment.