Skip to content

Commit

Permalink
[SPARK-24307][CORE] Support reading remote cached partitions > 2gb
Browse files Browse the repository at this point in the history
(1) Netty's ByteBuf cannot support data > 2gb.  So to transfer data from a
ChunkedByteBuffer over the network, we use a custom version of
FileRegion which is backed by the ChunkedByteBuffer.

(2) On the receiving end, we need to expose all the data in a
FileSegmentManagedBuffer as a ChunkedByteBuffer.  We do that by memory
mapping the entire file in chunks.

Added unit tests.  Ran the randomized test a couple of hundred times on my laptop.  Tests cover the equivalent of SPARK-24107 for the ChunkedByteBufferFileRegion.  Also tested on a cluster with remote cache reads >2gb (in memory and on disk).

Author: Imran Rashid <irashid@cloudera.com>

Closes #21440 from squito/chunked_bb_file_region.
  • Loading branch information
squito authored and jerryshao committed Jul 20, 2018
1 parent 67e108d commit 7e84764
Show file tree
Hide file tree
Showing 5 changed files with 286 additions and 9 deletions.
11 changes: 9 additions & 2 deletions core/src/main/scala/org/apache/spark/storage/BlockManager.scala
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,8 @@ private[spark] class BlockManager(

private[spark] val externalShuffleServiceEnabled =
conf.getBoolean("spark.shuffle.service.enabled", false)
private val chunkSize =
conf.getSizeAsBytes("spark.storage.memoryMapLimitForTests", Int.MaxValue.toString).toInt

val diskBlockManager = {
// Only perform cleanup if an external service is not serving our shuffle files.
Expand Down Expand Up @@ -660,6 +662,11 @@ private[spark] class BlockManager(
* Get block from remote block managers as serialized bytes.
*/
def getRemoteBytes(blockId: BlockId): Option[ChunkedByteBuffer] = {
// TODO if we change this method to return the ManagedBuffer, then getRemoteValues
// could just use the inputStream on the temp file, rather than memory-mapping the file.
// Until then, replication can cause the process to use too much memory and get killed
// by the OS / cluster manager (not a java OOM, since its a memory-mapped file) even though
// we've read the data to disk.
logDebug(s"Getting remote block $blockId")
require(blockId != null, "BlockId is null")
var runningFailureCount = 0
Expand Down Expand Up @@ -690,7 +697,7 @@ private[spark] class BlockManager(
logDebug(s"Getting remote block $blockId from $loc")
val data = try {
blockTransferService.fetchBlockSync(
loc.host, loc.port, loc.executorId, blockId.toString, tempFileManager).nioByteBuffer()
loc.host, loc.port, loc.executorId, blockId.toString, tempFileManager)
} catch {
case NonFatal(e) =>
runningFailureCount += 1
Expand Down Expand Up @@ -724,7 +731,7 @@ private[spark] class BlockManager(
}

if (data != null) {
return Some(new ChunkedByteBuffer(data))
return Some(ChunkedByteBuffer.fromManagedBuffer(data, chunkSize))
}
logDebug(s"The value of block $blockId is null")
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,17 +17,21 @@

package org.apache.spark.util.io

import java.io.InputStream
import java.io.{File, FileInputStream, InputStream}
import java.nio.ByteBuffer
import java.nio.channels.WritableByteChannel
import java.nio.channels.{FileChannel, WritableByteChannel}
import java.nio.file.StandardOpenOption

import scala.collection.mutable.ListBuffer

import com.google.common.primitives.UnsignedBytes
import io.netty.buffer.{ByteBuf, Unpooled}

import org.apache.spark.SparkEnv
import org.apache.spark.internal.config
import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, ManagedBuffer}
import org.apache.spark.network.util.ByteArrayWritableChannel
import org.apache.spark.storage.StorageUtils
import org.apache.spark.util.Utils

/**
* Read-only byte buffer which is physically stored as multiple chunks rather than a single
Expand Down Expand Up @@ -81,10 +85,10 @@ private[spark] class ChunkedByteBuffer(var chunks: Array[ByteBuffer]) {
}

/**
* Wrap this buffer to view it as a Netty ByteBuf.
* Wrap this in a custom "FileRegion" which allows us to transfer over 2 GB.
*/
def toNetty: ByteBuf = {
Unpooled.wrappedBuffer(chunks.length, getChunks(): _*)
def toNetty: ChunkedByteBufferFileRegion = {
new ChunkedByteBufferFileRegion(this, bufferWriteChunkSize)
}

/**
Expand Down Expand Up @@ -166,6 +170,34 @@ private[spark] class ChunkedByteBuffer(var chunks: Array[ByteBuffer]) {

}

object ChunkedByteBuffer {
// TODO eliminate this method if we switch BlockManager to getting InputStreams
def fromManagedBuffer(data: ManagedBuffer, maxChunkSize: Int): ChunkedByteBuffer = {
data match {
case f: FileSegmentManagedBuffer =>
map(f.getFile, maxChunkSize, f.getOffset, f.getLength)
case other =>
new ChunkedByteBuffer(other.nioByteBuffer())
}
}

def map(file: File, maxChunkSize: Int, offset: Long, length: Long): ChunkedByteBuffer = {
Utils.tryWithResource(FileChannel.open(file.toPath, StandardOpenOption.READ)) { channel =>
var remaining = length
var pos = offset
val chunks = new ListBuffer[ByteBuffer]()
while (remaining > 0) {
val chunkSize = math.min(remaining, maxChunkSize)
val chunk = channel.map(FileChannel.MapMode.READ_ONLY, pos, chunkSize)
pos += chunkSize
remaining -= chunkSize
chunks += chunk
}
new ChunkedByteBuffer(chunks.toArray)
}
}
}

/**
* Reads data from a ChunkedByteBuffer.
*
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.util.io

import java.nio.channels.WritableByteChannel

import io.netty.channel.FileRegion
import io.netty.util.AbstractReferenceCounted

import org.apache.spark.internal.Logging
import org.apache.spark.network.util.AbstractFileRegion


/**
* This exposes a ChunkedByteBuffer as a netty FileRegion, just to allow sending > 2gb in one netty
* message. This is because netty cannot send a ByteBuf > 2g, but it can send a large FileRegion,
* even though the data is not backed by a file.
*/
private[io] class ChunkedByteBufferFileRegion(
private val chunkedByteBuffer: ChunkedByteBuffer,
private val ioChunkSize: Int) extends AbstractFileRegion {

private var _transferred: Long = 0
// this duplicates the original chunks, so we're free to modify the position, limit, etc.
private val chunks = chunkedByteBuffer.getChunks()
private val size = chunks.foldLeft(0L) { _ + _.remaining() }

protected def deallocate: Unit = {}

override def count(): Long = size

// this is the "start position" of the overall Data in the backing file, not our current position
override def position(): Long = 0

override def transferred(): Long = _transferred

private var currentChunkIdx = 0

def transferTo(target: WritableByteChannel, position: Long): Long = {
assert(position == _transferred)
if (position == size) return 0L
var keepGoing = true
var written = 0L
var currentChunk = chunks(currentChunkIdx)
while (keepGoing) {
while (currentChunk.hasRemaining && keepGoing) {
val ioSize = Math.min(currentChunk.remaining(), ioChunkSize)
val originalLimit = currentChunk.limit()
currentChunk.limit(currentChunk.position() + ioSize)
val thisWriteSize = target.write(currentChunk)
currentChunk.limit(originalLimit)
written += thisWriteSize
if (thisWriteSize < ioSize) {
// the channel did not accept our entire write. We do *not* keep trying -- netty wants
// us to just stop, and report how much we've written.
keepGoing = false
}
}
if (keepGoing) {
// advance to the next chunk (if there are any more)
currentChunkIdx += 1
if (currentChunkIdx == chunks.size) {
keepGoing = false
} else {
currentChunk = chunks(currentChunkIdx)
}
}
}
_transferred += written
written
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.io

import java.nio.ByteBuffer
import java.nio.channels.WritableByteChannel

import scala.util.Random

import org.mockito.Mockito.when
import org.scalatest.BeforeAndAfterEach
import org.scalatest.mockito.MockitoSugar

import org.apache.spark.{SparkConf, SparkEnv, SparkFunSuite}
import org.apache.spark.internal.config
import org.apache.spark.util.io.ChunkedByteBuffer

class ChunkedByteBufferFileRegionSuite extends SparkFunSuite with MockitoSugar
with BeforeAndAfterEach {

override protected def beforeEach(): Unit = {
super.beforeEach()
val conf = new SparkConf()
val env = mock[SparkEnv]
SparkEnv.set(env)
when(env.conf).thenReturn(conf)
}

override protected def afterEach(): Unit = {
SparkEnv.set(null)
}

private def generateChunkedByteBuffer(nChunks: Int, perChunk: Int): ChunkedByteBuffer = {
val bytes = (0 until nChunks).map { chunkIdx =>
val bb = ByteBuffer.allocate(perChunk)
(0 until perChunk).foreach { idx =>
bb.put((chunkIdx * perChunk + idx).toByte)
}
bb.position(0)
bb
}.toArray
new ChunkedByteBuffer(bytes)
}

test("transferTo can stop and resume correctly") {
SparkEnv.get.conf.set(config.BUFFER_WRITE_CHUNK_SIZE, 9L)
val cbb = generateChunkedByteBuffer(4, 10)
val fileRegion = cbb.toNetty

val targetChannel = new LimitedWritableByteChannel(40)

var pos = 0L
// write the fileregion to the channel, but with the transfer limited at various spots along
// the way.

// limit to within the first chunk
targetChannel.acceptNBytes = 5
pos = fileRegion.transferTo(targetChannel, pos)
assert(targetChannel.pos === 5)

// a little bit further within the first chunk
targetChannel.acceptNBytes = 2
pos += fileRegion.transferTo(targetChannel, pos)
assert(targetChannel.pos === 7)

// past the first chunk, into the 2nd
targetChannel.acceptNBytes = 6
pos += fileRegion.transferTo(targetChannel, pos)
assert(targetChannel.pos === 13)

// right to the end of the 2nd chunk
targetChannel.acceptNBytes = 7
pos += fileRegion.transferTo(targetChannel, pos)
assert(targetChannel.pos === 20)

// rest of 2nd chunk, all of 3rd, some of 4th
targetChannel.acceptNBytes = 15
pos += fileRegion.transferTo(targetChannel, pos)
assert(targetChannel.pos === 35)

// now till the end
targetChannel.acceptNBytes = 5
pos += fileRegion.transferTo(targetChannel, pos)
assert(targetChannel.pos === 40)

// calling again at the end should be OK
targetChannel.acceptNBytes = 20
fileRegion.transferTo(targetChannel, pos)
assert(targetChannel.pos === 40)
}

test(s"transfer to with random limits") {
val rng = new Random()
val seed = System.currentTimeMillis()
logInfo(s"seed = $seed")
rng.setSeed(seed)
val chunkSize = 1e4.toInt
SparkEnv.get.conf.set(config.BUFFER_WRITE_CHUNK_SIZE, rng.nextInt(chunkSize).toLong)

val cbb = generateChunkedByteBuffer(50, chunkSize)
val fileRegion = cbb.toNetty
val transferLimit = 1e5.toInt
val targetChannel = new LimitedWritableByteChannel(transferLimit)
while (targetChannel.pos < cbb.size) {
val nextTransferSize = rng.nextInt(transferLimit)
targetChannel.acceptNBytes = nextTransferSize
fileRegion.transferTo(targetChannel, targetChannel.pos)
}
assert(0 === fileRegion.transferTo(targetChannel, targetChannel.pos))
}

/**
* This mocks a channel which only accepts a limited number of bytes at a time. It also verifies
* the written data matches our expectations as the data is received.
*/
private class LimitedWritableByteChannel(maxWriteSize: Int) extends WritableByteChannel {
val bytes = new Array[Byte](maxWriteSize)
var acceptNBytes = 0
var pos = 0

override def write(src: ByteBuffer): Int = {
val length = math.min(acceptNBytes, src.remaining())
src.get(bytes, 0, length)
acceptNBytes -= length
// verify we got the right data
(0 until length).foreach { idx =>
assert(bytes(idx) === (pos + idx).toByte, s"; wrong data at ${pos + idx}")
}
pos += length
length
}

override def isOpen: Boolean = true

override def close(): Unit = {}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ class ChunkedByteBufferSuite extends SparkFunSuite with SharedSparkContext {
assert(emptyChunkedByteBuffer.getChunks().isEmpty)
assert(emptyChunkedByteBuffer.toArray === Array.empty)
assert(emptyChunkedByteBuffer.toByteBuffer.capacity() === 0)
assert(emptyChunkedByteBuffer.toNetty.capacity() === 0)
assert(emptyChunkedByteBuffer.toNetty.count() === 0)
emptyChunkedByteBuffer.toInputStream(dispose = false).close()
emptyChunkedByteBuffer.toInputStream(dispose = true).close()
}
Expand Down

0 comments on commit 7e84764

Please sign in to comment.