Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-47910][CORE] close stream when DiskBlockObjectWriter closeResources to avoid memory leak #46131

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,12 @@ private[spark] class DiskBlockObjectWriter(
*/
private var numRecordsCommitted = 0L

// For testing only.
private[storage] def getSerializerWrappedStream: OutputStream = bs

// For testing only.
private[storage] def getSerializationStream: SerializationStream = objOut

/**
* Set the checksum that the checksumOutputStream should use
*/
Expand Down Expand Up @@ -174,19 +180,36 @@ private[spark] class DiskBlockObjectWriter(
* Should call after committing or reverting partial writes.
*/
private def closeResources(): Unit = {
if (initialized) {
Utils.tryWithSafeFinally {
mcs.manualClose()
} {
channel = null
mcs = null
bs = null
fos = null
ts = null
objOut = null
initialized = false
streamOpen = false
hasBeenClosed = true
try {
if (streamOpen) {
Utils.tryWithSafeFinally {
if (null != objOut) objOut.close()
bs = null
} {
objOut = null
if (null != bs) bs.close()
bs = null
}
}
} catch {
case e: IOException =>
logInfo(log"Exception occurred while closing the output stream" +
log"${MDC(ERROR, e.getMessage)}")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

error -> info
and the exception stack trace would be useful

} finally {
if (initialized) {
Utils.tryWithSafeFinally {
mcs.manualClose()
} {
channel = null
mcs = null
bs = null
fos = null
ts = null
objOut = null
initialized = false
streamOpen = false
hasBeenClosed = true
}
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,14 @@
*/
package org.apache.spark.storage

import java.io.File
import java.io.{File, InputStream, OutputStream}
import java.nio.ByteBuffer

import scala.reflect.ClassTag

import org.apache.spark.{SparkConf, SparkException, SparkFunSuite}
import org.apache.spark.executor.ShuffleWriteMetrics
import org.apache.spark.serializer.{JavaSerializer, SerializerManager}
import org.apache.spark.serializer.{DeserializationStream, JavaSerializer, SerializationStream, Serializer, SerializerInstance, SerializerManager}
import org.apache.spark.util.Utils

class DiskBlockObjectWriterSuite extends SparkFunSuite {
Expand All @@ -43,10 +46,14 @@ class DiskBlockObjectWriterSuite extends SparkFunSuite {
private def createWriter(): (DiskBlockObjectWriter, File, ShuffleWriteMetrics) = {
val file = new File(tempDir, "somefile")
val conf = new SparkConf()
val serializerManager = new SerializerManager(new JavaSerializer(conf), conf)
val serializerManager = new CustomSerializerManager(new JavaSerializer(conf), conf, None)
val writeMetrics = new ShuffleWriteMetrics()
val writer = new DiskBlockObjectWriter(
file, serializerManager, new JavaSerializer(new SparkConf()).newInstance(), 1024, true,
file,
serializerManager,
new CustomJavaSerializer(new SparkConf()).newInstance(),
1024,
true,
writeMetrics)
(writer, file, writeMetrics)
}
Expand Down Expand Up @@ -196,9 +203,76 @@ class DiskBlockObjectWriterSuite extends SparkFunSuite {
for (i <- 1 to 500) {
writer.write(i, i)
}

val bs = writer.getSerializerWrappedStream.asInstanceOf[OutputStreamWithCloseDetecting]
val objOut = writer.getSerializationStream.asInstanceOf[SerializationStreamWithCloseDetecting]

mridulm marked this conversation as resolved.
Show resolved Hide resolved
writer.closeAndDelete()
assert(!file.exists())
assert(writeMetrics.bytesWritten == 0)
assert(writeMetrics.recordsWritten == 0)
assert(bs.isClosed)
assert(objOut.isClosed)
}
}

trait CloseDetecting {
var isClosed = false
}

class OutputStreamWithCloseDetecting(outputStream: OutputStream)
extends OutputStream
with CloseDetecting {
override def write(b: Int): Unit = outputStream.write(b)

override def close(): Unit = {
isClosed = true
outputStream.close()
}
}

class CustomSerializerManager(
defaultSerializer: Serializer,
conf: SparkConf,
encryptionKey: Option[Array[Byte]])
extends SerializerManager(defaultSerializer, conf, encryptionKey) {
override def wrapStream(blockId: BlockId, s: OutputStream): OutputStream = {
new OutputStreamWithCloseDetecting(wrapForCompression(blockId, wrapForEncryption(s)))
}
}

class CustomJavaSerializer(conf: SparkConf) extends JavaSerializer(conf) {

override def newInstance(): SerializerInstance = {
new CustomJavaSerializerInstance(super.newInstance())
}
}

class SerializationStreamWithCloseDetecting(serializationStream: SerializationStream)
extends SerializationStream with CloseDetecting {

override def close(): Unit = {
isClosed = true
serializationStream.close()
}

override def writeObject[T: ClassTag](t: T): SerializationStream =
serializationStream.writeObject(t)

override def flush(): Unit = serializationStream.flush()
}

class CustomJavaSerializerInstance(instance: SerializerInstance) extends SerializerInstance {
override def serializeStream(s: OutputStream): SerializationStream =
new SerializationStreamWithCloseDetecting(instance.serializeStream(s))

override def serialize[T: ClassTag](t: T): ByteBuffer = instance.serialize(t)

override def deserialize[T: ClassTag](bytes: ByteBuffer): T = instance.deserialize(bytes)

override def deserialize[T: ClassTag](bytes: ByteBuffer, loader: ClassLoader): T =
instance.deserialize(bytes, loader)

override def deserializeStream(s: InputStream): DeserializationStream =
instance.deserializeStream(s)
}