Skip to content

Commit

Permalink
Buffer streams up to a threshold.
Browse files Browse the repository at this point in the history
Signed-off-by: Pascal Spörri <psp@zurich.ibm.com>
  • Loading branch information
pspoerri committed Aug 31, 2023
1 parent f5fd647 commit 1c05385
Show file tree
Hide file tree
Showing 4 changed files with 147 additions and 44 deletions.
6 changes: 2 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,8 @@ These configuration values need to be passed to Spark to load and configure the

Changing these values might have an impact on performance.

- `spark.shuffle.s3.bufferSize`: Default size of the buffered output streams (default: `32768`,
uses `spark.shuffle.file.buffer` as default)
- `spark.shuffle.s3.bufferInputSize`: Maximum size of buffered input streams (default: `209715200`,
uses `spark.network.maxRemoteBlockSizeFetchToMem` as default)
- `spark.shuffle.s3.bufferSize`: Default buffer size when writing (default: `8388608`)
- `spark.shuffle.s3.maxBufferSizeTask`: Maximum size of the buffered output streams per task (default: `134217728`)
- `spark.shuffle.s3.cachePartitionLengths`: Cache partition lengths in memory (default: `true`)
- `spark.shuffle.s3.cacheChecksums`: Cache checksums in memory (default: `true`)
- `spark.shuffle.s3.cleanup`: Cleanup the shuffle files (default: `true`)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,8 @@ class S3ShuffleDispatcher extends Logging {
private val isS3A = rootDir.startsWith("s3a://")

// Optional
val bufferSize: Int = conf.getInt("spark.shuffle.s3.bufferSize", defaultValue = conf.get(SHUFFLE_FILE_BUFFER_SIZE).toInt * 1024)
val bufferInputSize: Int = conf.getInt("spark.shuffle.s3.bufferInputSize", defaultValue = conf.get(MAX_REMOTE_BLOCK_SIZE_FETCH_TO_MEM).toInt)
val bufferSize: Int = conf.getInt("spark.shuffle.s3.bufferSize", defaultValue = 8 * 1024 * 1024)
val maxBufferSizeTask: Int = conf.getInt("spark.shuffle.s3.maxBufferSizeTask", defaultValue = 128 * 1024 * 1024)
val cachePartitionLengths: Boolean = conf.getBoolean("spark.shuffle.s3.cachePartitionLengths", defaultValue = true)
val cacheChecksums: Boolean = conf.getBoolean("spark.shuffle.s3.cacheChecksums", defaultValue = true)
val cleanupShuffleFiles: Boolean = conf.getBoolean("spark.shuffle.s3.cleanup", defaultValue = true)
Expand All @@ -64,7 +64,7 @@ class S3ShuffleDispatcher extends Logging {

// Optional
logInfo(s"- spark.shuffle.s3.bufferSize=${bufferSize}")
logInfo(s"- spark.shuffle.s3.bufferInputSize=${bufferInputSize}")
logInfo(s"- spark.shuffle.s3.maxBufferSizeTask=${maxBufferSizeTask}")
logInfo(s"- spark.shuffle.s3.cachePartitionLengths=${cachePartitionLengths}")
logInfo(s"- spark.shuffle.s3.cacheChecksums=${cacheChecksums}")
logInfo(s"- spark.shuffle.s3.cleanup=${cleanupShuffleFiles}")
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
/**
* Copyright 2023- IBM Inc. All rights reserved
* SPDX-License-Identifier: Apache2.0
*/

package org.apache.spark.storage

import org.apache.spark.internal.Logging

import java.io.{BufferedInputStream, InputStream}
import java.util

class S3BufferedPrefetchIterator(iter: Iterator[(BlockId, S3ShuffleBlockStream)], maxBufferSize: Long) extends Iterator[(BlockId, InputStream)] with Logging {
@volatile private var memoryUsage: Long = 0
@volatile private var hasItem: Boolean = iter.hasNext
private var timeWaiting: Long = 0
private var timePrefetching: Long = 0
private var timeNext: Long = 0
private var numStreams: Long = 0
private var bytesRead: Long = 0

private var nextElement: (BlockId, S3ShuffleBlockStream) = null

private val completed = new util.LinkedList[(InputStream, BlockId, Long)]()

private def prefetchThread(): Unit = {
while (iter.hasNext || nextElement != null) {
if (nextElement == null) {
val now = System.nanoTime()
nextElement = iter.next()
timeNext = System.nanoTime() - now
}
val bsize = scala.math.min(maxBufferSize, nextElement._2.maxBytes).toInt

var fetchNext = false
synchronized {
if (memoryUsage + math.min(bsize, maxBufferSize) > maxBufferSize) {
try {
wait()
}
catch {
case _: InterruptedException =>
Thread.currentThread.interrupt()
}
} else {
fetchNext = true
}
}

if (fetchNext) {
val block = nextElement._1
val s = nextElement._2
nextElement = null
val now = System.nanoTime()
val stream = new BufferedInputStream(s, bsize)
// Fill the buffered input stream by reading and then resetting the stream.
stream.mark(bsize)
stream.read()
stream.reset()
timePrefetching += System.nanoTime() - now
bytesRead += bsize
synchronized {
memoryUsage += bsize
completed.push((stream, block, bsize))
hasItem = iter.hasNext
notifyAll()
}
}
}
}

private val self = this
private val thread = new Thread {
override def run(): Unit = {
self.prefetchThread()
}
}
thread.start()

private def printStatistics(): Unit = synchronized {
try {
val tW = timeWaiting / 1000000
val tP = timePrefetching / 1000000
val tN = timeNext / 1000000
val bR = bytesRead
val r = numStreams
// Average time per prefetch
val atP = tP / r
// Average time waiting
val atW = tW / r
// Average time next
val atN = tN / r
// Average read bandwidth
val bW = bR.toDouble / (tP.toDouble / 1000) / (1024 * 1024)
// Block size
val bs = bR / r
logInfo(s"Statistics: ${bR} bytes, ${tW} ms waiting (${atW} avg), " +
s"${tP} ms prefetching (avg: ${atP} ms - ${bs} block size - ${bW} MiB/s) " +
s"${tN} ms for next (${atN} avg)")
} catch {
case e: Exception => logError(f"Unable to print statistics: ${e.getMessage}.")
}
}

override def hasNext: Boolean = synchronized {
val result = hasItem || (completed.size() > 0)
if (!result) {
printStatistics()
}
result
}

override def next(): (BlockId, InputStream) = synchronized {
val now = System.nanoTime()
while (completed.isEmpty) {
try {
wait()
} catch {
case _: InterruptedException =>
Thread.currentThread.interrupt()
}
}
timeWaiting += System.nanoTime() - now
numStreams += 1
val result = completed.pop()
memoryUsage -= result._3
notifyAll()
return (result._2, result._1)
}
}
49 changes: 12 additions & 37 deletions src/main/scala/org/apache/spark/storage/S3ShuffleReader.scala
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,9 @@ import org.apache.spark.shuffle.{BaseShuffleHandle, ShuffleReadMetricsReporter,
import org.apache.spark.storage.ShuffleBlockFetcherIterator.FetchBlockInfo
import org.apache.spark.util.{CompletionIterator, ThreadUtils}
import org.apache.spark.util.collection.ExternalSorter
import org.apache.spark.{InterruptibleIterator, SparkConf, SparkEnv, SparkException, TaskContext}
import org.apache.spark.{InterruptibleIterator, SparkConf, SparkEnv, TaskContext}

import java.io.{BufferedInputStream, InputStream}
import java.util.zip.{CheckedInputStream, Checksum}
import scala.concurrent.duration.Duration
import scala.concurrent.{Await, ExecutionContext, Future}
import scala.concurrent.ExecutionContext

/**
* This class was adapted from Apache Spark: BlockStoreShuffleReader.
Expand All @@ -55,7 +52,7 @@ class S3ShuffleReader[K, C](

private val dispatcher = S3ShuffleDispatcher.get
private val dep = handle.dependency
private val bufferInputSize = dispatcher.bufferInputSize
private val maxBufferSizeTask = dispatcher.maxBufferSizeTask

private val fetchContinousBlocksInBatch: Boolean = {
val serializerRelocatable = dep.serializer.supportsRelocationOfSerializedObjects
Expand All @@ -77,17 +74,6 @@ class S3ShuffleReader[K, C](
doBatchFetch
}

// Source: Cassandra connector for Apache Spark (https://github.com/datastax/spark-cassandra-connector)
// com.datastax.spark.connector.datasource.JoinHelper
// License: Apache 2.0
// See here for an explanation: http://www.russellspitzer.com/2017/02/27/Concurrency-In-Spark/
def slidingPrefetchIterator[T](it: Iterator[Future[T]], batchSize: Int): Iterator[T] = {
val (firstElements, lastElement) = it.grouped(batchSize)
.sliding(2)
.span(_ => it.hasNext)
(firstElements.map(_.head) ++ lastElement.flatten).flatten.map(Await.result(_, Duration.Inf))
}

override def read(): Iterator[Product2[K, C]] = {
val serializerInstance = dep.serializer.newInstance()
val blocks = computeShuffleBlocks(handle.shuffleId,
Expand All @@ -98,35 +84,24 @@ class S3ShuffleReader[K, C](

val wrappedStreams = new S3ShuffleBlockIterator(blocks)

// Create a key/value iterator for each stream
val recordIterPromise = wrappedStreams.filterNot(_._2.maxBytes == 0).map { case (blockId, wrappedStream) =>
readMetrics.incRemoteBytesRead(wrappedStream.maxBytes) // increase byte count.
val filteredStream = wrappedStreams.filterNot(_._2.maxBytes == 0).map(f => {
readMetrics.incRemoteBytesRead(f._2.maxBytes) // increase byte count.
readMetrics.incRemoteBlocksFetched(1)
// Note: the asKeyValueIterator below wraps a key/value iterator inside of a
// NextIterator. The NextIterator makes sure that close() is called on the
// underlying InputStream when all records have been read.
Future {
val bufferSize = scala.math.min(wrappedStream.maxBytes, bufferInputSize).toInt
val stream = new BufferedInputStream(wrappedStream, bufferSize)

// Fill the buffered input stream by reading and then resetting the stream.
stream.mark(bufferSize)
stream.read()
stream.reset()

f
})
val recordIter = new S3BufferedPrefetchIterator(filteredStream, maxBufferSizeTask)
.flatMap(s => {
val stream = s._2
val blockId = s._1
val checkedStream = if (dispatcher.checksumEnabled) {
new S3ChecksumValidationStream(blockId, stream, dispatcher.checksumAlgorithm)
} else {
stream
}

serializerInstance
.deserializeStream(serializerManager.wrapStream(blockId, checkedStream))
.asKeyValueIterator
}(S3ShuffleReader.asyncExecutionContext)
}

val recordIter = slidingPrefetchIterator(recordIterPromise, dispatcher.prefetchBatchSize).flatten
})

// Update the context task metrics for each record read.
val metricIter = CompletionIterator[(Any, Any), Iterator[(Any, Any)]](
Expand Down

0 comments on commit 1c05385

Please sign in to comment.