Skip to content

Commit

Permalink
[SPARK-46947][CORE] Delay memory manager initialization until Driver …
Browse files Browse the repository at this point in the history
…plugin is loaded

### What changes were proposed in this pull request?

This changes the initialization of `SparkEnv.memoryManager` to after the `DriverPlugin` is loaded, to allow the plugin to customize memory related configurations.

A minor fix has been made to `Task` to make sure that it uses the same `BlockManager` through out the task execution. Previous a different `BlockManager` could be used in some corner cases. Also added a test for the fix.

### Why are the changes needed?

Today, there is no way for a custom `DriverPlugin` to override memory configurations such as `spark.executor.memory`, `spark.executor.memoryOverhead`, `spark.memory.offheap.size` etc This is because the memory manager is initialized before `DriverPlugin` is loaded.

A similar change has been made to `shuffleManager` in apache#43627.

### Does this PR introduce _any_ user-facing change?

No.

### How was this patch tested?

Existing tests. Also added new tests.

### Was this patch authored or co-authored using generative AI tooling?

No

Closes apache#45052 from sunchao/SPARK-46947.

Authored-by: Chao Sun <sunchao@apache.org>
Signed-off-by: Chao Sun <sunchao@apache.org>
  • Loading branch information
sunchao authored and TakawaAkirayo committed Mar 4, 2024
1 parent abc09e5 commit 7437ad4
Show file tree
Hide file tree
Showing 6 changed files with 163 additions and 17 deletions.
1 change: 1 addition & 0 deletions core/src/main/scala/org/apache/spark/SparkContext.scala
Original file line number Diff line number Diff line change
Expand Up @@ -577,6 +577,7 @@ class SparkContext(config: SparkConf) extends Logging {
// Initialize any plugins before the task scheduler is initialized.
_plugins = PluginContainer(this, _resources.asJava)
_env.initializeShuffleManager()
_env.initializeMemoryManager(SparkContext.numDriverCores(master, conf))

// Create and start the scheduler
val (sched, ts) = SparkContext.createTaskScheduler(this, master)
Expand Down
20 changes: 15 additions & 5 deletions core/src/main/scala/org/apache/spark/SparkEnv.scala
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,6 @@ class SparkEnv (
val blockManager: BlockManager,
val securityManager: SecurityManager,
val metricsSystem: MetricsSystem,
val memoryManager: MemoryManager,
val outputCommitCoordinator: OutputCommitCoordinator,
val conf: SparkConf) extends Logging {

Expand All @@ -77,6 +76,12 @@ class SparkEnv (

def shuffleManager: ShuffleManager = _shuffleManager

// We initialize the MemoryManager later in SparkContext after DriverPlugin is loaded
// to allow the plugin to overwrite executor memory configurations
private var _memoryManager: MemoryManager = _

def memoryManager: MemoryManager = _memoryManager

@volatile private[spark] var isStopped = false

/**
Expand Down Expand Up @@ -199,6 +204,12 @@ class SparkEnv (
"Shuffle manager already initialized to %s", _shuffleManager)
_shuffleManager = ShuffleManager.create(conf, executorId == SparkContext.DRIVER_IDENTIFIER)
}

private[spark] def initializeMemoryManager(numUsableCores: Int): Unit = {
Preconditions.checkState(null == memoryManager,
"Memory manager already initialized to %s", _memoryManager)
_memoryManager = UnifiedMemoryManager(conf, numUsableCores)
}
}

object SparkEnv extends Logging {
Expand Down Expand Up @@ -276,6 +287,8 @@ object SparkEnv extends Logging {
numCores,
ioEncryptionKey
)
// Set the memory manager since it needs to be initialized explicitly
env.initializeMemoryManager(numCores)
SparkEnv.set(env)
env
}
Expand Down Expand Up @@ -358,8 +371,6 @@ object SparkEnv extends Logging {
new MapOutputTrackerMasterEndpoint(
rpcEnv, mapOutputTracker.asInstanceOf[MapOutputTrackerMaster], conf))

val memoryManager: MemoryManager = UnifiedMemoryManager(conf, numUsableCores)

val blockManagerPort = if (isDriver) {
conf.get(DRIVER_BLOCK_MANAGER_PORT)
} else {
Expand Down Expand Up @@ -418,7 +429,7 @@ object SparkEnv extends Logging {
blockManagerMaster,
serializerManager,
conf,
memoryManager,
_memoryManager = null,
mapOutputTracker,
_shuffleManager = null,
blockTransferService,
Expand Down Expand Up @@ -463,7 +474,6 @@ object SparkEnv extends Logging {
blockManager,
securityManager,
metricsSystem,
memoryManager,
outputCommitCoordinator,
conf)

Expand Down
15 changes: 10 additions & 5 deletions core/src/main/scala/org/apache/spark/scheduler/Task.scala
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,12 @@ private[spark] abstract class Task[T](

require(cpus > 0, "CPUs per task should be > 0")

SparkEnv.get.blockManager.registerTask(taskAttemptId)
// Use the blockManager at start of the task through out the task - particularly in
// case of local mode, a SparkEnv can be initialized when spark context is restarted
// and we want to ensure the right env and block manager is used (given lazy initialization of
// block manager)
val blockManager = SparkEnv.get.blockManager
blockManager.registerTask(taskAttemptId)
// TODO SPARK-24874 Allow create BarrierTaskContext based on partitions, instead of whether
// the stage is barrier.
val taskContext = new TaskContextImpl(
Expand Down Expand Up @@ -143,15 +148,15 @@ private[spark] abstract class Task[T](
try {
Utils.tryLogNonFatalError {
// Release memory used by this thread for unrolling blocks
SparkEnv.get.blockManager.memoryStore.releaseUnrollMemoryForThisTask(MemoryMode.ON_HEAP)
SparkEnv.get.blockManager.memoryStore.releaseUnrollMemoryForThisTask(
MemoryMode.OFF_HEAP)
blockManager.memoryStore.releaseUnrollMemoryForThisTask(MemoryMode.ON_HEAP)
blockManager.memoryStore.releaseUnrollMemoryForThisTask(MemoryMode.OFF_HEAP)
// Notify any tasks waiting for execution memory to be freed to wake up and try to
// acquire memory again. This makes impossible the scenario where a task sleeps forever
// because there are no other tasks left to notify it. Since this is safe to do but may
// not be strictly necessary, we should revisit whether we can remove this in the
// future.
val memoryManager = SparkEnv.get.memoryManager

val memoryManager = blockManager.memoryManager
memoryManager.synchronized { memoryManager.notifyAll() }
}
} finally {
Expand Down
20 changes: 14 additions & 6 deletions core/src/main/scala/org/apache/spark/storage/BlockManager.scala
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ private[spark] class BlockManager(
val master: BlockManagerMaster,
val serializerManager: SerializerManager,
val conf: SparkConf,
memoryManager: MemoryManager,
private val _memoryManager: MemoryManager,
mapOutputTracker: MapOutputTracker,
private val _shuffleManager: ShuffleManager,
val blockTransferService: BlockTransferService,
Expand All @@ -198,6 +198,12 @@ private[spark] class BlockManager(
// (except for tests) and we ask for the instance from the SparkEnv.
private lazy val shuffleManager = Option(_shuffleManager).getOrElse(SparkEnv.get.shuffleManager)

// Similarly, we also initialize MemoryManager later after DriverPlugin is loaded, to
// allow the plugin to overwrite certain memory configurations. The `_memoryManager` will be
// null here and we ask for the instance from SparkEnv
private[spark] lazy val memoryManager =
Option(_memoryManager).getOrElse(SparkEnv.get.memoryManager)

// same as `conf.get(config.SHUFFLE_SERVICE_ENABLED)`
private[spark] val externalShuffleServiceEnabled: Boolean = externalBlockStoreClient.isDefined
private val isDriver = executorId == SparkContext.DRIVER_IDENTIFIER
Expand All @@ -224,17 +230,19 @@ private[spark] class BlockManager(
ThreadUtils.newDaemonCachedThreadPool("block-manager-future", 128))

// Actual storage of where blocks are kept
private[spark] val memoryStore =
new MemoryStore(conf, blockInfoManager, serializerManager, memoryManager, this)
private[spark] lazy val memoryStore = {
val store = new MemoryStore(conf, blockInfoManager, serializerManager, memoryManager, this)
memoryManager.setMemoryStore(store)
store
}
private[spark] val diskStore = new DiskStore(conf, diskBlockManager, securityManager)
memoryManager.setMemoryStore(memoryStore)

// Note: depending on the memory manager, `maxMemory` may actually vary over time.
// However, since we use this only for reporting and logging, what we actually want here is
// the absolute maximum value that `maxMemory` can ever possibly reach. We may need
// to revisit whether reporting this value as the "max" is intuitive to the user.
private val maxOnHeapMemory = memoryManager.maxOnHeapStorageMemory
private val maxOffHeapMemory = memoryManager.maxOffHeapStorageMemory
private lazy val maxOnHeapMemory = memoryManager.maxOnHeapStorageMemory
private lazy val maxOffHeapMemory = memoryManager.maxOffHeapStorageMemory

private[spark] val externalShuffleServicePort = StorageUtils.externalShuffleServicePort(conf)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ import org.apache.spark.TestUtils._
import org.apache.spark.api.plugin._
import org.apache.spark.internal.config._
import org.apache.spark.launcher.SparkLauncher
import org.apache.spark.memory.MemoryMode
import org.apache.spark.resource.ResourceInformation
import org.apache.spark.resource.ResourceUtils.GPU
import org.apache.spark.resource.TestResourceIDs.{DRIVER_GPU_ID, EXECUTOR_GPU_ID, WORKER_GPU_ID}
Expand Down Expand Up @@ -228,6 +229,58 @@ class PluginContainerSuite extends SparkFunSuite with LocalSparkContext {
assert(driverResources.get(GPU).name === GPU)
}
}

test("memory override in plugin") {
val conf = new SparkConf()
.setAppName(getClass().getName())
.set(SparkLauncher.SPARK_MASTER, "local-cluster[2,1,1024]")
.set(PLUGINS, Seq(classOf[MemoryOverridePlugin].getName()))

var sc: SparkContext = null
try {
sc = new SparkContext(conf)
val memoryManager = sc.env.memoryManager

assert(memoryManager.tungstenMemoryMode == MemoryMode.OFF_HEAP)
assert(memoryManager.maxOffHeapStorageMemory == MemoryOverridePlugin.offHeapMemory)

// Ensure all executors has started
TestUtils.waitUntilExecutorsUp(sc, 1, 60000)

// Check executor memory is also updated
val execInfo = sc.statusTracker.getExecutorInfos.head
assert(execInfo.totalOffHeapStorageMemory() == MemoryOverridePlugin.offHeapMemory)
} finally {
if (sc != null) {
sc.stop()
}
}
}
}

class MemoryOverridePlugin extends SparkPlugin {
override def driverPlugin(): DriverPlugin = {
new DriverPlugin {
override def init(sc: SparkContext, pluginContext: PluginContext): JMap[String, String] = {
// Take the original executor memory, and set `spark.memory.offHeap.size` to be the
// same value. Also set `spark.memory.offHeap.enabled` to true.
val originalExecutorMemBytes =
sc.conf.getSizeAsMb(EXECUTOR_MEMORY.key, EXECUTOR_MEMORY.defaultValueString)
sc.conf.set(MEMORY_OFFHEAP_ENABLED.key, "true")
sc.conf.set(MEMORY_OFFHEAP_SIZE.key, s"${originalExecutorMemBytes}M")
MemoryOverridePlugin.offHeapMemory = sc.conf.getSizeAsBytes(MEMORY_OFFHEAP_SIZE.key)
Map.empty[String, String].asJava
}
}
}

override def executorPlugin(): ExecutorPlugin = {
new ExecutorPlugin {}
}
}

object MemoryOverridePlugin {
var offHeapMemory: Long = _
}

class NonLocalModeSparkPlugin extends SparkPlugin {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package org.apache.spark.scheduler

import java.util.Properties
import java.util.concurrent.Semaphore
import java.util.concurrent.atomic.AtomicInteger

import scala.collection.mutable.ArrayBuffer
Expand All @@ -27,7 +28,9 @@ import org.mockito.Mockito._
import org.scalatest.BeforeAndAfter

import org.apache.spark._
import org.apache.spark.api.plugin.{DriverPlugin, ExecutorPlugin, SparkPlugin}
import org.apache.spark.executor.{Executor, TaskMetrics, TaskMetricsSuite}
import org.apache.spark.internal.Logging
import org.apache.spark.internal.config.METRICS_CONF
import org.apache.spark.memory.TaskMemoryManager
import org.apache.spark.metrics.source.JvmSource
Expand Down Expand Up @@ -680,14 +683,80 @@ class TaskContextSuite extends SparkFunSuite with BeforeAndAfter with LocalSpark
context.markTaskCompleted(None)
assert(isFailed)
}

test("SPARK-46947: ensure the correct block manager is used to unroll memory for task") {
import BlockManagerValidationPlugin._
BlockManagerValidationPlugin.resetState()

// run a task which ignores thread interruption when spark context is shutdown
sc = new SparkContext("local", "test")

val rdd = new RDD[String](sc, List()) {
override def getPartitions = Array[Partition](StubPartition(0))

override def compute(split: Partition, context: TaskContext): Iterator[String] = {
context.addTaskCompletionListener(new TaskCompletionListener {
override def onTaskCompletion(context: TaskContext): Unit = {
try {
releaseTaskSem.acquire()
} catch {
case _: InterruptedException =>
// ignore thread interruption
}
}
})
taskStartedSem.release()
Iterator.empty
}
}
// submit the job, but don't block this thread
rdd.collectAsync()
// wait for task to start
taskStartedSem.acquire()

sc.stop()
assert(sc.isStopped)

// create a new SparkContext which will be blocked for certain amount of time
// during initializing the driver plugin below
val conf = new SparkConf()
conf.set("spark.plugins", classOf[BlockManagerValidationPlugin].getName)
sc = new SparkContext("local", "test", conf)
}
}

private object TaskContextSuite {
private object TaskContextSuite extends Logging {
@volatile var completed = false

@volatile var lastError: Throwable = _

class FakeTaskFailureException extends Exception("Fake task failure")
}

class BlockManagerValidationPlugin extends SparkPlugin {
override def driverPlugin(): DriverPlugin = {
new DriverPlugin() {
// does nothing but block the current thread for certain time for the task thread
// to progress and reproduce the issue.
BlockManagerValidationPlugin.releaseTaskSem.release()
Thread.sleep(2500)
}
}
override def executorPlugin(): ExecutorPlugin = {
new ExecutorPlugin() {
// do nothing
}
}
}

object BlockManagerValidationPlugin {
val releaseTaskSem = new Semaphore(0)
val taskStartedSem = new Semaphore(0)

def resetState(): Unit = {
releaseTaskSem.drainPermits()
taskStartedSem.drainPermits()
}
}

private case class StubPartition(index: Int) extends Partition

0 comments on commit 7437ad4

Please sign in to comment.