Skip to content

[SPARK-45025][CORE] Allow block manager memory store iterator to handle thread interrupt and perform task completion gracefully#42742

Closed
anishshri-db wants to merge 3 commits intoapache:masterfrom
anishshri-db:task/SPARK-45025
Closed

[SPARK-45025][CORE] Allow block manager memory store iterator to handle thread interrupt and perform task completion gracefully#42742
anishshri-db wants to merge 3 commits intoapache:masterfrom
anishshri-db:task/SPARK-45025

Conversation

@anishshri-db
Copy link
Contributor

@anishshri-db anishshri-db commented Aug 31, 2023

What changes were proposed in this pull request?

Allow block manager memory store iterator to handle thread interrupt and perform task completion gracefully

Why are the changes needed?

Currently the putIteratorAsBytes can remain hung even if thread interrupt is received on task cancellation leading to the task reaper killing the executor JVM eventually. The reason for this is that the interrupt is never processed within the while loop for the unroll block which leads to the task continuing running beyond the reaper timeout.

Attached the logs here for a particular task/thread:

10.1.121.105/app-20230824190614-0000/735/stderr.txt:55427:23/08/29 12:01:51 INFO CoarseGrainedExecutorBackend: Got assigned task 222564684
10.1.121.105/app-20230824190614-0000/735/stderr.txt:55494:23/08/29 12:01:51 INFO Executor: Running task 6.0 in stage 900216.0 (TID 222564684)
10.1.121.105/app-20230824190614-0000/735/stderr.txt:55983:23/08/29 12:03:22 INFO Executor: Executor is trying to kill task 6.0 in stage 900216.0 (TID 222564684), reason: another attempt succeeded
10.1.121.105/app-20230824190614-0000/735/stderr.txt:55987:23/08/29 12:03:22 INFO ShuffleMapTask: Trying to kill task 6.0 in stage 900216.0 (TID 222564684) with reason=another attempt succeeded and current stackTrace:
        net.jpountz.lz4.LZ4JNI.LZ4_compress_limitedOutput(Native Method)
        at net.jpountz.lz4.LZ4JNICompressor.compress(LZ4JNICompressor.java:36)
        at net.jpountz.lz4.LZ4Compressor.compress(LZ4Compressor.java:95)
        at net.jpountz.lz4.LZ4BlockOutputStream.flushBufferedData(LZ4BlockOutputStream.java:208)
        at net.jpountz.lz4.LZ4BlockOutputStream.write(LZ4BlockOutputStream.java:176)
        at java.io.ObjectOutputStream$BlockDataOutputStream.drain(ObjectOutputStream.java:1877)
        at java.io.ObjectOutputStream$BlockDataOutputStream.setBlockDataMode(ObjectOutputStream.java:1786)
        at java.io.ObjectOutputStream.writeExternalData(ObjectOutputStream.java:1460)
        at java.io.ObjectOutputStream.writeOrdinaryObject(ObjectOutputStream.java:1430)
        at java.io.ObjectOutputStream.writeObject0(ObjectOutputStream.java:1178)
        at java.io.ObjectOutputStream.writeObject(ObjectOutputStream.java:348)
        at org.apache.spark.serializer.JavaSerializationStream.writeObject(JavaSerializer.scala:44)
        at org.apache.spark.storage.memory.SerializedValuesHolder.storeValue(MemoryStore.scala:728)
        at org.apache.spark.storage.memory.MemoryStore.putIterator(MemoryStore.scala:224)
        at org.apache.spark.storage.memory.MemoryStore.putIteratorAsBytes(MemoryStore.scala:352)
        at org.apache.spark.storage.BlockManager.$anonfun$doPutIterator$1(BlockManager.scala:1447)
        at org.apache.spark.storage.BlockManager$$Lambda$732/1315363341.apply(Unknown Source)
        at org.apache.spark.storage.BlockManager.org$apache$spark$storage$BlockManager$$doPut(BlockManager.scala:1357)
        at org.apache.spark.storage.BlockManager.doPutIterator(BlockManager.scala:1421)
        at org.apache.spark.storage.BlockManager.getOrElseUpdate(BlockManager.scala:1240)
        at org.apache.spark.rdd.RDD.getOrCompute(RDD.scala:391)
        at org.apache.spark.rdd.RDD.iterator(RDD.scala:342)
        at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:60)
        at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:380)
        at org.apache.spark.rdd.RDD.iterator(RDD.scala:344)
        at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:60)
        at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:380)
        at org.apache.spark.rdd.RDD.iterator(RDD.scala:344)
        at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:60)
        at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:380)
        at org.apache.spark.rdd.RDD.iterator(RDD.scala:344)
        at org.apache.spark.shuffle.ShuffleWriteProcessor.write(ShuffleWriteProcessor.scala:59)
        at org.apache.spark.scheduler.ShuffleMapTask.$anonfun$runTask$3(ShuffleMapTask.scala:81)
        at org.apache.spark.scheduler.ShuffleMapTask$$Lambda$1195/1841467537.apply(Unknown Source)
        at com.databricks.spark.util.ExecutorFrameProfiler$.record(ExecutorFrameProfiler.scala:110)
        at org.apache.spark.scheduler.ShuffleMapTask.$anonfun$runTask$1(ShuffleMapTask.scala:81)
        at org.apache.spark.scheduler.ShuffleMapTask$$Lambda$1050/445760861.apply(Unknown Source)
        at com.databricks.spark.util.ExecutorFrameProfiler$.record(ExecutorFrameProfiler.scala:110)
        at org.apache.spark.scheduler.ShuffleMapTask.runTask(ShuffleMapTask.scala:53)
        at org.apache.spark.scheduler.ShuffleMapTask.runTask(ShuffleMapTask.scala:41)
        at org.apache.spark.scheduler.Task.doRunTask(Task.scala:153)
        at org.apache.spark.scheduler.Task.$anonfun$run$1(Task.scala:122)
        at org.apache.spark.scheduler.Task$$Lambda$925/288368281.apply(Unknown Source)
        at com.databricks.spark.util.ExecutorFrameProfiler$.record(ExecutorFrameProfiler.scala:110)
        at org.apache.spark.scheduler.Task.run(Task.scala:94)
        at org.apache.spark.executor.Executor$TaskRunner.$anonfun$run$13(Executor.scala:819)
        at org.apache.spark.executor.Executor$TaskRunner$$Lambda$909/144086205.apply(Unknown Source)
        at org.apache.spark.util.Utils$.tryWithSafeFinally(Utils.scala:1657)
        at org.apache.spark.executor.Executor$TaskRunner.$anonfun$run$4(Executor.scala:822)
        at org.apache.spark.executor.Executor$TaskRunner$$Lambda$793/213482103.apply$mcV$sp(Unknown Source)
        at scala.runtime.java8.JFunction0$mcV$sp.apply(JFunction0$mcV$sp.java:23)
        at com.databricks.spark.util.ExecutorFrameProfiler$.record(ExecutorFrameProfiler.scala:110)
        at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:678)
        at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149)
        at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)
        at java.lang.Thread.run(Thread.java:750)

...
10.1.121.105/app-20230824190614-0000/735/stderr.txt:58310:23/08/29 12:04:31 WARN Executor: Killed task 222564684 is still running after 68908 ms
10.1.121.105/app-20230824190614-0000/735/stderr.txt:58361:23/08/29 12:04:31 WARN Executor: Thread dump from task 222564684:
java.lang.System.identityHashCode(NativeMethod)
java.io.ObjectOutputStream$HandleTable.hash(ObjectOutputStream.java:2360)
java.io.ObjectOutputStream$HandleTable.lookup(ObjectOutputStream.java:2293)
java.io.ObjectOutputStream.writeObject0(ObjectOutputStream.java:1116)
java.io.ObjectOutputStream.writeObject(ObjectOutputStream.java:348)
org.apache.spark.serializer.JavaSerializationStream.writeObject(JavaSerializer.scala:44)
org.apache.spark.storage.memory.SerializedValuesHolder.storeValue(MemoryStore.scala:728)
org.apache.spark.storage.memory.MemoryStore.putIterator(MemoryStore.scala:224)
org.apache.spark.storage.memory.MemoryStore.putIteratorAsBytes(MemoryStore.scala:352)
org.apache.spark.storage.BlockManager.$anonfun$doPutIterator$1(BlockManager.scala:1447)o
rg.apache.spark.storage.BlockManager$$Lambda$732/1315363341.apply(Unknown Source)
org.apache.spark.storage.BlockManager.org$apache$spark$storage$BlockManager$$doPut(BlockManager.scala:1357)
...
10.1.121.105/app-20230824190614-0000/735/stderr.txt:58388:org.apache.spark.SparkException: Killing executor JVM because killed task 222564684 could not be stopped within 60000 ms.
10.1.121.105/app-20230824190614-0000/735/stderr.txt:58810:java.lang.Error: org.apache.spark.SparkException: Killing executor JVM because killed task 222564684 could not be stopped within 60000 ms.
10.1.121.105/app-20230824190614-0000/735/stderr.txt:58814:Caused by: org.apache.spark.SparkException: Killing executor JVM because killed task 222564684 could not be stopped within 60000 ms.

We can see that the stack traces are different over time but still within the while loop for putIterator.

Does this PR introduce any user-facing change?

No

How was this patch tested?

Ran MemoryStoreSuite

[info] MemoryStoreSuite:
[info] - reserve/release unroll memory (36 milliseconds)
[info] - safely unroll blocks (70 milliseconds)
[info] - safely unroll blocks through putIteratorAsValues (10 milliseconds)
[info] - safely unroll blocks through putIteratorAsValues off-heap (21 milliseconds)
[info] - safely unroll blocks through putIteratorAsBytes (138 milliseconds)
[info] - PartiallySerializedBlock.valuesIterator (6 milliseconds)
[info] - PartiallySerializedBlock.finishWritingToStream (5 milliseconds)
[info] - multiple unrolls by the same thread (8 milliseconds)
[info] - lazily create a big ByteBuffer to avoid OOM if it cannot be put into MemoryStore (3 milliseconds)
[info] - put a small ByteBuffer to MemoryStore (3 milliseconds)
[info] - SPARK-22083: Release all locks in evictBlocksToFreeSpace (43 milliseconds)
[info] - put user-defined objects to MemoryStore and remove (5 milliseconds)
[info] - put user-defined objects to MemoryStore and clear (4 milliseconds)
[info] Run completed in 1 second, 587 milliseconds.
[info] Total number of tests run: 13
[info] Suites: completed 1, aborted 0
[info] Tests: succeeded 13, failed 0, canceled 0, ignored 0, pending 0
[info] All tests passed.

Was this patch authored or co-authored using generative AI tooling?

No

…ead interrupt and perform task completion gracefully
@github-actions github-actions bot added the CORE label Aug 31, 2023
@anishshri-db
Copy link
Contributor Author

cc - @JoshRosen , @HeartSaVioR - PTAL, thx !

@HeartSaVioR
Copy link
Contributor

cc. @jiangxb1987 @Ngone51

Copy link
Contributor

@HeartSaVioR HeartSaVioR left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The rationale and the fix seems to be OK, but I doubt I'm the right one to approve and sign-off. I'd wait for experts on this area to chime in.

Copy link
Contributor

@JoshRosen JoshRosen left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If a cancelled task is getting stuck in the unroll while loop and isn't exiting in a timely manner then that must imply that neither the upstream values.next() calculation nor any memory reservation / spill calls are performing IO or are calling TaskContext.killTaskIfInterrupted() because otherwise the task would have been killed. In other words, there's no interruptible code running in the loop, hence the need to manually check for interrupts.

An alternative fix could be to wrap an InterruptibleIterator call around the upstream iterator, but that's much more likely to have adverse performance impacts per the linked comment because checking Thread.interrupted is cheaper than context.killTaskIfInterrupted.

Given that, I'm okay with adding logic to check Thread.interrupt, but I'm slightly wary of this PR's current approach of returning a value on interrupt: if we return Left(unrollMemoryUsedByThisBlock) then that doesn't directly cause the task to exit: it will continue running until it hits killTaskIfInterrupted() (which isn't guaranteed to be present) or until it tries to perform IO and neither of those operations are guaranteed to happen in a timely fashion (although in practice they probably will).

Instead, what do you think about replacing the Left(unrollMemoryUsedByThisBlock) with a throw new InterruptedException()? That exception will bubble up and help the task to exit sooner.

@anishshri-db
Copy link
Contributor Author

Instead, what do you think about replacing the Left(unrollMemoryUsedByThisBlock) with a throw new InterruptedException()? That exception will bubble up and help the task to exit sooner.

@JoshRosen - I was going to do exactly this initially :) But I was not sure whether this would be entirely safe or not, in terms of disposing the byte buffers.

It seems the caller relies on the result here to create PartiallySerializedBlock in case of failure

And within this, we add a task completion listener to ensure that we call dispose on the byte buffers ?

  private lazy val unrolledBuffer: ChunkedByteBuffer = {
    bbos.close()
    bbos.toChunkedByteBuffer
  }

  // If the task does not fully consume `valuesIterator` or otherwise fails to consume or dispose of
  // this PartiallySerializedBlock then we risk leaking of direct buffers, so we use a task
  // completion listener here in order to ensure that `unrolled.dispose()` is called at least once.
  // The dispose() method is idempotent, so it's safe to call it unconditionally.
  Option(TaskContext.get()).foreach { taskContext =>
    taskContext.addTaskCompletionListener[Unit] { _ =>
      // When a task completes, its unroll memory will automatically be freed. Thus we do not call
      // releaseUnrollMemoryForThisTask() here because we want to avoid double-freeing.
      unrolledBuffer.dispose()
    }
  }

So it seems freeing the memory is not a problem, but if we return the InterruptedException, we would still be risking leaking the direct buffers, since we won't get a chance to register this task completion listener ? Do you think this is safe/handled elsewhere in the caller when the exception is received ?

@JoshRosen
Copy link
Contributor

So it seems freeing the memory is not a problem, but if we return the InterruptedException, we would still be risking leaking the direct buffers, since we won't get a chance to register this task completion listener ? Do you think this is safe/handled elsewhere in the caller when the exception is received ?

If I understand correctly, I think this might be a pre-existing risk that we're making worse: there's nothing that prevented the old code from throwing arbitrary exceptions when computing the iterator elements.

I wonder whether we should aim to fix that pre-existing bug at a higher level. In putIteratorAsBytes, we have

val valuesHolder = new SerializedValuesHolder[T](blockId, chunkSize, classTag,
memoryMode, serializerManager)
putIterator(blockId, values, classTag, memoryMode, valuesHolder) match {
case Right(storedSize) => Right(storedSize)
case Left(unrollMemoryUsedByThisBlock) =>
Left(new PartiallySerializedBlock(
this,
serializerManager,
blockId,
valuesHolder.serializationStream,
valuesHolder.redirectableStream,
unrollMemoryUsedByThisBlock,
memoryMode,
valuesHolder.bbos,
values,
classTag))
}
}

I'm wondering whether we can restructure that code in order to wrap the putIterator call and dispose of the valuesHolder in case putIterator fails. Something along these lines (borrowing some code and comments from elsewhere in this part of Spark):

   val putIteratorResult = Utils.tryWithSafeFinallyAndFailureCallbacks {
      putIterator(blockId, values, classTag, memoryMode, valuesHolder)
    }(catchBlock = {
      // We want to close the output stream in order to free any resources associated with the
      // serializer itself (such as Kryo's internal buffers). close() might cause data to be
      // written, so redirect the output stream to discard that data.
      valuesHolder.redirectableOutputStream.setOutputStream(ByteStreams.nullOutputStream())
      valuesHolder.serializationStream.close()
      valuesHolder.bbos.close()
      valuesHolder.bbos.toChunkedByteBuffer.dispose()
    })

    putIteratorResult match {
      case Right(storedSize) => Right(storedSize)
      case Left(unrollMemoryUsedByThisBlock) =>
        Left(new PartiallySerializedBlock(
          this,
          serializerManager,
          blockId,
          valuesHolder.serializationStream,
          valuesHolder.redirectableStream,
          unrollMemoryUsedByThisBlock,
          memoryMode,
          valuesHolder.bbos,
          values,
          classTag))
    }

If the putIterator() call fails then the catchBlock will try to close the serialization stream and dispose of the block. I used tryWithSafeFinallyAndFailureCallbacks because there's non-trivial cleanup work taking place in the catch block and I didn't want that to suppress the original task exception.

I scoped the try block to exclude the new PartiallySerializedBlock because I wanted to avoid the possibility that two different pieces of cleanup logic (the task completion callback and the catch block) both call toChunkedByteBuffer().


As I look further into the pre-existing code, I'm spotting a couple of other cases where it looks like we're not guaranteed to perform proper cleanup. For example, it looks like we're not guaranteed to close the serialization stream if downstream partial unrolling code fails (or at least it's not straightforwardly obvious that cleanup will happen).

To better test those cases, I think we should add some new unit test cases to MemoryStoreSuite to test scenarios where the iterator being stored throws exceptions at various points.

@anishshri-db
Copy link
Contributor Author

@JoshRosen - agreed, don't think its being handled currently anyway. If we want to add the catch block on the caller, I think we need to do it for putIteratorAsValues as well ? Do you think its ok to split this issue into 2 parts since I think the other change might need some more work ? For the current PR, I'll just modify the return value to throw the exception and then file a follow-up JIRA ticket for the cleanup related work ?

@JoshRosen
Copy link
Contributor

If we want to add the catch block on the caller, I think we need to do it for putIteratorAsValues as well ?

I don't think we need it there: putIteratorAsValues stores deserialized values, so there's no off-heap memory to be freed or serialization streams to be closed.

Do you think its ok to split this issue into 2 parts since I think the other change might need some more work ? For the current PR, I'll just modify the return value to throw the exception and then file a follow-up JIRA ticket for the cleanup related work ?

As a compromise, what do you think about keeping this as-is (i.e. not throwing InterruptedException) and instead add a single taskContext.killTaskIfInterrupted call right before we return from putIteratorAsBytes? e.g.

    val res = putIteratorResult match {
      case Right(storedSize) => Right(storedSize)
      case Left(unrollMemoryUsedByThisBlock) =>
        Left(new PartiallySerializedBlock(
          this,
          serializerManager,
          blockId,
          valuesHolder.serializationStream,
          valuesHolder.redirectableStream,
          unrollMemoryUsedByThisBlock,
          memoryMode,
          valuesHolder.bbos,
          values,
          classTag))
    }
    Option(TaskContext.get()).foreach(_.killTaskIfInterrupted())
    res

This approach will ensure that the bytebuffer cleanup logic is run (because the task completion callback will have been registered) but also ensures that we'll exit in a timely manner rather than trying to continue processing the rest of the task's rows.

@anishshri-db
Copy link
Contributor Author

As a compromise, what do you think about keeping this as-is (i.e. not throwing InterruptedException) and instead add a single taskContext.killTaskIfInterrupted call right before we return from putIteratorAsBytes? e.g.

Done

// Unroll this block safely, checking whether we have exceeded our threshold periodically
while (values.hasNext && keepUnrolling) {
// and if no thread interrupts have been received.
while (values.hasNext && keepUnrolling && !Thread.currentThread().isInterrupted) {
Copy link
Contributor

@mridulm mridulm Sep 1, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am slightly wary of calling isInterrupted for every iteration.
Given we already have the if (elementsUnrolled ... condition below, do we want to move it to that ?

IIRC isInterrupted is a couple of orders more expensive than a boolean check.

Something like:

var interrupted = false
while (values.hasNext && keepUnrolling && !interrupted) {
  valuesHolder ...
  if (elementsUnrolled ... ){
    interrupted = Thread.currentThread().isInterrupted
   ...
  }
  elementsUnrolled += 1
}

if (interrupted) {
 ...
} else if (keepUnrolling) {

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mridulm - if we put this check inside the if condition below, it seems we will only set the interrupted flag the next time the memory check period expires ? Seems like the default value for the check period is set to 16. So, if the interrupt is received after the first element is unrolled, we will wait for 15 more elements and also do extra work in the interim, that we probably have to dispose later anyway ? So might be better to check within the while loop and exit earlier as we are doing in the PR ?

      valuesHolder.storeValue(values.next())
      if (elementsUnrolled % memoryCheckPeriod == 0) {
        val currentSize = valuesHolder.estimatedSize()
        // If our vector's size has exceeded the threshold, request more memory

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IIRC isInterrupted is a couple of orders more expensive than a boolean check.

In older JDKs it is implemented via an intrinsic and in newer ones (14+) it's just a volatile boolean check (see https://bugs.openjdk.org/browse/JDK-8229516), so I don't think it's too expensive. There's an old StackOverflow answer at https://stackoverflow.com/a/5158441/590203 offering a plausible explanation for why the intrinsic should be cheap.

On JDK 11 on my laptop, I ran a toy JMH benchmark and measured < 1ns per call.

I think these costs are very small in comparison to the per-element serialization costs.

Given this, I'm not too worried about performance regressions due to this change and think it's probably okay to check on every element rather than every 16th.

Copy link
Contributor

@mridulm mridulm Sep 3, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@anishshri-db Thread interruption is a rare occurrence when compared to number of putIterator calls - so we should try to minimize the impact for the common case, as long as it is a reasonable additional cost when interruption does occur (having said that, please see more below).

@JoshRosen It is way faster than I expected it to be [1] - while it is still slower than a local variable [2], the cost is really low when compared to the cost of other operations in the loop ... this should be just noise.

Given this, while in general I prefer to minimize unnecessary cost, I am fine with the change.

[1] Thanks for your comments, went digging into this - was fun !
I knew it was intrinsic, but the uncontended reads are much faster than I had expected (I seem to be misremembering some stats) - and thanks for jdk14 ref, was not aware of that change.
My comment "IIRC isInterrupted is a couple of orders more expensive than a boolean check." is definitely incorrect !

[2]
On my linux desktop: 2.175 +- 0.001 ns/op versus 2.424 +- 0.001 ns/op
On my mac, the difference is higher - but so is the variance ... so I am discounting that

Copy link
Contributor

@mridulm mridulm left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The proposal looks good to me - just a minor suggestion

@anishshri-db
Copy link
Contributor Author

@JoshRosen - any more comments here ? or do you think we can merge this ?

@dongjoon-hyun dongjoon-hyun changed the title [SPARK-45025] Allow block manager memory store iterator to handle thread interrupt and perform task completion gracefully [SPARK-45025][CORE] Allow block manager memory store iterator to handle thread interrupt and perform task completion gracefully Sep 1, 2023
@anishshri-db
Copy link
Contributor Author

anishshri-db commented Sep 6, 2023

@JoshRosen - seems like we are all in agreement ? Is it ok to merge the change ? Thx

cc- @HeartSaVioR

@JoshRosen
Copy link
Contributor

I'm traveling without a laptop right now, so I won't be able to merge this until next week.

@anishshri-db
Copy link
Contributor Author

Thanks @JoshRosen . @mridulm , @HeartSaVioR - would one of you be able to merge then ?

@HeartSaVioR
Copy link
Contributor

As long as I see @mridulm is OK with the fix, I can help merging this.

Thanks! Merging to master.

@mridulm
Copy link
Contributor

mridulm commented Sep 7, 2023

Thanks @HeartSaVioR !

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants