Skip to content

Commit

Permalink
Pass callbacks cleanly.
Browse files Browse the repository at this point in the history
  • Loading branch information
rxin committed Aug 19, 2014
1 parent 603dce7 commit 4c6d0ee
Show file tree
Hide file tree
Showing 6 changed files with 152 additions and 73 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.network.netty.client

import java.util.EventListener


trait BlockClientListener extends EventListener {

def onFetchSuccess(blockId: String, data: ReferenceCountedBuffer): Unit

def onFetchFailure(blockId: String, errorMsg: String): Unit

}
Original file line number Diff line number Diff line change
Expand Up @@ -37,14 +37,14 @@ import org.apache.spark.Logging
*
* See [[org.apache.spark.network.netty.server.BlockServer]] for client/server protocol.
*
* Concurrency: [[BlockFetchingClient]] is not thread safe and should not be shared.
* Concurrency: thread safe and can be called from multiple threads.
*/
@throws[TimeoutException]
private[spark]
class BlockFetchingClient(factory: BlockFetchingClientFactory, hostname: String, port: Int)
extends Logging {

val handler = new BlockFetchingClientHandler
private val handler = new BlockFetchingClientHandler

/** Netty Bootstrap for creating the TCP connection. */
private val bootstrap: Bootstrap = {
Expand Down Expand Up @@ -84,17 +84,9 @@ class BlockFetchingClient(factory: BlockFetchingClientFactory, hostname: String,
* rate of fetching; otherwise we could run out of memory.
*
* @param blockIds sequence of block ids to fetch.
* @param blockFetchSuccessCallback callback function when a block is successfully fetched.
* First argument is the block id, and second argument is the
* raw data in a ByteBuffer.
* @param blockFetchFailureCallback callback function when we failed to fetch any of the blocks.
* First argument is the block id, and second argument is the
* error message.
* @param listener callback to fire on fetch success / failure.
*/
def fetchBlocks(
blockIds: Seq[String],
blockFetchSuccessCallback: (String, ReferenceCountedBuffer) => Unit,
blockFetchFailureCallback: (String, String) => Unit): Unit = {
def fetchBlocks(blockIds: Seq[String], listener: BlockClientListener): Unit = {
// It's best to limit the number of "write" calls since it needs to traverse the whole pipeline.
// It's also best to limit the number of "flush" calls since it requires system calls.
// Let's concatenate the string and then call writeAndFlush once.
Expand All @@ -106,9 +98,9 @@ class BlockFetchingClient(factory: BlockFetchingClientFactory, hostname: String,
s"Sending request $blockIds to $hostname:$port"
}

// TODO: This is not the most elegant way to handle this ...
handler.blockFetchSuccessCallback = blockFetchSuccessCallback
handler.blockFetchFailureCallback = blockFetchFailureCallback
blockIds.foreach { blockId =>
handler.addRequest(blockId, listener)
}

val writeFuture = cf.channel().writeAndFlush(blockIds.mkString("\n") + "\n")
writeFuture.addListener(new ChannelFutureListener {
Expand All @@ -120,8 +112,13 @@ class BlockFetchingClient(factory: BlockFetchingClientFactory, hostname: String,
}
} else {
// Fail all blocks.
logError(s"Failed to send request $blockIds to $hostname:$port", future.cause)
blockIds.foreach(blockFetchFailureCallback(_, future.cause.getMessage))
val errorMsg =
s"Failed to send request $blockIds to $hostname:$port: ${future.cause.getMessage}"
logError(errorMsg, future.cause)
blockIds.foreach { blockId =>
listener.onFetchFailure(blockId, errorMsg)
handler.removeRequest(blockId)
}
}
}
})
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,15 +26,39 @@ import org.apache.spark.Logging
/**
* Handler that processes server responses. It uses the protocol documented in
* [[org.apache.spark.network.netty.server.BlockServer]].
*
* Concurrency: thread safe and can be called from multiple threads.
*/
private[client]
class BlockFetchingClientHandler extends SimpleChannelInboundHandler[ByteBuf] with Logging {

var blockFetchSuccessCallback: (String, ReferenceCountedBuffer) => Unit = _
var blockFetchFailureCallback: (String, String) => Unit = _
/** Tracks the list of outstanding requests and their listeners on success/failure. */
private val outstandingRequests = java.util.Collections.synchronizedMap {
new java.util.HashMap[String, BlockClientListener]
}

def addRequest(blockId: String, listener: BlockClientListener): Unit = {
outstandingRequests.put(blockId, listener)
}

def removeRequest(blockId: String): Unit = {
outstandingRequests.remove(blockId)
}

override def exceptionCaught(ctx: ChannelHandlerContext, cause: Throwable): Unit = {
logError(s"Exception in connection from ${ctx.channel.remoteAddress}", cause)
val errorMsg = s"Exception in connection from ${ctx.channel.remoteAddress}: ${cause.getMessage}"
logError(errorMsg, cause)

// Fire the failure callback for all outstanding blocks
outstandingRequests.synchronized {
val iter = outstandingRequests.entrySet().iterator()
while (iter.hasNext) {
val entry = iter.next()
entry.getValue.onFetchFailure(entry.getKey, errorMsg)
}
outstandingRequests.clear()
}

ctx.close()
}

Expand All @@ -54,10 +78,26 @@ class BlockFetchingClientHandler extends SimpleChannelInboundHandler[ByteBuf] wi
in.readBytes(errorMessageBytes)
val errorMsg = new String(errorMessageBytes)
logTrace(s"Received block $blockId ($blockSize B) with error $errorMsg from $server")
blockFetchFailureCallback(blockId, errorMsg)

val listener = outstandingRequests.get(blockId)
if (listener == null) {
// Ignore callback
logWarning(s"Got a response for block $blockId but it is not in our outstanding requests")
} else {
outstandingRequests.remove(blockId)
listener.onFetchFailure(blockId, errorMsg)
}
} else {
logTrace(s"Received block $blockId ($blockSize B) from $server")
blockFetchSuccessCallback(blockId, new ReferenceCountedBuffer(in))

val listener = outstandingRequests.get(blockId)
if (listener == null) {
// Ignore callback
logWarning(s"Got a response for block $blockId but it is not in our outstanding requests")
} else {
outstandingRequests.remove(blockId)
listener.onFetchSuccess(blockId, new ReferenceCountedBuffer(in))
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
package org.apache.spark.storage

import java.util.concurrent.LinkedBlockingQueue
import org.apache.spark.network.netty.client.{LazyInitIterator, ReferenceCountedBuffer}
import org.apache.spark.network.netty.client.{BlockClientListener, LazyInitIterator, ReferenceCountedBuffer}

import scala.collection.mutable.ArrayBuffer
import scala.collection.mutable.HashSet
Expand Down Expand Up @@ -285,37 +285,40 @@ object BlockFetcherIterator {

client.fetchBlocks(
blocks,
(blockId: String, refBuf: ReferenceCountedBuffer) => {
// Increment the reference count so the buffer won't be recycled.
// TODO: This could result in memory leaks when the task is stopped due to exception
// before the iterator is exhausted.
refBuf.retain()
val buf = refBuf.byteBuffer()
val blockSize = buf.remaining()
val bid = BlockId(blockId)

// TODO: remove code duplication between here and BlockManager.dataDeserialization.
results.put(new FetchResult(bid, sizeMap(bid), () => {
def createIterator: Iterator[Any] = {
val stream = blockManager.wrapForCompression(bid, refBuf.inputStream())
serializer.newInstance().deserializeStream(stream).asIterator
new BlockClientListener {
override def onFetchFailure(blockId: String, errorMsg: String): Unit = {
logError(s"Could not get block(s) from $cmId with error: $errorMsg")
for ((blockId, size) <- req.blocks) {
results.put(new FetchResult(blockId, -1, null))
}
new LazyInitIterator(createIterator) {
// Release the buffer when we are done traversing it.
override def close(): Unit = refBuf.release()
}
}))

readMetrics.synchronized {
readMetrics.remoteBytesRead += blockSize
readMetrics.remoteBlocksFetched += 1
}
logDebug("Got remote block " + blockId + " after " + Utils.getUsedTimeMs(startTime))
},
(blockId: String, errorMsg: String) => {
logError(s"Could not get block(s) from $cmId with error: $errorMsg")
for ((blockId, size) <- req.blocks) {
results.put(new FetchResult(blockId, -1, null))

override def onFetchSuccess(blockId: String, data: ReferenceCountedBuffer): Unit = {
// Increment the reference count so the buffer won't be recycled.
// TODO: This could result in memory leaks when the task is stopped due to exception
// before the iterator is exhausted.
data.retain()
val buf = data.byteBuffer()
val blockSize = buf.remaining()
val bid = BlockId(blockId)

// TODO: remove code duplication between here and BlockManager.dataDeserialization.
results.put(new FetchResult(bid, sizeMap(bid), () => {
def createIterator: Iterator[Any] = {
val stream = blockManager.wrapForCompression(bid, data.inputStream())
serializer.newInstance().deserializeStream(stream).asIterator
}
new LazyInitIterator(createIterator) {
// Release the buffer when we are done traversing it.
override def close(): Unit = data.release()
}
}))

readMetrics.synchronized {
readMetrics.remoteBytesRead += blockSize
readMetrics.remoteBlocksFetched += 1
}
logDebug("Got remote block " + blockId + " after " + Utils.getUsedTimeMs(startTime))
}
}
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ import io.netty.buffer.{ByteBufUtil, Unpooled}
import org.scalatest.{BeforeAndAfterAll, FunSuite}

import org.apache.spark.SparkConf
import org.apache.spark.network.netty.client.{ReferenceCountedBuffer, BlockFetchingClientFactory}
import org.apache.spark.network.netty.client.{BlockClientListener, ReferenceCountedBuffer, BlockFetchingClientFactory}
import org.apache.spark.network.netty.server.BlockServer
import org.apache.spark.storage.{FileSegment, BlockDataProvider}

Expand Down Expand Up @@ -99,15 +99,18 @@ class ServerClientIntegrationSuite extends FunSuite with BeforeAndAfterAll {

client.fetchBlocks(
blockIds,
(blockId, buf) => {
receivedBlockIds.add(blockId)
buf.retain()
receivedBuffers.add(buf)
sem.release()
},
(blockId, errorMsg) => {
errorBlockIds.add(blockId)
sem.release()
new BlockClientListener {
override def onFetchFailure(blockId: String, errorMsg: String): Unit = {
errorBlockIds.add(blockId)
sem.release()
}

override def onFetchSuccess(blockId: String, data: ReferenceCountedBuffer): Unit = {
receivedBlockIds.add(blockId)
data.retain()
receivedBuffers.add(data)
sem.release()
}
}
)
if (!sem.tryAcquire(blockIds.size, 30, TimeUnit.SECONDS)) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,12 +35,17 @@ class BlockFetchingClientHandlerSuite extends FunSuite {
var parsedBlockId: String = ""
var parsedBlockData: String = ""
val handler = new BlockFetchingClientHandler
handler.blockFetchSuccessCallback = (bid, refCntBuf) => {
parsedBlockId = bid
val bytes = new Array[Byte](refCntBuf.byteBuffer().remaining)
refCntBuf.byteBuffer().get(bytes)
parsedBlockData = new String(bytes)
}
handler.addRequest(blockId,
new BlockClientListener {
override def onFetchFailure(blockId: String, errorMsg: String): Unit = ???
override def onFetchSuccess(bid: String, refCntBuf: ReferenceCountedBuffer): Unit = {
parsedBlockId = bid
val bytes = new Array[Byte](refCntBuf.byteBuffer().remaining)
refCntBuf.byteBuffer().get(bytes)
parsedBlockData = new String(bytes)
}
}
)

val channel = new EmbeddedChannel(handler)
val buf = ByteBuffer.allocate(totalLength + 4) // 4 bytes for the length field itself
Expand All @@ -65,11 +70,13 @@ class BlockFetchingClientHandlerSuite extends FunSuite {
var parsedBlockId: String = ""
var parsedErrorMsg: String = ""
val handler = new BlockFetchingClientHandler
handler.blockFetchFailureCallback = (bid, msg) => {
parsedBlockId = bid
parsedErrorMsg = msg
}

handler.addRequest(blockId, new BlockClientListener {
override def onFetchFailure(bid: String, msg: String) ={
parsedBlockId = bid
parsedErrorMsg = msg
}
override def onFetchSuccess(bid: String, refCntBuf: ReferenceCountedBuffer) = ???
})
val channel = new EmbeddedChannel(handler)
val buf = ByteBuffer.allocate(totalLength + 4) // 4 bytes for the length field itself
buf.putInt(totalLength)
Expand Down

0 comments on commit 4c6d0ee

Please sign in to comment.