From 165eab1518f5184ef9609f26d374c5ccefd05472 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Tue, 9 Sep 2014 00:29:33 -0700 Subject: [PATCH 01/46] [SPARK-3453] Refactor Netty module to use BlockTransferService. Also includes some partial support for uploading blocks. --- .../apache/spark/network/ManagedBuffer.scala | 19 +- ...FetchingClient.scala => BlockClient.scala} | 41 ++- ...Factory.scala => BlockClientFactory.scala} | 38 +-- .../network/netty/BlockClientHandler.scala | 86 +++++++ .../netty/{server => }/BlockServer.scala | 39 +-- .../network/netty/BlockServerHandler.scala | 98 +++++++ .../netty/NettyBlockTransferService.scala | 83 ++++++ .../spark/network/netty/PathResolver.scala | 25 -- .../netty/client/BlockClientListener.scala | 29 --- .../client/BlockFetchingClientHandler.scala | 103 -------- .../netty/client/LazyInitIterator.scala | 44 ---- .../netty/client/ReferenceCountedBuffer.scala | 47 ---- .../apache/spark/network/netty/protocol.scala | 243 ++++++++++++++++++ .../network/netty/server/BlockHeader.scala | 32 --- .../netty/server/BlockHeaderEncoder.scala | 47 ---- .../BlockServerChannelInitializer.scala | 40 --- .../netty/server/BlockServerHandler.scala | 140 ---------- .../spark/storage/BlockDataProvider.scala | 32 --- .../netty/BlockClientHandlerSuite.scala | 129 ++++++++++ .../spark/network/netty/ProtocolSuite.scala | 88 +++++++ .../netty/ServerClientIntegrationSuite.scala | 90 ++++--- .../network/netty/TestManagedBuffer.scala | 68 +++++ .../BlockFetchingClientHandlerSuite.scala | 105 -------- .../server/BlockHeaderEncoderSuite.scala | 64 ----- .../server/BlockServerHandlerSuite.scala | 107 -------- 25 files changed, 898 insertions(+), 939 deletions(-) rename core/src/main/scala/org/apache/spark/network/netty/{client/BlockFetchingClient.scala => BlockClient.scala} (70%) rename core/src/main/scala/org/apache/spark/network/netty/{client/BlockFetchingClientFactory.scala => BlockClientFactory.scala} (66%) create mode 100644 core/src/main/scala/org/apache/spark/network/netty/BlockClientHandler.scala rename core/src/main/scala/org/apache/spark/network/netty/{server => }/BlockServer.scala (77%) create mode 100644 core/src/main/scala/org/apache/spark/network/netty/BlockServerHandler.scala create mode 100644 core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala delete mode 100644 core/src/main/scala/org/apache/spark/network/netty/PathResolver.scala delete mode 100644 core/src/main/scala/org/apache/spark/network/netty/client/BlockClientListener.scala delete mode 100644 core/src/main/scala/org/apache/spark/network/netty/client/BlockFetchingClientHandler.scala delete mode 100644 core/src/main/scala/org/apache/spark/network/netty/client/LazyInitIterator.scala delete mode 100644 core/src/main/scala/org/apache/spark/network/netty/client/ReferenceCountedBuffer.scala create mode 100644 core/src/main/scala/org/apache/spark/network/netty/protocol.scala delete mode 100644 core/src/main/scala/org/apache/spark/network/netty/server/BlockHeader.scala delete mode 100644 core/src/main/scala/org/apache/spark/network/netty/server/BlockHeaderEncoder.scala delete mode 100644 core/src/main/scala/org/apache/spark/network/netty/server/BlockServerChannelInitializer.scala delete mode 100644 core/src/main/scala/org/apache/spark/network/netty/server/BlockServerHandler.scala delete mode 100644 core/src/main/scala/org/apache/spark/storage/BlockDataProvider.scala create mode 100644 core/src/test/scala/org/apache/spark/network/netty/BlockClientHandlerSuite.scala create mode 100644 core/src/test/scala/org/apache/spark/network/netty/ProtocolSuite.scala create mode 100644 core/src/test/scala/org/apache/spark/network/netty/TestManagedBuffer.scala delete mode 100644 core/src/test/scala/org/apache/spark/network/netty/client/BlockFetchingClientHandlerSuite.scala delete mode 100644 core/src/test/scala/org/apache/spark/network/netty/server/BlockHeaderEncoderSuite.scala delete mode 100644 core/src/test/scala/org/apache/spark/network/netty/server/BlockServerHandlerSuite.scala diff --git a/core/src/main/scala/org/apache/spark/network/ManagedBuffer.scala b/core/src/main/scala/org/apache/spark/network/ManagedBuffer.scala index a4409181ec907..9c298132fcfba 100644 --- a/core/src/main/scala/org/apache/spark/network/ManagedBuffer.scala +++ b/core/src/main/scala/org/apache/spark/network/ManagedBuffer.scala @@ -25,7 +25,8 @@ import java.nio.channels.FileChannel.MapMode import scala.util.Try import com.google.common.io.ByteStreams -import io.netty.buffer.{ByteBufInputStream, ByteBuf} +import io.netty.buffer.{Unpooled, ByteBufInputStream, ByteBuf} +import io.netty.channel.DefaultFileRegion import org.apache.spark.util.{ByteBufferInputStream, Utils} @@ -38,7 +39,7 @@ import org.apache.spark.util.{ByteBufferInputStream, Utils} * - NioByteBufferManagedBuffer: data backed by a NIO ByteBuffer * - NettyByteBufManagedBuffer: data backed by a Netty ByteBuf */ -sealed abstract class ManagedBuffer { +abstract class ManagedBuffer { // Note that all the methods are defined with parenthesis because their implementations can // have side effects (io operations). @@ -57,6 +58,11 @@ sealed abstract class ManagedBuffer { * it does not go over the limit. */ def inputStream(): InputStream + + /** + * Convert the buffer into an Netty object, used to write the data out. + */ + private[network] def convertToNetty(): AnyRef } @@ -113,7 +119,10 @@ final class FileSegmentManagedBuffer(val file: File, val offset: Long, val lengt } } - override def toString: String = s"${getClass.getName}($file, $offset, $length)" + private[network] override def convertToNetty(): AnyRef = { + val fileChannel = new FileInputStream(file).getChannel + new DefaultFileRegion(fileChannel, offset, length) + } } @@ -127,6 +136,8 @@ final class NioByteBufferManagedBuffer(buf: ByteBuffer) extends ManagedBuffer { override def nioByteBuffer() = buf.duplicate() override def inputStream() = new ByteBufferInputStream(buf) + + private[network] override def convertToNetty(): AnyRef = Unpooled.wrappedBuffer(buf) } @@ -141,6 +152,8 @@ final class NettyByteBufManagedBuffer(buf: ByteBuf) extends ManagedBuffer { override def inputStream() = new ByteBufInputStream(buf) + private[network] override def convertToNetty(): AnyRef = buf + // TODO(rxin): Promote this to top level ManagedBuffer interface and add documentation for it. def release(): Unit = buf.release() } diff --git a/core/src/main/scala/org/apache/spark/network/netty/client/BlockFetchingClient.scala b/core/src/main/scala/org/apache/spark/network/netty/BlockClient.scala similarity index 70% rename from core/src/main/scala/org/apache/spark/network/netty/client/BlockFetchingClient.scala rename to core/src/main/scala/org/apache/spark/network/netty/BlockClient.scala index 5aea7ba2f3673..95af2565bcc39 100644 --- a/core/src/main/scala/org/apache/spark/network/netty/client/BlockFetchingClient.scala +++ b/core/src/main/scala/org/apache/spark/network/netty/BlockClient.scala @@ -15,36 +15,35 @@ * limitations under the License. */ -package org.apache.spark.network.netty.client +package org.apache.spark.network.netty import java.util.concurrent.TimeoutException import io.netty.bootstrap.Bootstrap import io.netty.buffer.PooledByteBufAllocator import io.netty.channel.socket.SocketChannel -import io.netty.channel.{ChannelFutureListener, ChannelFuture, ChannelInitializer, ChannelOption} -import io.netty.handler.codec.LengthFieldBasedFrameDecoder -import io.netty.handler.codec.string.StringEncoder -import io.netty.util.CharsetUtil +import io.netty.channel.{ChannelFuture, ChannelFutureListener, ChannelInitializer, ChannelOption} import org.apache.spark.Logging +import org.apache.spark.network.BlockFetchingListener + /** - * Client for fetching data blocks from [[org.apache.spark.network.netty.server.BlockServer]]. - * Use [[BlockFetchingClientFactory]] to instantiate this client. + * Client for [[NettyBlockTransferService]]. Use [[BlockClientFactory]] to + * instantiate this client. * * The constructor blocks until a connection is successfully established. * - * See [[org.apache.spark.network.netty.server.BlockServer]] for client/server protocol. - * * Concurrency: thread safe and can be called from multiple threads. */ @throws[TimeoutException] -private[spark] -class BlockFetchingClient(factory: BlockFetchingClientFactory, hostname: String, port: Int) +private[netty] +class BlockClient(factory: BlockClientFactory, hostname: String, port: Int) extends Logging { - private val handler = new BlockFetchingClientHandler + private val handler = new BlockClientHandler + private val encoder = new ClientRequestEncoder + private val decoder = new ServerResponseDecoder /** Netty Bootstrap for creating the TCP connection. */ private val bootstrap: Bootstrap = { @@ -61,9 +60,9 @@ class BlockFetchingClient(factory: BlockFetchingClientFactory, hostname: String, b.handler(new ChannelInitializer[SocketChannel] { override def initChannel(ch: SocketChannel): Unit = { ch.pipeline - .addLast("encoder", new StringEncoder(CharsetUtil.UTF_8)) - // maxFrameLength = 2G, lengthFieldOffset = 0, lengthFieldLength = 4 - .addLast("framedLengthDecoder", new LengthFieldBasedFrameDecoder(Int.MaxValue, 0, 4)) + .addLast("clientRequestEncoder", encoder) + .addLast("frameDecoder", ProtocolUtils.createFrameDecoder()) + .addLast("serverResponseDecoder", decoder) .addLast("handler", handler) } }) @@ -86,12 +85,7 @@ class BlockFetchingClient(factory: BlockFetchingClientFactory, hostname: String, * @param blockIds sequence of block ids to fetch. * @param listener callback to fire on fetch success / failure. */ - def fetchBlocks(blockIds: Seq[String], listener: BlockClientListener): Unit = { - // It's best to limit the number of "write" calls since it needs to traverse the whole pipeline. - // It's also best to limit the number of "flush" calls since it requires system calls. - // Let's concatenate the string and then call writeAndFlush once. - // This is also why this implementation might be more efficient than multiple, separate - // fetch block calls. + def fetchBlocks(blockIds: Seq[String], listener: BlockFetchingListener): Unit = { var startTime: Long = 0 logTrace { startTime = System.nanoTime @@ -102,8 +96,7 @@ class BlockFetchingClient(factory: BlockFetchingClientFactory, hostname: String, handler.addRequest(blockId, listener) } - val writeFuture = cf.channel().writeAndFlush(blockIds.mkString("\n") + "\n") - writeFuture.addListener(new ChannelFutureListener { + cf.channel().writeAndFlush(BlockFetchRequest(blockIds)).addListener(new ChannelFutureListener { override def operationComplete(future: ChannelFuture): Unit = { if (future.isSuccess) { logTrace { @@ -116,9 +109,9 @@ class BlockFetchingClient(factory: BlockFetchingClientFactory, hostname: String, s"Failed to send request $blockIds to $hostname:$port: ${future.cause.getMessage}" logError(errorMsg, future.cause) blockIds.foreach { blockId => - listener.onFetchFailure(blockId, errorMsg) handler.removeRequest(blockId) } + listener.onBlockFetchFailure(new RuntimeException(errorMsg)) } } }) diff --git a/core/src/main/scala/org/apache/spark/network/netty/client/BlockFetchingClientFactory.scala b/core/src/main/scala/org/apache/spark/network/netty/BlockClientFactory.scala similarity index 66% rename from core/src/main/scala/org/apache/spark/network/netty/client/BlockFetchingClientFactory.scala rename to core/src/main/scala/org/apache/spark/network/netty/BlockClientFactory.scala index 2b28402c52b49..0777275cd4fe3 100644 --- a/core/src/main/scala/org/apache/spark/network/netty/client/BlockFetchingClientFactory.scala +++ b/core/src/main/scala/org/apache/spark/network/netty/BlockClientFactory.scala @@ -15,36 +15,34 @@ * limitations under the License. */ -package org.apache.spark.network.netty.client +package org.apache.spark.network.netty -import io.netty.channel.epoll.{EpollEventLoopGroup, EpollSocketChannel} +import io.netty.channel.epoll.{Epoll, EpollEventLoopGroup, EpollSocketChannel} import io.netty.channel.nio.NioEventLoopGroup import io.netty.channel.oio.OioEventLoopGroup import io.netty.channel.socket.nio.NioSocketChannel import io.netty.channel.socket.oio.OioSocketChannel -import io.netty.channel.{EventLoopGroup, Channel} +import io.netty.channel.{Channel, EventLoopGroup} import org.apache.spark.SparkConf -import org.apache.spark.network.netty.NettyConfig import org.apache.spark.util.Utils + /** - * Factory for creating [[BlockFetchingClient]] by using createClient. This factory reuses + * Factory for creating [[BlockClient]] by using createClient. This factory reuses * the worker thread pool for Netty. - * - * Concurrency: createClient is safe to be called from multiple threads concurrently. */ -private[spark] -class BlockFetchingClientFactory(val conf: NettyConfig) { +private[netty] +class BlockClientFactory(val conf: NettyConfig) { def this(sparkConf: SparkConf) = this(new NettyConfig(sparkConf)) /** A thread factory so the threads are named (for debugging). */ - val threadFactory = Utils.namedThreadFactory("spark-shuffle-client") + private[netty] val threadFactory = Utils.namedThreadFactory("spark-shuffle-client") /** The following two are instantiated by the [[init]] method, depending ioMode. */ - var socketChannelClass: Class[_ <: Channel] = _ - var workerGroup: EventLoopGroup = _ + private[netty] var socketChannelClass: Class[_ <: Channel] = _ + private[netty] var workerGroup: EventLoopGroup = _ init() @@ -63,20 +61,12 @@ class BlockFetchingClientFactory(val conf: NettyConfig) { workerGroup = new EpollEventLoopGroup(0, threadFactory) } + // For auto mode, first try epoll (only available on Linux), then nio. conf.ioMode match { case "nio" => initNio() case "oio" => initOio() case "epoll" => initEpoll() - case "auto" => - // For auto mode, first try epoll (only available on Linux), then nio. - try { - initEpoll() - } catch { - // TODO: Should we log the throwable? But that always happen on non-Linux systems. - // Perhaps the right thing to do is to check whether the system is Linux, and then only - // call initEpoll on Linux. - case e: Throwable => initNio() - } + case "auto" => if (Epoll.isAvailable) initEpoll() else initNio() } } @@ -87,8 +77,8 @@ class BlockFetchingClientFactory(val conf: NettyConfig) { * * Concurrency: This method is safe to call from multiple threads. */ - def createClient(remoteHost: String, remotePort: Int): BlockFetchingClient = { - new BlockFetchingClient(this, remoteHost, remotePort) + def createClient(remoteHost: String, remotePort: Int): BlockClient = { + new BlockClient(this, remoteHost, remotePort) } def stop(): Unit = { diff --git a/core/src/main/scala/org/apache/spark/network/netty/BlockClientHandler.scala b/core/src/main/scala/org/apache/spark/network/netty/BlockClientHandler.scala new file mode 100644 index 0000000000000..b41c831f3d7e5 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/network/netty/BlockClientHandler.scala @@ -0,0 +1,86 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.netty + +import io.netty.channel.{ChannelHandlerContext, SimpleChannelInboundHandler} + +import org.apache.spark.Logging +import org.apache.spark.network.BlockFetchingListener + + +/** + * Handler that processes server responses. + * + * Concurrency: thread safe and can be called from multiple threads. + */ +private[netty] +class BlockClientHandler extends SimpleChannelInboundHandler[ServerResponse] with Logging { + + /** Tracks the list of outstanding requests and their listeners on success/failure. */ + private val outstandingRequests = java.util.Collections.synchronizedMap { + new java.util.HashMap[String, BlockFetchingListener] + } + + def addRequest(blockId: String, listener: BlockFetchingListener): Unit = { + outstandingRequests.put(blockId, listener) + } + + def removeRequest(blockId: String): Unit = { + outstandingRequests.remove(blockId) + } + + override def exceptionCaught(ctx: ChannelHandlerContext, cause: Throwable): Unit = { + val errorMsg = s"Exception in connection from ${ctx.channel.remoteAddress}: ${cause.getMessage}" + logError(errorMsg, cause) + + // Fire the failure callback for all outstanding blocks + outstandingRequests.synchronized { + val iter = outstandingRequests.entrySet().iterator() + while (iter.hasNext) { + val entry = iter.next() + entry.getValue.onBlockFetchFailure(cause) + } + outstandingRequests.clear() + } + + ctx.close() + } + + override def channelRead0(ctx: ChannelHandlerContext, response: ServerResponse) { + val server = ctx.channel.remoteAddress.toString + response match { + case BlockFetchSuccess(blockId, buf) => + val listener = outstandingRequests.get(blockId) + if (listener == null) { + logWarning(s"Got a response for block $blockId from $server but it is not outstanding") + } else { + outstandingRequests.remove(blockId) + listener.onBlockFetchSuccess(blockId, buf) + } + case BlockFetchFailure(blockId, errorMsg) => + val listener = outstandingRequests.get(blockId) + if (listener == null) { + logWarning( + s"Got a response for block $blockId from $server ($errorMsg) but it is not outstanding") + } else { + outstandingRequests.remove(blockId) + listener.onBlockFetchFailure(new RuntimeException(errorMsg)) + } + } + } +} diff --git a/core/src/main/scala/org/apache/spark/network/netty/server/BlockServer.scala b/core/src/main/scala/org/apache/spark/network/netty/BlockServer.scala similarity index 77% rename from core/src/main/scala/org/apache/spark/network/netty/server/BlockServer.scala rename to core/src/main/scala/org/apache/spark/network/netty/BlockServer.scala index 7b2f9a8d4dfd0..76f28aa00112e 100644 --- a/core/src/main/scala/org/apache/spark/network/netty/server/BlockServer.scala +++ b/core/src/main/scala/org/apache/spark/network/netty/BlockServer.scala @@ -15,48 +15,33 @@ * limitations under the License. */ -package org.apache.spark.network.netty.server +package org.apache.spark.network.netty import java.net.InetSocketAddress import io.netty.bootstrap.ServerBootstrap import io.netty.buffer.PooledByteBufAllocator -import io.netty.channel.{ChannelFuture, ChannelInitializer, ChannelOption} import io.netty.channel.epoll.{EpollEventLoopGroup, EpollServerSocketChannel} import io.netty.channel.nio.NioEventLoopGroup import io.netty.channel.oio.OioEventLoopGroup import io.netty.channel.socket.SocketChannel import io.netty.channel.socket.nio.NioServerSocketChannel import io.netty.channel.socket.oio.OioServerSocketChannel -import io.netty.handler.codec.LineBasedFrameDecoder -import io.netty.handler.codec.string.StringDecoder -import io.netty.util.CharsetUtil +import io.netty.channel.{ChannelInitializer, ChannelFuture, ChannelOption} +import io.netty.handler.codec.LengthFieldBasedFrameDecoder import org.apache.spark.{Logging, SparkConf} -import org.apache.spark.network.netty.NettyConfig -import org.apache.spark.storage.BlockDataProvider +import org.apache.spark.network.BlockDataManager import org.apache.spark.util.Utils /** - * Server for serving Spark data blocks. - * This should be used together with [[org.apache.spark.network.netty.client.BlockFetchingClient]]. - * - * Protocol for requesting blocks (client to server): - * One block id per line, e.g. to request 3 blocks: "block1\nblock2\nblock3\n" - * - * Protocol for sending blocks (server to client): - * frame-length (4 bytes), block-id-length (4 bytes), block-id, block-data. - * - * frame-length should not include the length of itself. - * If block-id-length is negative, then this is an error message rather than block-data. The real - * length is the absolute value of the frame-length. - * + * Server for the [[NettyBlockTransferService]]. */ -private[spark] -class BlockServer(conf: NettyConfig, dataProvider: BlockDataProvider) extends Logging { +private[netty] +class BlockServer(conf: NettyConfig, dataProvider: BlockDataManager) extends Logging { - def this(sparkConf: SparkConf, dataProvider: BlockDataProvider) = { + def this(sparkConf: SparkConf, dataProvider: BlockDataManager) = { this(new NettyConfig(sparkConf), dataProvider) } @@ -129,10 +114,10 @@ class BlockServer(conf: NettyConfig, dataProvider: BlockDataProvider) extends Lo bootstrap.childHandler(new ChannelInitializer[SocketChannel] { override def initChannel(ch: SocketChannel): Unit = { - ch.pipeline - .addLast("frameDecoder", new LineBasedFrameDecoder(1024)) // max block id length 1024 - .addLast("stringDecoder", new StringDecoder(CharsetUtil.UTF_8)) - .addLast("blockHeaderEncoder", new BlockHeaderEncoder) + val p = ch.pipeline + .addLast("frameDecoder", ProtocolUtils.createFrameDecoder()) + .addLast("clientRequestDecoder", new ClientRequestDecoder) + .addLast("serverResponseEncoder", new ServerResponseEncoder) .addLast("handler", new BlockServerHandler(dataProvider)) } }) diff --git a/core/src/main/scala/org/apache/spark/network/netty/BlockServerHandler.scala b/core/src/main/scala/org/apache/spark/network/netty/BlockServerHandler.scala new file mode 100644 index 0000000000000..739526a4fc6bc --- /dev/null +++ b/core/src/main/scala/org/apache/spark/network/netty/BlockServerHandler.scala @@ -0,0 +1,98 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.netty + +import io.netty.channel._ + +import org.apache.spark.Logging +import org.apache.spark.network.{ManagedBuffer, BlockDataManager} + + +/** + * A handler that processes requests from clients and writes block data back. + * + * The messages should have been processed by the pipeline setup by BlockServerChannelInitializer. + */ +private[netty] class BlockServerHandler(dataProvider: BlockDataManager) + extends SimpleChannelInboundHandler[ClientRequest] with Logging { + + override def exceptionCaught(ctx: ChannelHandlerContext, cause: Throwable): Unit = { + logError(s"Exception in connection from ${ctx.channel.remoteAddress}", cause) + ctx.close() + } + + override def channelRead0(ctx: ChannelHandlerContext, request: ClientRequest): Unit = { + request match { + case BlockFetchRequest(blockIds) => + blockIds.foreach(processBlockRequest(ctx, _)) + case BlockUploadRequest(blockId, data) => + // TODO(rxin): handle upload. + } + } // end of channelRead0 + + private def processBlockRequest(ctx: ChannelHandlerContext, blockId: String): Unit = { + // A helper function to send error message back to the client. + def client = ctx.channel.remoteAddress.toString + + def respondWithError(error: String): Unit = { + ctx.writeAndFlush(new BlockFetchFailure(blockId, error)).addListener( + new ChannelFutureListener { + override def operationComplete(future: ChannelFuture) { + if (!future.isSuccess) { + // TODO: Maybe log the success case as well. + logError(s"Error sending error back to $client", future.cause) + ctx.close() + } + } + } + ) + } + + logTrace(s"Received request from $client to fetch block $blockId") + + // First make sure we can find the block. If not, send error back to the user. + var blockData: Option[ManagedBuffer] = null + try { + blockData = dataProvider.getBlockData(blockId) + } catch { + case e: Exception => + logError(s"Error opening block $blockId for request from $client", e) + respondWithError(e.getMessage) + return + } + + blockData match { + case Some(buf) => + ctx.writeAndFlush(new BlockFetchSuccess(blockId, buf)).addListener( + new ChannelFutureListener { + override def operationComplete(future: ChannelFuture): Unit = { + if (future.isSuccess) { + logTrace(s"Sent block $blockId (${buf.size} B) back to $client") + } else { + logError( + s"Error sending block $blockId to $client; closing connection", future.cause) + ctx.close() + } + } + } + ) + case None => + respondWithError("Block not found") + } + } // end of processBlockRequest +} diff --git a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala new file mode 100644 index 0000000000000..fa8bdfc96e8b8 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala @@ -0,0 +1,83 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.netty + +import scala.concurrent.Future + +import org.apache.spark.SparkConf +import org.apache.spark.network._ +import org.apache.spark.storage.StorageLevel + + +/** + * A [[BlockTransferService]] implementation based on Netty. + * + * See protocol.scala for the communication protocol between server and client + */ +final class NettyBlockTransferService(conf: SparkConf) extends BlockTransferService { + + private[this] val nettyConf: NettyConfig = new NettyConfig(conf) + + private[this] var server: BlockServer = _ + private[this] var clientFactory: BlockClientFactory = _ + + override def init(blockDataManager: BlockDataManager): Unit = { + server = new BlockServer(nettyConf, blockDataManager) + clientFactory = new BlockClientFactory(nettyConf) + } + + override def stop(): Unit = { + if (server != null) { + server.stop() + } + if (clientFactory != null) { + clientFactory.stop() + } + } + + override def fetchBlocks( + hostName: String, + port: Int, + blockIds: Seq[String], + listener: BlockFetchingListener): Unit = { + clientFactory.createClient(hostName, port).fetchBlocks(blockIds, listener) + } + + override def uploadBlock( + hostname: String, + port: Int, + blockId: String, + blockData: ManagedBuffer, level: StorageLevel): Future[Unit] = { + // TODO(rxin): Implement uploadBlock. + ??? + } + + override def hostName: String = { + if (server == null) { + throw new IllegalStateException("Server has not been started") + } + server.hostName + } + + override def port: Int = { + if (server == null) { + throw new IllegalStateException("Server has not been started") + } + server.port + } +} diff --git a/core/src/main/scala/org/apache/spark/network/netty/PathResolver.scala b/core/src/main/scala/org/apache/spark/network/netty/PathResolver.scala deleted file mode 100644 index 0d7695072a7b1..0000000000000 --- a/core/src/main/scala/org/apache/spark/network/netty/PathResolver.scala +++ /dev/null @@ -1,25 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network.netty - -import org.apache.spark.storage.{BlockId, FileSegment} - -trait PathResolver { - /** Get the file segment in which the given block resides. */ - def getBlockLocation(blockId: BlockId): FileSegment -} diff --git a/core/src/main/scala/org/apache/spark/network/netty/client/BlockClientListener.scala b/core/src/main/scala/org/apache/spark/network/netty/client/BlockClientListener.scala deleted file mode 100644 index e28219dd7745b..0000000000000 --- a/core/src/main/scala/org/apache/spark/network/netty/client/BlockClientListener.scala +++ /dev/null @@ -1,29 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network.netty.client - -import java.util.EventListener - - -trait BlockClientListener extends EventListener { - - def onFetchSuccess(blockId: String, data: ReferenceCountedBuffer): Unit - - def onFetchFailure(blockId: String, errorMsg: String): Unit - -} diff --git a/core/src/main/scala/org/apache/spark/network/netty/client/BlockFetchingClientHandler.scala b/core/src/main/scala/org/apache/spark/network/netty/client/BlockFetchingClientHandler.scala deleted file mode 100644 index 83265b164299d..0000000000000 --- a/core/src/main/scala/org/apache/spark/network/netty/client/BlockFetchingClientHandler.scala +++ /dev/null @@ -1,103 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network.netty.client - -import io.netty.buffer.ByteBuf -import io.netty.channel.{ChannelHandlerContext, SimpleChannelInboundHandler} - -import org.apache.spark.Logging - - -/** - * Handler that processes server responses. It uses the protocol documented in - * [[org.apache.spark.network.netty.server.BlockServer]]. - * - * Concurrency: thread safe and can be called from multiple threads. - */ -private[client] -class BlockFetchingClientHandler extends SimpleChannelInboundHandler[ByteBuf] with Logging { - - /** Tracks the list of outstanding requests and their listeners on success/failure. */ - private val outstandingRequests = java.util.Collections.synchronizedMap { - new java.util.HashMap[String, BlockClientListener] - } - - def addRequest(blockId: String, listener: BlockClientListener): Unit = { - outstandingRequests.put(blockId, listener) - } - - def removeRequest(blockId: String): Unit = { - outstandingRequests.remove(blockId) - } - - override def exceptionCaught(ctx: ChannelHandlerContext, cause: Throwable): Unit = { - val errorMsg = s"Exception in connection from ${ctx.channel.remoteAddress}: ${cause.getMessage}" - logError(errorMsg, cause) - - // Fire the failure callback for all outstanding blocks - outstandingRequests.synchronized { - val iter = outstandingRequests.entrySet().iterator() - while (iter.hasNext) { - val entry = iter.next() - entry.getValue.onFetchFailure(entry.getKey, errorMsg) - } - outstandingRequests.clear() - } - - ctx.close() - } - - override def channelRead0(ctx: ChannelHandlerContext, in: ByteBuf) { - val totalLen = in.readInt() - val blockIdLen = in.readInt() - val blockIdBytes = new Array[Byte](math.abs(blockIdLen)) - in.readBytes(blockIdBytes) - val blockId = new String(blockIdBytes) - val blockSize = totalLen - math.abs(blockIdLen) - 4 - - def server = ctx.channel.remoteAddress.toString - - // blockIdLen is negative when it is an error message. - if (blockIdLen < 0) { - val errorMessageBytes = new Array[Byte](blockSize) - in.readBytes(errorMessageBytes) - val errorMsg = new String(errorMessageBytes) - logTrace(s"Received block $blockId ($blockSize B) with error $errorMsg from $server") - - val listener = outstandingRequests.get(blockId) - if (listener == null) { - // Ignore callback - logWarning(s"Got a response for block $blockId but it is not in our outstanding requests") - } else { - outstandingRequests.remove(blockId) - listener.onFetchFailure(blockId, errorMsg) - } - } else { - logTrace(s"Received block $blockId ($blockSize B) from $server") - - val listener = outstandingRequests.get(blockId) - if (listener == null) { - // Ignore callback - logWarning(s"Got a response for block $blockId but it is not in our outstanding requests") - } else { - outstandingRequests.remove(blockId) - listener.onFetchSuccess(blockId, new ReferenceCountedBuffer(in)) - } - } - } -} diff --git a/core/src/main/scala/org/apache/spark/network/netty/client/LazyInitIterator.scala b/core/src/main/scala/org/apache/spark/network/netty/client/LazyInitIterator.scala deleted file mode 100644 index 9740ee64d1f2d..0000000000000 --- a/core/src/main/scala/org/apache/spark/network/netty/client/LazyInitIterator.scala +++ /dev/null @@ -1,44 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network.netty.client - -/** - * A simple iterator that lazily initializes the underlying iterator. - * - * The use case is that sometimes we might have many iterators open at the same time, and each of - * the iterator might initialize its own buffer (e.g. decompression buffer, deserialization buffer). - * This could lead to too many buffers open. If this iterator is used, we lazily initialize those - * buffers. - */ -private[spark] -class LazyInitIterator(createIterator: => Iterator[Any]) extends Iterator[Any] { - - lazy val proxy = createIterator - - override def hasNext: Boolean = { - val gotNext = proxy.hasNext - if (!gotNext) { - close() - } - gotNext - } - - override def next(): Any = proxy.next() - - def close(): Unit = Unit -} diff --git a/core/src/main/scala/org/apache/spark/network/netty/client/ReferenceCountedBuffer.scala b/core/src/main/scala/org/apache/spark/network/netty/client/ReferenceCountedBuffer.scala deleted file mode 100644 index ea1abf5eccc26..0000000000000 --- a/core/src/main/scala/org/apache/spark/network/netty/client/ReferenceCountedBuffer.scala +++ /dev/null @@ -1,47 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network.netty.client - -import java.io.InputStream -import java.nio.ByteBuffer - -import io.netty.buffer.{ByteBuf, ByteBufInputStream} - - -/** - * A buffer abstraction based on Netty's ByteBuf so we don't expose Netty. - * This is a Scala value class. - * - * The buffer's life cycle is NOT managed by the JVM, and thus requiring explicit declaration of - * reference by the retain method and release method. - */ -private[spark] -class ReferenceCountedBuffer(val underlying: ByteBuf) extends AnyVal { - - /** Return the nio ByteBuffer view of the underlying buffer. */ - def byteBuffer(): ByteBuffer = underlying.nioBuffer - - /** Creates a new input stream that starts from the current position of the buffer. */ - def inputStream(): InputStream = new ByteBufInputStream(underlying) - - /** Increment the reference counter by one. */ - def retain(): Unit = underlying.retain() - - /** Decrement the reference counter by one and release the buffer if the ref count is 0. */ - def release(): Unit = underlying.release() -} diff --git a/core/src/main/scala/org/apache/spark/network/netty/protocol.scala b/core/src/main/scala/org/apache/spark/network/netty/protocol.scala new file mode 100644 index 0000000000000..0159eca1d3b41 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/network/netty/protocol.scala @@ -0,0 +1,243 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.netty + +import java.util.{List => JList} + +import io.netty.buffer.ByteBuf +import io.netty.channel.ChannelHandlerContext +import io.netty.channel.ChannelHandler.Sharable +import io.netty.handler.codec._ + +import org.apache.spark.Logging +import org.apache.spark.network.{NettyByteBufManagedBuffer, ManagedBuffer} + + +sealed trait ClientRequest { + def id: Byte +} + +final case class BlockFetchRequest(blocks: Seq[String]) extends ClientRequest { + override def id = 0 +} + +final case class BlockUploadRequest(blockId: String, data: ManagedBuffer) extends ClientRequest { + require(blockId.length <= Byte.MaxValue) + override def id = 1 +} + + +sealed trait ServerResponse { + def id: Byte +} + +final case class BlockFetchSuccess(blockId: String, data: ManagedBuffer) extends ServerResponse { + require(blockId.length <= Byte.MaxValue) + override def id = 0 +} + +final case class BlockFetchFailure(blockId: String, error: String) extends ServerResponse { + require(blockId.length <= Byte.MaxValue) + override def id = 1 +} + + +/** + * Encoder used by the client side to encode client-to-server responses. + */ +@Sharable +final class ClientRequestEncoder extends MessageToMessageEncoder[ClientRequest] { + override def encode(ctx: ChannelHandlerContext, in: ClientRequest, out: JList[Object]): Unit = { + in match { + case BlockFetchRequest(blocks) => + // 8 bytes: frame size + // 1 byte: BlockFetchRequest vs BlockUploadRequest + // 4 byte: num blocks + // then for each block id write 1 byte for blockId.length and then blockId itself + val frameLength = 8 + 1 + 4 + blocks.size + blocks.map(_.size).fold(0)(_ + _) + val buf = ctx.alloc().buffer(frameLength) + + buf.writeLong(frameLength) + buf.writeByte(in.id) + buf.writeInt(blocks.size) + blocks.foreach { blockId => + ProtocolUtils.writeBlockId(buf, blockId) + } + + assert(buf.writableBytes() == 0) + out.add(buf) + + case BlockUploadRequest(blockId, data) => + // 8 bytes: frame size + // 1 byte: msg id (BlockFetchRequest vs BlockUploadRequest) + // 1 byte: blockId.length + // data itself (length can be derived from: frame size - 1 - blockId.length) + val headerLength = 8 + 1 + 1 + blockId.length + val frameLength = headerLength + data.size + val header = ctx.alloc().buffer(headerLength) + + // Call this before we add header to out so in case of exceptions + // we don't send anything at all. + val body = data.convertToNetty() + + header.writeLong(frameLength) + header.writeByte(in.id) + ProtocolUtils.writeBlockId(header, blockId) + + assert(header.writableBytes() == 0) + out.add(header) + out.add(body) + } + } +} + + +/** + * Decoder in the server side to decode client requests. + * + * This assumes the inbound messages have been processed by a frame decoder created by + * [[ProtocolUtils.createFrameDecoder()]]. + */ +@Sharable +final class ClientRequestDecoder extends MessageToMessageDecoder[ByteBuf] { + override protected def decode(ctx: ChannelHandlerContext, in: ByteBuf, out: JList[AnyRef]): Unit = + { + val msgTypeId = in.readByte() + val decoded = msgTypeId match { + case 0 => // BlockFetchRequest + val numBlocks = in.readInt() + val blockIds = Seq.fill(numBlocks) { ProtocolUtils.readBlockId(in) } + BlockFetchRequest(blockIds) + + case 1 => // BlockUploadRequest + val blockId = ProtocolUtils.readBlockId(in) + in.retain() // retain the bytebuf so we don't recycle it immediately. + BlockUploadRequest(blockId, new NettyByteBufManagedBuffer(in)) + } + + assert(decoded.id == msgTypeId) + out.add(decoded) + } +} + + +/** + * Encoder used by the server side to encode server-to-client responses. + */ +@Sharable +final class ServerResponseEncoder extends MessageToMessageEncoder[ServerResponse] with Logging { + override def encode(ctx: ChannelHandlerContext, in: ServerResponse, out: JList[Object]): Unit = { + in match { + case BlockFetchSuccess(blockId, data) => + // Handle the body first so if we encounter an error getting the body, we can respond + // with an error instead. + var body: AnyRef = null + try { + body = data.convertToNetty() + } catch { + case e: Exception => + // Re-encode this message as BlockFetchFailure. + logError(s"Error opening block $blockId for client ${ctx.channel.remoteAddress}", e) + encode(ctx, new BlockFetchFailure(blockId, e.getMessage), out) + return + } + + // If we got here, body cannot be null + // 8 bytes = long for frame length + // 1 byte = message id (type) + // 1 byte = block id length + // followed by block id itself + val headerLength = 8 + 1 + 1 + blockId.length + val frameLength = headerLength + data.size + val header = ctx.alloc().buffer(headerLength) + header.writeLong(frameLength) + header.writeByte(in.id) + ProtocolUtils.writeBlockId(header, blockId) + + assert(header.writableBytes() == 0) + out.add(header) + out.add(body) + + case BlockFetchFailure(blockId, error) => + val frameLength = 8 + 1 + 1 + blockId.length + error.length + val buf = ctx.alloc().buffer(frameLength) + buf.writeLong(frameLength) + buf.writeByte(in.id) + ProtocolUtils.writeBlockId(buf, blockId) + buf.writeBytes(error.getBytes) + + assert(buf.writableBytes() == 0) + out.add(buf) + } + } +} + + +/** + * Decoder in the client side to decode server responses. + * + * This assumes the inbound messages have been processed by a frame decoder created by + * [[ProtocolUtils.createFrameDecoder()]]. + */ +@Sharable +final class ServerResponseDecoder extends MessageToMessageDecoder[ByteBuf] { + override def decode(ctx: ChannelHandlerContext, in: ByteBuf, out: JList[AnyRef]): Unit = { + val msgId = in.readByte() + val decoded = msgId match { + case 0 => // BlockFetchSuccess + val blockId = ProtocolUtils.readBlockId(in) + in.retain() + new BlockFetchSuccess(blockId, new NettyByteBufManagedBuffer(in)) + + case 1 => // BlockFetchFailure + val blockId = ProtocolUtils.readBlockId(in) + val errorBytes = new Array[Byte](in.readableBytes()) + in.readBytes(errorBytes) + new BlockFetchFailure(blockId, new String(errorBytes)) + } + + assert(decoded.id == msgId) + out.add(decoded) + } +} + + +private[netty] object ProtocolUtils { + + /** LengthFieldBasedFrameDecoder used before all decoders. */ + def createFrameDecoder(): ByteToMessageDecoder = { + // maxFrameLength = 2G + // lengthFieldOffset = 0 + // lengthFieldLength = 8 + // lengthAdjustment = -8, i.e. exclude the 8 byte length itself + // initialBytesToStrip = 8, i.e. strip out the length field itself + new LengthFieldBasedFrameDecoder(Int.MaxValue, 0, 8, -8, 8) + } + + def readBlockId(in: ByteBuf): String = { + val numBytesToRead = in.readByte().toInt + val bytes = new Array[Byte](numBytesToRead) + in.readBytes(bytes) + new String(bytes) + } + + def writeBlockId(out: ByteBuf, blockId: String): Unit = { + out.writeByte(blockId.length) + out.writeBytes(blockId.getBytes) + } +} diff --git a/core/src/main/scala/org/apache/spark/network/netty/server/BlockHeader.scala b/core/src/main/scala/org/apache/spark/network/netty/server/BlockHeader.scala deleted file mode 100644 index 162e9cc6828d4..0000000000000 --- a/core/src/main/scala/org/apache/spark/network/netty/server/BlockHeader.scala +++ /dev/null @@ -1,32 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network.netty.server - -/** - * Header describing a block. This is used only in the server pipeline. - * - * [[BlockServerHandler]] creates this, and [[BlockHeaderEncoder]] encodes it. - * - * @param blockSize length of the block content, excluding the length itself. - * If positive, this is the header for a block (not part of the header). - * If negative, this is the header and content for an error message. - * @param blockId block id - * @param error some error message from reading the block - */ -private[server] -class BlockHeader(val blockSize: Int, val blockId: String, val error: Option[String] = None) diff --git a/core/src/main/scala/org/apache/spark/network/netty/server/BlockHeaderEncoder.scala b/core/src/main/scala/org/apache/spark/network/netty/server/BlockHeaderEncoder.scala deleted file mode 100644 index 8e4dda4ef8595..0000000000000 --- a/core/src/main/scala/org/apache/spark/network/netty/server/BlockHeaderEncoder.scala +++ /dev/null @@ -1,47 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network.netty.server - -import io.netty.buffer.ByteBuf -import io.netty.channel.ChannelHandlerContext -import io.netty.handler.codec.MessageToByteEncoder - -/** - * A simple encoder for BlockHeader. See [[BlockServer]] for the server to client protocol. - */ -private[server] -class BlockHeaderEncoder extends MessageToByteEncoder[BlockHeader] { - override def encode(ctx: ChannelHandlerContext, msg: BlockHeader, out: ByteBuf): Unit = { - // message = message length (4 bytes) + block id length (4 bytes) + block id + block data - // message length = block id length (4 bytes) + size of block id + size of block data - val blockIdBytes = msg.blockId.getBytes - msg.error match { - case Some(errorMsg) => - val errorBytes = errorMsg.getBytes - out.writeInt(4 + blockIdBytes.length + errorBytes.size) - out.writeInt(-blockIdBytes.length) // use negative block id length to represent errors - out.writeBytes(blockIdBytes) // next is blockId itself - out.writeBytes(errorBytes) // error message - case None => - out.writeInt(4 + blockIdBytes.length + msg.blockSize) - out.writeInt(blockIdBytes.length) // First 4 bytes is blockId length - out.writeBytes(blockIdBytes) // next is blockId itself - // msg of size blockSize will be written by ServerHandler - } - } -} diff --git a/core/src/main/scala/org/apache/spark/network/netty/server/BlockServerChannelInitializer.scala b/core/src/main/scala/org/apache/spark/network/netty/server/BlockServerChannelInitializer.scala deleted file mode 100644 index cc70bd0c5c477..0000000000000 --- a/core/src/main/scala/org/apache/spark/network/netty/server/BlockServerChannelInitializer.scala +++ /dev/null @@ -1,40 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network.netty.server - -import io.netty.channel.ChannelInitializer -import io.netty.channel.socket.SocketChannel -import io.netty.handler.codec.LineBasedFrameDecoder -import io.netty.handler.codec.string.StringDecoder -import io.netty.util.CharsetUtil -import org.apache.spark.storage.BlockDataProvider - - -/** Channel initializer that sets up the pipeline for the BlockServer. */ -private[netty] -class BlockServerChannelInitializer(dataProvider: BlockDataProvider) - extends ChannelInitializer[SocketChannel] { - - override def initChannel(ch: SocketChannel): Unit = { - ch.pipeline - .addLast("frameDecoder", new LineBasedFrameDecoder(1024)) // max block id length 1024 - .addLast("stringDecoder", new StringDecoder(CharsetUtil.UTF_8)) - .addLast("blockHeaderEncoder", new BlockHeaderEncoder) - .addLast("handler", new BlockServerHandler(dataProvider)) - } -} diff --git a/core/src/main/scala/org/apache/spark/network/netty/server/BlockServerHandler.scala b/core/src/main/scala/org/apache/spark/network/netty/server/BlockServerHandler.scala deleted file mode 100644 index 40dd5e5d1a2ac..0000000000000 --- a/core/src/main/scala/org/apache/spark/network/netty/server/BlockServerHandler.scala +++ /dev/null @@ -1,140 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network.netty.server - -import java.io.FileInputStream -import java.nio.ByteBuffer -import java.nio.channels.FileChannel - -import io.netty.buffer.Unpooled -import io.netty.channel._ - -import org.apache.spark.Logging -import org.apache.spark.storage.{FileSegment, BlockDataProvider} - - -/** - * A handler that processes requests from clients and writes block data back. - * - * The messages should have been processed by a LineBasedFrameDecoder and a StringDecoder first - * so channelRead0 is called once per line (i.e. per block id). - */ -private[server] -class BlockServerHandler(dataProvider: BlockDataProvider) - extends SimpleChannelInboundHandler[String] with Logging { - - override def exceptionCaught(ctx: ChannelHandlerContext, cause: Throwable): Unit = { - logError(s"Exception in connection from ${ctx.channel.remoteAddress}", cause) - ctx.close() - } - - override def channelRead0(ctx: ChannelHandlerContext, blockId: String): Unit = { - def client = ctx.channel.remoteAddress.toString - - // A helper function to send error message back to the client. - def respondWithError(error: String): Unit = { - ctx.writeAndFlush(new BlockHeader(-1, blockId, Some(error))).addListener( - new ChannelFutureListener { - override def operationComplete(future: ChannelFuture) { - if (!future.isSuccess) { - // TODO: Maybe log the success case as well. - logError(s"Error sending error back to $client", future.cause) - ctx.close() - } - } - } - ) - } - - def writeFileSegment(segment: FileSegment): Unit = { - // Send error message back if the block is too large. Even though we are capable of sending - // large (2G+) blocks, the receiving end cannot handle it so let's fail fast. - // Once we fixed the receiving end to be able to process large blocks, this should be removed. - // Also make sure we update BlockHeaderEncoder to support length > 2G. - - // See [[BlockHeaderEncoder]] for the way length is encoded. - if (segment.length + blockId.length + 4 > Int.MaxValue) { - respondWithError(s"Block $blockId size ($segment.length) greater than 2G") - return - } - - var fileChannel: FileChannel = null - try { - fileChannel = new FileInputStream(segment.file).getChannel - } catch { - case e: Exception => - logError( - s"Error opening channel for $blockId in ${segment.file} for request from $client", e) - respondWithError(e.getMessage) - } - - // Found the block. Send it back. - if (fileChannel != null) { - // Write the header and block data. In the case of failures, the listener on the block data - // write should close the connection. - ctx.write(new BlockHeader(segment.length.toInt, blockId)) - - val region = new DefaultFileRegion(fileChannel, segment.offset, segment.length) - ctx.writeAndFlush(region).addListener(new ChannelFutureListener { - override def operationComplete(future: ChannelFuture) { - if (future.isSuccess) { - logTrace(s"Sent block $blockId (${segment.length} B) back to $client") - } else { - logError(s"Error sending block $blockId to $client; closing connection", future.cause) - ctx.close() - } - } - }) - } - } - - def writeByteBuffer(buf: ByteBuffer): Unit = { - ctx.write(new BlockHeader(buf.remaining, blockId)) - ctx.writeAndFlush(Unpooled.wrappedBuffer(buf)).addListener(new ChannelFutureListener { - override def operationComplete(future: ChannelFuture) { - if (future.isSuccess) { - logTrace(s"Sent block $blockId (${buf.remaining} B) back to $client") - } else { - logError(s"Error sending block $blockId to $client; closing connection", future.cause) - ctx.close() - } - } - }) - } - - logTrace(s"Received request from $client to fetch block $blockId") - - var blockData: Either[FileSegment, ByteBuffer] = null - - // First make sure we can find the block. If not, send error back to the user. - try { - blockData = dataProvider.getBlockData(blockId) - } catch { - case e: Exception => - logError(s"Error opening block $blockId for request from $client", e) - respondWithError(e.getMessage) - return - } - - blockData match { - case Left(segment) => writeFileSegment(segment) - case Right(buf) => writeByteBuffer(buf) - } - - } // end of channelRead0 -} diff --git a/core/src/main/scala/org/apache/spark/storage/BlockDataProvider.scala b/core/src/main/scala/org/apache/spark/storage/BlockDataProvider.scala deleted file mode 100644 index 5b6d086630834..0000000000000 --- a/core/src/main/scala/org/apache/spark/storage/BlockDataProvider.scala +++ /dev/null @@ -1,32 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.storage - -import java.nio.ByteBuffer - - -/** - * An interface for providing data for blocks. - * - * getBlockData returns either a FileSegment (for zero-copy send), or a ByteBuffer. - * - * Aside from unit tests, [[BlockManager]] is the main class that implements this. - */ -private[spark] trait BlockDataProvider { - def getBlockData(blockId: String): Either[FileSegment, ByteBuffer] -} diff --git a/core/src/test/scala/org/apache/spark/network/netty/BlockClientHandlerSuite.scala b/core/src/test/scala/org/apache/spark/network/netty/BlockClientHandlerSuite.scala new file mode 100644 index 0000000000000..1358b2f9c8071 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/network/netty/BlockClientHandlerSuite.scala @@ -0,0 +1,129 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.netty + +import java.nio.ByteBuffer +import java.util.concurrent.atomic.AtomicInteger + +import io.netty.buffer.Unpooled +import io.netty.channel.embedded.EmbeddedChannel + +import org.scalatest.{FunSuite, PrivateMethodTester} + +import org.apache.spark.network._ + + +class BlockClientHandlerSuite extends FunSuite with PrivateMethodTester { + + private def sizeOfOutstandingRequests(handler: BlockClientHandler): Int = { + val outstandingRequests = PrivateMethod[java.util.Map[_, _]]('outstandingRequests) + handler.invokePrivate(outstandingRequests()).size + } + + test("handling block data (successful fetch)") { + val blockId = "test_block" + val blockData = "blahblahblahblahblah" + + var parsedBlockId: String = "" + var parsedBlockData: String = "" + val handler = new BlockClientHandler + handler.addRequest(blockId, + new BlockFetchingListener { + override def onBlockFetchFailure(exception: Throwable): Unit = { + throw new UnsupportedOperationException + } + override def onBlockFetchSuccess(blockId: String, data: ManagedBuffer): Unit = { + parsedBlockId = blockId + val bytes = new Array[Byte](data.size.toInt) + data.nioByteBuffer().get(bytes) + parsedBlockData = new String(bytes) + } + } + ) + + val outstandingRequests = PrivateMethod[java.util.Map[_, _]]('outstandingRequests) + assert(handler.invokePrivate(outstandingRequests()).size === 1) + + val channel = new EmbeddedChannel(handler) + val buf = ByteBuffer.allocate(blockData.size) // 4 bytes for the length field itself + buf.put(blockData.getBytes) + buf.flip() + + channel.writeInbound(BlockFetchSuccess(blockId, new NioByteBufferManagedBuffer(buf))) + + assert(parsedBlockId === blockId) + assert(parsedBlockData === blockData) + assert(handler.invokePrivate(outstandingRequests()).size === 0) + assert(channel.finish() === false) + } + + test("handling error message (failed fetch)") { + val blockId = "test_block" + val errorMsg = "error erro5r error err4or error3 error6 error erro1r" + + var parsedErrorMsg: String = "" + val handler = new BlockClientHandler + handler.addRequest(blockId, + new BlockFetchingListener { + override def onBlockFetchFailure(exception: Throwable): Unit = { + parsedErrorMsg = exception.getMessage + } + + override def onBlockFetchSuccess(blockId: String, data: ManagedBuffer): Unit = { + throw new UnsupportedOperationException + } + } + ) + + assert(sizeOfOutstandingRequests(handler) === 1) + + val channel = new EmbeddedChannel(handler) + channel.writeInbound(BlockFetchFailure(blockId, errorMsg)) + assert(parsedErrorMsg === errorMsg) + assert(sizeOfOutstandingRequests(handler) === 0) + assert(channel.finish() === false) + } + + ignore("clear all outstanding request upon connection close") { + val errorCount = new AtomicInteger(0) + val successCount = new AtomicInteger(0) + val handler = new BlockClientHandler + + val listener = new BlockFetchingListener { + override def onBlockFetchFailure(exception: Throwable): Unit = { + errorCount.incrementAndGet() + } + override def onBlockFetchSuccess(blockId: String, data: ManagedBuffer): Unit = { + successCount.incrementAndGet() + } + } + + handler.addRequest("b1", listener) + handler.addRequest("b2", listener) + handler.addRequest("b3", listener) + assert(sizeOfOutstandingRequests(handler) === 3) + + val channel = new EmbeddedChannel(handler) + channel.writeInbound(BlockFetchSuccess("b1", new NettyByteBufManagedBuffer(Unpooled.buffer()))) + // Need to figure out a way to generate an exception + assert(successCount.get() === 1) + assert(errorCount.get() === 2) + assert(sizeOfOutstandingRequests(handler) === 0) + assert(channel.finish() === false) + } +} diff --git a/core/src/test/scala/org/apache/spark/network/netty/ProtocolSuite.scala b/core/src/test/scala/org/apache/spark/network/netty/ProtocolSuite.scala new file mode 100644 index 0000000000000..72034634a5bd2 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/network/netty/ProtocolSuite.scala @@ -0,0 +1,88 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.netty + +import io.netty.channel.embedded.EmbeddedChannel + +import org.scalatest.FunSuite + + +/** + * Test client/server encoder/decoder protocol. + */ +class ProtocolSuite extends FunSuite { + + /** + * Helper to test server to client message protocol by encoding a message and decoding it. + */ + private def testServerToClient(msg: ServerResponse) { + val serverChannel = new EmbeddedChannel(new ServerResponseEncoder) + serverChannel.writeOutbound(msg) + + val clientChannel = new EmbeddedChannel( + ProtocolUtils.createFrameDecoder(), + new ServerResponseDecoder) + + // Drain all server outbound messages and write them to the client's server decoder. + while (!serverChannel.outboundMessages().isEmpty) { + clientChannel.writeInbound(serverChannel.readOutbound()) + } + + assert(clientChannel.inboundMessages().size === 1) + // Must put "msg === ..." instead of "... === msg" since only TestManagedBuffer equals is + // overridden. + assert(msg === clientChannel.readInbound()) + } + + /** + * Helper to test client to server message protocol by encoding a message and decoding it. + */ + private def testClientToServer(msg: ClientRequest) { + val clientChannel = new EmbeddedChannel(new ClientRequestEncoder) + clientChannel.writeOutbound(msg) + + val serverChannel = new EmbeddedChannel( + ProtocolUtils.createFrameDecoder(), + new ClientRequestDecoder) + + // Drain all client outbound messages and write them to the server's decoder. + while (!clientChannel.outboundMessages().isEmpty) { + serverChannel.writeInbound(clientChannel.readOutbound()) + } + + assert(serverChannel.inboundMessages().size === 1) + // Must put "msg === ..." instead of "... === msg" since only TestManagedBuffer equals is + // overridden. + assert(msg === serverChannel.readInbound()) + } + + test("server to client protocol") { + testServerToClient(BlockFetchSuccess("a1234", new TestManagedBuffer(10))) + testServerToClient(BlockFetchSuccess("", new TestManagedBuffer(0))) + testServerToClient(BlockFetchFailure("abcd", "this is an error")) + testServerToClient(BlockFetchFailure("", "")) + } + + test("client to server protocol") { + testClientToServer(BlockFetchRequest(Seq.empty[String])) + testClientToServer(BlockFetchRequest(Seq("b1"))) + testClientToServer(BlockFetchRequest(Seq("b1", "b2", "b3"))) + testClientToServer(BlockUploadRequest("", new TestManagedBuffer(0))) + testClientToServer(BlockUploadRequest("b_upload", new TestManagedBuffer(10))) + } +} diff --git a/core/src/test/scala/org/apache/spark/network/netty/ServerClientIntegrationSuite.scala b/core/src/test/scala/org/apache/spark/network/netty/ServerClientIntegrationSuite.scala index 02d0ffc86f58f..a468764fb1848 100644 --- a/core/src/test/scala/org/apache/spark/network/netty/ServerClientIntegrationSuite.scala +++ b/core/src/test/scala/org/apache/spark/network/netty/ServerClientIntegrationSuite.scala @@ -1,19 +1,19 @@ /* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ package org.apache.spark.network.netty @@ -24,26 +24,25 @@ import java.util.concurrent.{TimeUnit, Semaphore} import scala.collection.JavaConversions._ -import io.netty.buffer.{ByteBufUtil, Unpooled} +import io.netty.buffer.Unpooled import org.scalatest.{BeforeAndAfterAll, FunSuite} import org.apache.spark.SparkConf -import org.apache.spark.network.netty.client.{BlockClientListener, ReferenceCountedBuffer, BlockFetchingClientFactory} -import org.apache.spark.network.netty.server.BlockServer -import org.apache.spark.storage.{FileSegment, BlockDataProvider} +import org.apache.spark.network._ +import org.apache.spark.storage.StorageLevel /** - * Test suite that makes sure the server and the client implementations share the same protocol. - */ +* Test suite that makes sure the server and the client implementations share the same protocol. +*/ class ServerClientIntegrationSuite extends FunSuite with BeforeAndAfterAll { val bufSize = 100000 var buf: ByteBuffer = _ var testFile: File = _ var server: BlockServer = _ - var clientFactory: BlockFetchingClientFactory = _ + var clientFactory: BlockClientFactory = _ val bufferBlockId = "buffer_block" val fileBlockId = "file_block" @@ -63,19 +62,24 @@ class ServerClientIntegrationSuite extends FunSuite with BeforeAndAfterAll { fp.write(fileContent) fp.close() - server = new BlockServer(new SparkConf, new BlockDataProvider { - override def getBlockData(blockId: String): Either[FileSegment, ByteBuffer] = { + server = new BlockServer(new SparkConf, new BlockDataManager { + override def getBlockData(blockId: String): Option[ManagedBuffer] = { if (blockId == bufferBlockId) { - Right(buf) + Some(new NioByteBufferManagedBuffer(buf)) } else if (blockId == fileBlockId) { - Left(new FileSegment(testFile, 10, testFile.length - 25)) + Some(new FileSegmentManagedBuffer(testFile, 10, testFile.length - 25)) } else { - throw new Exception("Unknown block id " + blockId) + None } } + + /** + * Put the block locally, using the given storage level. + */ + def putBlockData(blockId: String, data: ManagedBuffer, level: StorageLevel): Unit = ??? }) - clientFactory = new BlockFetchingClientFactory(new SparkConf) + clientFactory = new BlockClientFactory(new SparkConf) } override def afterAll() = { @@ -89,31 +93,29 @@ class ServerClientIntegrationSuite extends FunSuite with BeforeAndAfterAll { /** A ByteBuf for file_block */ lazy val fileBlockReference = Unpooled.wrappedBuffer(fileContent, 10, fileContent.length - 25) - def fetchBlocks(blockIds: Seq[String]): (Set[String], Set[ReferenceCountedBuffer], Set[String]) = + def fetchBlocks(blockIds: Seq[String]): (Set[String], Set[ManagedBuffer], Set[String]) = { val client = clientFactory.createClient(server.hostName, server.port) val sem = new Semaphore(0) val receivedBlockIds = Collections.synchronizedSet(new HashSet[String]) val errorBlockIds = Collections.synchronizedSet(new HashSet[String]) - val receivedBuffers = Collections.synchronizedSet(new HashSet[ReferenceCountedBuffer]) + val receivedBuffers = Collections.synchronizedSet(new HashSet[ManagedBuffer]) client.fetchBlocks( blockIds, - new BlockClientListener { - override def onFetchFailure(blockId: String, errorMsg: String): Unit = { - errorBlockIds.add(blockId) + new BlockFetchingListener { + override def onBlockFetchFailure(exception: Throwable): Unit = { sem.release() } - override def onFetchSuccess(blockId: String, data: ReferenceCountedBuffer): Unit = { + override def onBlockFetchSuccess(blockId: String, data: ManagedBuffer): Unit = { receivedBlockIds.add(blockId) - data.retain() receivedBuffers.add(data) sem.release() } } ) - if (!sem.tryAcquire(blockIds.size, 30, TimeUnit.SECONDS)) { + if (!sem.tryAcquire(blockIds.size, 5, TimeUnit.SECONDS)) { fail("Timeout getting response from the server") } client.close() @@ -123,20 +125,18 @@ class ServerClientIntegrationSuite extends FunSuite with BeforeAndAfterAll { test("fetch a ByteBuffer block") { val (blockIds, buffers, failBlockIds) = fetchBlocks(Seq(bufferBlockId)) assert(blockIds === Set(bufferBlockId)) - assert(buffers.map(_.underlying) === Set(byteBufferBlockReference)) + assert(buffers.map(_.convertToNetty()) === Set(byteBufferBlockReference)) assert(failBlockIds.isEmpty) - buffers.foreach(_.release()) } test("fetch a FileSegment block via zero-copy send") { val (blockIds, buffers, failBlockIds) = fetchBlocks(Seq(fileBlockId)) assert(blockIds === Set(fileBlockId)) - assert(buffers.map(_.underlying) === Set(fileBlockReference)) + assert(buffers.map(_.convertToNetty()) === Set(fileBlockReference)) assert(failBlockIds.isEmpty) - buffers.foreach(_.release()) } - test("fetch a non-existent block") { + ignore("fetch a non-existent block") { val (blockIds, buffers, failBlockIds) = fetchBlocks(Seq("random-block")) assert(blockIds.isEmpty) assert(buffers.isEmpty) @@ -146,16 +146,14 @@ class ServerClientIntegrationSuite extends FunSuite with BeforeAndAfterAll { test("fetch both ByteBuffer block and FileSegment block") { val (blockIds, buffers, failBlockIds) = fetchBlocks(Seq(bufferBlockId, fileBlockId)) assert(blockIds === Set(bufferBlockId, fileBlockId)) - assert(buffers.map(_.underlying) === Set(byteBufferBlockReference, fileBlockReference)) + assert(buffers.map(_.convertToNetty()) === Set(byteBufferBlockReference, fileBlockReference)) assert(failBlockIds.isEmpty) - buffers.foreach(_.release()) } - test("fetch both ByteBuffer block and a non-existent block") { + ignore("fetch both ByteBuffer block and a non-existent block") { val (blockIds, buffers, failBlockIds) = fetchBlocks(Seq(bufferBlockId, "random-block")) assert(blockIds === Set(bufferBlockId)) - assert(buffers.map(_.underlying) === Set(byteBufferBlockReference)) + assert(buffers.map(_.convertToNetty()) === Set(byteBufferBlockReference)) assert(failBlockIds === Set("random-block")) - buffers.foreach(_.release()) } } diff --git a/core/src/test/scala/org/apache/spark/network/netty/TestManagedBuffer.scala b/core/src/test/scala/org/apache/spark/network/netty/TestManagedBuffer.scala new file mode 100644 index 0000000000000..6ae2d3b3faf91 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/network/netty/TestManagedBuffer.scala @@ -0,0 +1,68 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.netty + +import java.io.InputStream +import java.nio.ByteBuffer + +import io.netty.buffer.Unpooled + +import org.apache.spark.network.{NettyByteBufManagedBuffer, ManagedBuffer} + + +/** + * A ManagedBuffer implementation that contains 0, 1, 2, 3, ..., (len-1). + * + * Used for testing. + */ +class TestManagedBuffer(len: Int) extends ManagedBuffer { + + require(len <= Byte.MaxValue) + + private val byteArray: Array[Byte] = Array.tabulate[Byte](len)(_.toByte) + + private val underlying = new NettyByteBufManagedBuffer(Unpooled.wrappedBuffer(byteArray)) + + override def size: Long = underlying.size + + override private[network] def convertToNetty(): AnyRef = underlying.convertToNetty() + + override def nioByteBuffer(): ByteBuffer = underlying.nioByteBuffer() + + override def inputStream(): InputStream = underlying.inputStream() + + override def toString: String = s"${getClass.getName}($len)" + + override def equals(other: Any): Boolean = other match { + case otherBuf: ManagedBuffer => + val nioBuf = otherBuf.nioByteBuffer() + if (nioBuf.remaining() != len) { + return false + } else { + var i = 0 + while (i < len) { + if (nioBuf.get() != i) { + return false + } + i += 1 + } + return true + } + case _ => false + } +} diff --git a/core/src/test/scala/org/apache/spark/network/netty/client/BlockFetchingClientHandlerSuite.scala b/core/src/test/scala/org/apache/spark/network/netty/client/BlockFetchingClientHandlerSuite.scala deleted file mode 100644 index 903ab09ae4322..0000000000000 --- a/core/src/test/scala/org/apache/spark/network/netty/client/BlockFetchingClientHandlerSuite.scala +++ /dev/null @@ -1,105 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network.netty.client - -import java.nio.ByteBuffer - -import io.netty.buffer.Unpooled -import io.netty.channel.embedded.EmbeddedChannel - -import org.scalatest.{PrivateMethodTester, FunSuite} - - -class BlockFetchingClientHandlerSuite extends FunSuite with PrivateMethodTester { - - test("handling block data (successful fetch)") { - val blockId = "test_block" - val blockData = "blahblahblahblahblah" - val totalLength = 4 + blockId.length + blockData.length - - var parsedBlockId: String = "" - var parsedBlockData: String = "" - val handler = new BlockFetchingClientHandler - handler.addRequest(blockId, - new BlockClientListener { - override def onFetchFailure(blockId: String, errorMsg: String): Unit = ??? - override def onFetchSuccess(bid: String, refCntBuf: ReferenceCountedBuffer): Unit = { - parsedBlockId = bid - val bytes = new Array[Byte](refCntBuf.byteBuffer().remaining) - refCntBuf.byteBuffer().get(bytes) - parsedBlockData = new String(bytes) - } - } - ) - - val outstandingRequests = PrivateMethod[java.util.Map[_, _]]('outstandingRequests) - assert(handler.invokePrivate(outstandingRequests()).size === 1) - - val channel = new EmbeddedChannel(handler) - val buf = ByteBuffer.allocate(totalLength + 4) // 4 bytes for the length field itself - buf.putInt(totalLength) - buf.putInt(blockId.length) - buf.put(blockId.getBytes) - buf.put(blockData.getBytes) - buf.flip() - - channel.writeInbound(Unpooled.wrappedBuffer(buf)) - assert(parsedBlockId === blockId) - assert(parsedBlockData === blockData) - - assert(handler.invokePrivate(outstandingRequests()).size === 0) - - channel.close() - } - - test("handling error message (failed fetch)") { - val blockId = "test_block" - val errorMsg = "error erro5r error err4or error3 error6 error erro1r" - val totalLength = 4 + blockId.length + errorMsg.length - - var parsedBlockId: String = "" - var parsedErrorMsg: String = "" - val handler = new BlockFetchingClientHandler - handler.addRequest(blockId, new BlockClientListener { - override def onFetchFailure(bid: String, msg: String) ={ - parsedBlockId = bid - parsedErrorMsg = msg - } - override def onFetchSuccess(bid: String, refCntBuf: ReferenceCountedBuffer) = ??? - }) - - val outstandingRequests = PrivateMethod[java.util.Map[_, _]]('outstandingRequests) - assert(handler.invokePrivate(outstandingRequests()).size === 1) - - val channel = new EmbeddedChannel(handler) - val buf = ByteBuffer.allocate(totalLength + 4) // 4 bytes for the length field itself - buf.putInt(totalLength) - buf.putInt(-blockId.length) - buf.put(blockId.getBytes) - buf.put(errorMsg.getBytes) - buf.flip() - - channel.writeInbound(Unpooled.wrappedBuffer(buf)) - assert(parsedBlockId === blockId) - assert(parsedErrorMsg === errorMsg) - - assert(handler.invokePrivate(outstandingRequests()).size === 0) - - channel.close() - } -} diff --git a/core/src/test/scala/org/apache/spark/network/netty/server/BlockHeaderEncoderSuite.scala b/core/src/test/scala/org/apache/spark/network/netty/server/BlockHeaderEncoderSuite.scala deleted file mode 100644 index 3ee281cb1350b..0000000000000 --- a/core/src/test/scala/org/apache/spark/network/netty/server/BlockHeaderEncoderSuite.scala +++ /dev/null @@ -1,64 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network.netty.server - -import io.netty.buffer.ByteBuf -import io.netty.channel.embedded.EmbeddedChannel - -import org.scalatest.FunSuite - - -class BlockHeaderEncoderSuite extends FunSuite { - - test("encode normal block data") { - val blockId = "test_block" - val channel = new EmbeddedChannel(new BlockHeaderEncoder) - channel.writeOutbound(new BlockHeader(17, blockId, None)) - val out = channel.readOutbound().asInstanceOf[ByteBuf] - assert(out.readInt() === 4 + blockId.length + 17) - assert(out.readInt() === blockId.length) - - val blockIdBytes = new Array[Byte](blockId.length) - out.readBytes(blockIdBytes) - assert(new String(blockIdBytes) === blockId) - assert(out.readableBytes() === 0) - - channel.close() - } - - test("encode error message") { - val blockId = "error_block" - val errorMsg = "error encountered" - val channel = new EmbeddedChannel(new BlockHeaderEncoder) - channel.writeOutbound(new BlockHeader(17, blockId, Some(errorMsg))) - val out = channel.readOutbound().asInstanceOf[ByteBuf] - assert(out.readInt() === 4 + blockId.length + errorMsg.length) - assert(out.readInt() === -blockId.length) - - val blockIdBytes = new Array[Byte](blockId.length) - out.readBytes(blockIdBytes) - assert(new String(blockIdBytes) === blockId) - - val errorMsgBytes = new Array[Byte](errorMsg.length) - out.readBytes(errorMsgBytes) - assert(new String(errorMsgBytes) === errorMsg) - assert(out.readableBytes() === 0) - - channel.close() - } -} diff --git a/core/src/test/scala/org/apache/spark/network/netty/server/BlockServerHandlerSuite.scala b/core/src/test/scala/org/apache/spark/network/netty/server/BlockServerHandlerSuite.scala deleted file mode 100644 index 3239c710f1639..0000000000000 --- a/core/src/test/scala/org/apache/spark/network/netty/server/BlockServerHandlerSuite.scala +++ /dev/null @@ -1,107 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network.netty.server - -import java.io.{RandomAccessFile, File} -import java.nio.ByteBuffer - -import io.netty.buffer.{Unpooled, ByteBuf} -import io.netty.channel.{ChannelHandlerContext, SimpleChannelInboundHandler, DefaultFileRegion} -import io.netty.channel.embedded.EmbeddedChannel - -import org.scalatest.FunSuite - -import org.apache.spark.storage.{BlockDataProvider, FileSegment} - - -class BlockServerHandlerSuite extends FunSuite { - - test("ByteBuffer block") { - val expectedBlockId = "test_bytebuffer_block" - val buf = ByteBuffer.allocate(10000) - for (i <- 1 to 10000) { - buf.put(i.toByte) - } - buf.flip() - - val channel = new EmbeddedChannel(new BlockServerHandler(new BlockDataProvider { - override def getBlockData(blockId: String): Either[FileSegment, ByteBuffer] = Right(buf) - })) - - channel.writeInbound(expectedBlockId) - assert(channel.outboundMessages().size === 2) - - val out1 = channel.readOutbound().asInstanceOf[BlockHeader] - val out2 = channel.readOutbound().asInstanceOf[ByteBuf] - - assert(out1.blockId === expectedBlockId) - assert(out1.blockSize === buf.remaining) - assert(out1.error === None) - - assert(out2.equals(Unpooled.wrappedBuffer(buf))) - - channel.close() - } - - test("FileSegment block via zero-copy") { - val expectedBlockId = "test_file_block" - - // Create random file data - val fileContent = new Array[Byte](1024) - scala.util.Random.nextBytes(fileContent) - val testFile = File.createTempFile("netty-test-file", "txt") - val fp = new RandomAccessFile(testFile, "rw") - fp.write(fileContent) - fp.close() - - val channel = new EmbeddedChannel(new BlockServerHandler(new BlockDataProvider { - override def getBlockData(blockId: String): Either[FileSegment, ByteBuffer] = { - Left(new FileSegment(testFile, 15, testFile.length - 25)) - } - })) - - channel.writeInbound(expectedBlockId) - assert(channel.outboundMessages().size === 2) - - val out1 = channel.readOutbound().asInstanceOf[BlockHeader] - val out2 = channel.readOutbound().asInstanceOf[DefaultFileRegion] - - assert(out1.blockId === expectedBlockId) - assert(out1.blockSize === testFile.length - 25) - assert(out1.error === None) - - assert(out2.count === testFile.length - 25) - assert(out2.position === 15) - } - - test("pipeline exception propagation") { - val blockServerHandler = new BlockServerHandler(new BlockDataProvider { - override def getBlockData(blockId: String): Either[FileSegment, ByteBuffer] = ??? - }) - val exceptionHandler = new SimpleChannelInboundHandler[String]() { - override def channelRead0(ctx: ChannelHandlerContext, msg: String): Unit = { - throw new Exception("this is an error") - } - } - - val channel = new EmbeddedChannel(exceptionHandler, blockServerHandler) - assert(channel.isOpen) - channel.writeInbound("a message to trigger the error") - assert(!channel.isOpen) - } -} From 1760d3292ecf262e4c77c9e3b28bfd2900d25840 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Tue, 9 Sep 2014 00:42:37 -0700 Subject: [PATCH 02/46] Use Epoll.isAvailable in BlockServer as well. --- .../apache/spark/network/netty/BlockServer.scala | 16 +++------------- 1 file changed, 3 insertions(+), 13 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/network/netty/BlockServer.scala b/core/src/main/scala/org/apache/spark/network/netty/BlockServer.scala index 76f28aa00112e..3433c5763ab3c 100644 --- a/core/src/main/scala/org/apache/spark/network/netty/BlockServer.scala +++ b/core/src/main/scala/org/apache/spark/network/netty/BlockServer.scala @@ -21,14 +21,13 @@ import java.net.InetSocketAddress import io.netty.bootstrap.ServerBootstrap import io.netty.buffer.PooledByteBufAllocator -import io.netty.channel.epoll.{EpollEventLoopGroup, EpollServerSocketChannel} +import io.netty.channel.epoll.{Epoll, EpollEventLoopGroup, EpollServerSocketChannel} import io.netty.channel.nio.NioEventLoopGroup import io.netty.channel.oio.OioEventLoopGroup import io.netty.channel.socket.SocketChannel import io.netty.channel.socket.nio.NioServerSocketChannel import io.netty.channel.socket.oio.OioServerSocketChannel import io.netty.channel.{ChannelInitializer, ChannelFuture, ChannelOption} -import io.netty.handler.codec.LengthFieldBasedFrameDecoder import org.apache.spark.{Logging, SparkConf} import org.apache.spark.network.BlockDataManager @@ -85,16 +84,7 @@ class BlockServer(conf: NettyConfig, dataProvider: BlockDataManager) extends Log case "nio" => initNio() case "oio" => initOio() case "epoll" => initEpoll() - case "auto" => - // For auto mode, first try epoll (only available on Linux), then nio. - try { - initEpoll() - } catch { - // TODO: Should we log the throwable? But that always happen on non-Linux systems. - // Perhaps the right thing to do is to check whether the system is Linux, and then only - // call initEpoll on Linux. - case e: Throwable => initNio() - } + case "auto" => if (Epoll.isAvailable) initEpoll() else initNio() } // Use pooled buffers to reduce temporary buffer allocation @@ -114,7 +104,7 @@ class BlockServer(conf: NettyConfig, dataProvider: BlockDataManager) extends Log bootstrap.childHandler(new ChannelInitializer[SocketChannel] { override def initChannel(ch: SocketChannel): Unit = { - val p = ch.pipeline + ch.pipeline .addLast("frameDecoder", ProtocolUtils.createFrameDecoder()) .addLast("clientRequestDecoder", new ClientRequestDecoder) .addLast("serverResponseEncoder", new ServerResponseEncoder) From 2b44cf1b7547919bbe7386e954fe2f56be046790 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Tue, 9 Sep 2014 14:36:31 -0700 Subject: [PATCH 03/46] Added more documentation. --- .../spark/network/netty/BlockClient.scala | 61 +++++-------------- .../network/netty/BlockClientFactory.scala | 44 ++++++++++++- .../network/netty/BlockClientHandler.scala | 5 +- .../spark/network/netty/BlockServer.scala | 4 +- .../apache/spark/network/netty/protocol.scala | 19 +++++- .../netty/ServerClientIntegrationSuite.scala | 5 +- 6 files changed, 80 insertions(+), 58 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/network/netty/BlockClient.scala b/core/src/main/scala/org/apache/spark/network/netty/BlockClient.scala index 95af2565bcc39..9333fefa92957 100644 --- a/core/src/main/scala/org/apache/spark/network/netty/BlockClient.scala +++ b/core/src/main/scala/org/apache/spark/network/netty/BlockClient.scala @@ -19,68 +19,35 @@ package org.apache.spark.network.netty import java.util.concurrent.TimeoutException -import io.netty.bootstrap.Bootstrap -import io.netty.buffer.PooledByteBufAllocator -import io.netty.channel.socket.SocketChannel -import io.netty.channel.{ChannelFuture, ChannelFutureListener, ChannelInitializer, ChannelOption} +import io.netty.channel.{ChannelFuture, ChannelFutureListener} import org.apache.spark.Logging import org.apache.spark.network.BlockFetchingListener /** - * Client for [[NettyBlockTransferService]]. Use [[BlockClientFactory]] to - * instantiate this client. + * Client for [[NettyBlockTransferService]]. The connection to server must have been established + * using [[BlockClientFactory]] before instantiating this. * - * The constructor blocks until a connection is successfully established. + * This class is used to make requests to the server , while [[BlockClientHandler]] is responsible + * for handling responses from the server. * * Concurrency: thread safe and can be called from multiple threads. + * + * @param cf the ChannelFuture for the connection. + * @param handler [[BlockClientHandler]] for handling outstanding requests. */ @throws[TimeoutException] private[netty] -class BlockClient(factory: BlockClientFactory, hostname: String, port: Int) - extends Logging { - - private val handler = new BlockClientHandler - private val encoder = new ClientRequestEncoder - private val decoder = new ServerResponseDecoder - - /** Netty Bootstrap for creating the TCP connection. */ - private val bootstrap: Bootstrap = { - val b = new Bootstrap - b.group(factory.workerGroup) - .channel(factory.socketChannelClass) - // Use pooled buffers to reduce temporary buffer allocation - .option(ChannelOption.ALLOCATOR, PooledByteBufAllocator.DEFAULT) - // Disable Nagle's Algorithm since we don't want packets to wait - .option(ChannelOption.TCP_NODELAY, java.lang.Boolean.TRUE) - .option(ChannelOption.SO_KEEPALIVE, java.lang.Boolean.TRUE) - .option[Integer](ChannelOption.CONNECT_TIMEOUT_MILLIS, factory.conf.connectTimeoutMs) - - b.handler(new ChannelInitializer[SocketChannel] { - override def initChannel(ch: SocketChannel): Unit = { - ch.pipeline - .addLast("clientRequestEncoder", encoder) - .addLast("frameDecoder", ProtocolUtils.createFrameDecoder()) - .addLast("serverResponseDecoder", decoder) - .addLast("handler", handler) - } - }) - b - } +class BlockClient(cf: ChannelFuture, handler: BlockClientHandler) extends Logging { - /** Netty ChannelFuture for the connection. */ - private val cf: ChannelFuture = bootstrap.connect(hostname, port) - if (!cf.awaitUninterruptibly(factory.conf.connectTimeoutMs)) { - throw new TimeoutException( - s"Connecting to $hostname:$port timed out (${factory.conf.connectTimeoutMs} ms)") - } + private[this] val serverAddr = cf.channel().remoteAddress().toString /** * Ask the remote server for a sequence of blocks, and execute the callback. * * Note that this is asynchronous and returns immediately. Upstream caller should throttle the - * rate of fetching; otherwise we could run out of memory. + * rate of fetching; otherwise we could run out of memory due to large outstanding fetches. * * @param blockIds sequence of block ids to fetch. * @param listener callback to fire on fetch success / failure. @@ -89,7 +56,7 @@ class BlockClient(factory: BlockClientFactory, hostname: String, port: Int) var startTime: Long = 0 logTrace { startTime = System.nanoTime - s"Sending request $blockIds to $hostname:$port" + s"Sending request $blockIds to $serverAddr" } blockIds.foreach { blockId => @@ -101,12 +68,12 @@ class BlockClient(factory: BlockClientFactory, hostname: String, port: Int) if (future.isSuccess) { logTrace { val timeTaken = (System.nanoTime - startTime).toDouble / 1000000 - s"Sending request $blockIds to $hostname:$port took $timeTaken ms" + s"Sending request $blockIds to $serverAddr took $timeTaken ms" } } else { // Fail all blocks. val errorMsg = - s"Failed to send request $blockIds to $hostname:$port: ${future.cause.getMessage}" + s"Failed to send request $blockIds to $serverAddr: ${future.cause.getMessage}" logError(errorMsg, future.cause) blockIds.foreach { blockId => handler.removeRequest(blockId) diff --git a/core/src/main/scala/org/apache/spark/network/netty/BlockClientFactory.scala b/core/src/main/scala/org/apache/spark/network/netty/BlockClientFactory.scala index 0777275cd4fe3..f05f1419ded14 100644 --- a/core/src/main/scala/org/apache/spark/network/netty/BlockClientFactory.scala +++ b/core/src/main/scala/org/apache/spark/network/netty/BlockClientFactory.scala @@ -17,12 +17,17 @@ package org.apache.spark.network.netty +import java.util.concurrent.TimeoutException + +import io.netty.bootstrap.Bootstrap +import io.netty.buffer.PooledByteBufAllocator +import io.netty.channel._ import io.netty.channel.epoll.{Epoll, EpollEventLoopGroup, EpollSocketChannel} import io.netty.channel.nio.NioEventLoopGroup import io.netty.channel.oio.OioEventLoopGroup +import io.netty.channel.socket.SocketChannel import io.netty.channel.socket.nio.NioSocketChannel import io.netty.channel.socket.oio.OioSocketChannel -import io.netty.channel.{Channel, EventLoopGroup} import org.apache.spark.SparkConf import org.apache.spark.util.Utils @@ -38,12 +43,16 @@ class BlockClientFactory(val conf: NettyConfig) { def this(sparkConf: SparkConf) = this(new NettyConfig(sparkConf)) /** A thread factory so the threads are named (for debugging). */ - private[netty] val threadFactory = Utils.namedThreadFactory("spark-shuffle-client") + private[netty] val threadFactory = Utils.namedThreadFactory("spark-netty-client") /** The following two are instantiated by the [[init]] method, depending ioMode. */ private[netty] var socketChannelClass: Class[_ <: Channel] = _ private[netty] var workerGroup: EventLoopGroup = _ + // The encoders are stateless and can be shared among multiple clients. + private[this] val encoder = new ClientRequestEncoder + private[this] val decoder = new ServerResponseDecoder + init() /** Initialize [[socketChannelClass]] and [[workerGroup]] based on ioMode. */ @@ -78,7 +87,36 @@ class BlockClientFactory(val conf: NettyConfig) { * Concurrency: This method is safe to call from multiple threads. */ def createClient(remoteHost: String, remotePort: Int): BlockClient = { - new BlockClient(this, remoteHost, remotePort) + val handler = new BlockClientHandler + + val bootstrap = new Bootstrap + bootstrap.group(workerGroup) + .channel(socketChannelClass) + // Use pooled buffers to reduce temporary buffer allocation + .option(ChannelOption.ALLOCATOR, PooledByteBufAllocator.DEFAULT) + // Disable Nagle's Algorithm since we don't want packets to wait + .option(ChannelOption.TCP_NODELAY, java.lang.Boolean.TRUE) + .option(ChannelOption.SO_KEEPALIVE, java.lang.Boolean.TRUE) + .option[Integer](ChannelOption.CONNECT_TIMEOUT_MILLIS, conf.connectTimeoutMs) + + bootstrap.handler(new ChannelInitializer[SocketChannel] { + override def initChannel(ch: SocketChannel): Unit = { + ch.pipeline + .addLast("clientRequestEncoder", encoder) + .addLast("frameDecoder", ProtocolUtils.createFrameDecoder()) + .addLast("serverResponseDecoder", decoder) + .addLast("handler", handler) + } + }) + + // Connect to the remote server + val cf: ChannelFuture = bootstrap.connect(remoteHost, remotePort) + if (!cf.awaitUninterruptibly(conf.connectTimeoutMs)) { + throw new TimeoutException( + s"Connecting to $remoteHost:$remotePort timed out (${conf.connectTimeoutMs} ms)") + } + + new BlockClient(cf, handler) } def stop(): Unit = { diff --git a/core/src/main/scala/org/apache/spark/network/netty/BlockClientHandler.scala b/core/src/main/scala/org/apache/spark/network/netty/BlockClientHandler.scala index b41c831f3d7e5..2a474cd71eab8 100644 --- a/core/src/main/scala/org/apache/spark/network/netty/BlockClientHandler.scala +++ b/core/src/main/scala/org/apache/spark/network/netty/BlockClientHandler.scala @@ -24,7 +24,8 @@ import org.apache.spark.network.BlockFetchingListener /** - * Handler that processes server responses. + * Handler that processes server responses, in response to requests issued from [[BlockClient]]. + * It works by tracking the list of outstanding requests (and their callbacks). * * Concurrency: thread safe and can be called from multiple threads. */ @@ -32,7 +33,7 @@ private[netty] class BlockClientHandler extends SimpleChannelInboundHandler[ServerResponse] with Logging { /** Tracks the list of outstanding requests and their listeners on success/failure. */ - private val outstandingRequests = java.util.Collections.synchronizedMap { + private[this] val outstandingRequests = java.util.Collections.synchronizedMap { new java.util.HashMap[String, BlockFetchingListener] } diff --git a/core/src/main/scala/org/apache/spark/network/netty/BlockServer.scala b/core/src/main/scala/org/apache/spark/network/netty/BlockServer.scala index 3433c5763ab3c..05443a74094d7 100644 --- a/core/src/main/scala/org/apache/spark/network/netty/BlockServer.scala +++ b/core/src/main/scala/org/apache/spark/network/netty/BlockServer.scala @@ -58,8 +58,8 @@ class BlockServer(conf: NettyConfig, dataProvider: BlockDataManager) extends Log /** Initialize the server. */ private def init(): Unit = { bootstrap = new ServerBootstrap - val bossThreadFactory = Utils.namedThreadFactory("spark-shuffle-server-boss") - val workerThreadFactory = Utils.namedThreadFactory("spark-shuffle-server-worker") + val bossThreadFactory = Utils.namedThreadFactory("spark-netty-server-boss") + val workerThreadFactory = Utils.namedThreadFactory("spark-netty-server-worker") // Use only one thread to accept connections, and 2 * num_cores for worker. def initNio(): Unit = { diff --git a/core/src/main/scala/org/apache/spark/network/netty/protocol.scala b/core/src/main/scala/org/apache/spark/network/netty/protocol.scala index 0159eca1d3b41..ac6a4d00f654f 100644 --- a/core/src/main/scala/org/apache/spark/network/netty/protocol.scala +++ b/core/src/main/scala/org/apache/spark/network/netty/protocol.scala @@ -28,29 +28,40 @@ import org.apache.spark.Logging import org.apache.spark.network.{NettyByteBufManagedBuffer, ManagedBuffer} +/** Messages from the client to the server. */ sealed trait ClientRequest { def id: Byte } +/** + * Request to fetch a sequence of blocks from the server. A single [[BlockFetchRequest]] can + * correspond to multiple [[ServerResponse]]s. + */ final case class BlockFetchRequest(blocks: Seq[String]) extends ClientRequest { override def id = 0 } +/** + * Request to upload a block to the server. Currently the server does not ack the upload request. + */ final case class BlockUploadRequest(blockId: String, data: ManagedBuffer) extends ClientRequest { require(blockId.length <= Byte.MaxValue) override def id = 1 } +/** Messages from server to client (usually in response to some [[ClientRequest]]. */ sealed trait ServerResponse { def id: Byte } +/** Response to [[BlockFetchRequest]] when a block exists and has been successfully fetched. */ final case class BlockFetchSuccess(blockId: String, data: ManagedBuffer) extends ServerResponse { require(blockId.length <= Byte.MaxValue) override def id = 0 } +/** Response to [[BlockFetchRequest]] when there is an error fetching the block. */ final case class BlockFetchFailure(blockId: String, error: String) extends ServerResponse { require(blockId.length <= Byte.MaxValue) override def id = 1 @@ -58,7 +69,9 @@ final case class BlockFetchFailure(blockId: String, error: String) extends Serve /** - * Encoder used by the client side to encode client-to-server responses. + * Encoder for [[ClientRequest]] used in client side. + * + * This encoder is stateless so it is safe to be shared by multiple threads. */ @Sharable final class ClientRequestEncoder extends MessageToMessageEncoder[ClientRequest] { @@ -109,6 +122,7 @@ final class ClientRequestEncoder extends MessageToMessageEncoder[ClientRequest] /** * Decoder in the server side to decode client requests. + * This decoder is stateless so it is safe to be shared by multiple threads. * * This assumes the inbound messages have been processed by a frame decoder created by * [[ProtocolUtils.createFrameDecoder()]]. @@ -138,6 +152,7 @@ final class ClientRequestDecoder extends MessageToMessageDecoder[ByteBuf] { /** * Encoder used by the server side to encode server-to-client responses. + * This encoder is stateless so it is safe to be shared by multiple threads. */ @Sharable final class ServerResponseEncoder extends MessageToMessageEncoder[ServerResponse] with Logging { @@ -190,6 +205,7 @@ final class ServerResponseEncoder extends MessageToMessageEncoder[ServerResponse /** * Decoder in the client side to decode server responses. + * This decoder is stateless so it is safe to be shared by multiple threads. * * This assumes the inbound messages have been processed by a frame decoder created by * [[ProtocolUtils.createFrameDecoder()]]. @@ -229,6 +245,7 @@ private[netty] object ProtocolUtils { new LengthFieldBasedFrameDecoder(Int.MaxValue, 0, 8, -8, 8) } + // TODO(rxin): Make sure these work for all charsets. def readBlockId(in: ByteBuf): String = { val numBytesToRead = in.readByte().toInt val bytes = new Array[Byte](numBytesToRead) diff --git a/core/src/test/scala/org/apache/spark/network/netty/ServerClientIntegrationSuite.scala b/core/src/test/scala/org/apache/spark/network/netty/ServerClientIntegrationSuite.scala index a468764fb1848..178c60a048b9f 100644 --- a/core/src/test/scala/org/apache/spark/network/netty/ServerClientIntegrationSuite.scala +++ b/core/src/test/scala/org/apache/spark/network/netty/ServerClientIntegrationSuite.scala @@ -34,7 +34,7 @@ import org.apache.spark.storage.StorageLevel /** -* Test suite that makes sure the server and the client implementations share the same protocol. +* Test cases that create real clients and servers and connect. */ class ServerClientIntegrationSuite extends FunSuite with BeforeAndAfterAll { @@ -93,8 +93,7 @@ class ServerClientIntegrationSuite extends FunSuite with BeforeAndAfterAll { /** A ByteBuf for file_block */ lazy val fileBlockReference = Unpooled.wrappedBuffer(fileContent, 10, fileContent.length - 25) - def fetchBlocks(blockIds: Seq[String]): (Set[String], Set[ManagedBuffer], Set[String]) = - { + def fetchBlocks(blockIds: Seq[String]): (Set[String], Set[ManagedBuffer], Set[String]) = { val client = clientFactory.createClient(server.hostName, server.port) val sem = new Semaphore(0) val receivedBlockIds = Collections.synchronizedSet(new HashSet[String]) From 064747b50a591acb132b2c750957e79f54dfa88f Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Tue, 9 Sep 2014 23:38:38 -0700 Subject: [PATCH 04/46] Reference count buffers and clean them up properly. --- .../scala/org/apache/spark/SparkEnv.scala | 9 +- .../spark/network/BlockDataManager.scala | 7 +- .../apache/spark/network/ManagedBuffer.scala | 41 ++++++- .../spark/network/netty/BlockServer.scala | 7 +- .../network/netty/BlockServerHandler.scala | 33 +++--- .../network/nio/NioBlockTransferService.scala | 3 +- .../apache/spark/storage/BlockManager.scala | 8 +- .../storage/ShuffleBlockFetcherIterator.scala | 111 ++++++++++++++---- .../netty/ServerClientIntegrationSuite.scala | 12 +- .../network/netty/TestManagedBuffer.scala | 4 + 10 files changed, 164 insertions(+), 71 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala index aba713cb4267a..373ce795a309e 100644 --- a/core/src/main/scala/org/apache/spark/SparkEnv.scala +++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala @@ -32,6 +32,7 @@ import org.apache.spark.api.python.PythonWorkerFactory import org.apache.spark.broadcast.BroadcastManager import org.apache.spark.metrics.MetricsSystem import org.apache.spark.network.BlockTransferService +import org.apache.spark.network.netty.NettyBlockTransferService import org.apache.spark.network.nio.NioBlockTransferService import org.apache.spark.scheduler.LiveListenerBus import org.apache.spark.serializer.Serializer @@ -39,6 +40,7 @@ import org.apache.spark.shuffle.{ShuffleMemoryManager, ShuffleManager} import org.apache.spark.storage._ import org.apache.spark.util.{AkkaUtils, Utils} + /** * :: DeveloperApi :: * Holds all the runtime environment objects for a running Spark instance (either master or worker), @@ -231,7 +233,12 @@ object SparkEnv extends Logging { val shuffleMemoryManager = new ShuffleMemoryManager(conf) - val blockTransferService = new NioBlockTransferService(conf, securityManager) + // TODO(rxin): Config option based on class name, similar to shuffle mgr and compression codec. + val blockTransferService = if (conf.getBoolean("spark.shuffle.use.netty", false)) { + new NettyBlockTransferService(conf) + } else { + new NioBlockTransferService(conf, securityManager) + } val blockManagerMaster = new BlockManagerMaster(registerOrLookup( "BlockManagerMaster", diff --git a/core/src/main/scala/org/apache/spark/network/BlockDataManager.scala b/core/src/main/scala/org/apache/spark/network/BlockDataManager.scala index e0e91724271c8..638e05f481f55 100644 --- a/core/src/main/scala/org/apache/spark/network/BlockDataManager.scala +++ b/core/src/main/scala/org/apache/spark/network/BlockDataManager.scala @@ -23,11 +23,10 @@ import org.apache.spark.storage.StorageLevel trait BlockDataManager { /** - * Interface to get local block data. - * - * @return Some(buffer) if the block exists locally, and None if it doesn't. + * Interface to get local block data. Throws an exception if the block cannot be found or + * cannot be read successfully. */ - def getBlockData(blockId: String): Option[ManagedBuffer] + def getBlockData(blockId: String): ManagedBuffer /** * Put the block locally, using the given storage level. diff --git a/core/src/main/scala/org/apache/spark/network/ManagedBuffer.scala b/core/src/main/scala/org/apache/spark/network/ManagedBuffer.scala index 9c298132fcfba..7f364947dd930 100644 --- a/core/src/main/scala/org/apache/spark/network/ManagedBuffer.scala +++ b/core/src/main/scala/org/apache/spark/network/ManagedBuffer.scala @@ -35,9 +35,14 @@ import org.apache.spark.util.{ByteBufferInputStream, Utils} * This interface provides an immutable view for data in the form of bytes. The implementation * should specify how the data is provided: * - * - FileSegmentManagedBuffer: data backed by part of a file - * - NioByteBufferManagedBuffer: data backed by a NIO ByteBuffer - * - NettyByteBufManagedBuffer: data backed by a Netty ByteBuf + * - [[FileSegmentManagedBuffer]]: data backed by part of a file + * - [[NioByteBufferManagedBuffer]]: data backed by a NIO ByteBuffer + * - [[NettyByteBufManagedBuffer]]: data backed by a Netty ByteBuf + * + * The concrete buffer implementation might be managed outside the JVM garbage collector. + * For example, in the case of [[NettyByteBufManagedBuffer]], the buffers are reference counted. + * In that case, if the buffer is going to be passed around to a different thread, retain/release + * should be called. */ abstract class ManagedBuffer { // Note that all the methods are defined with parenthesis because their implementations can @@ -59,6 +64,17 @@ abstract class ManagedBuffer { */ def inputStream(): InputStream + /** + * Increment the reference count by one if applicable. + */ + def retain(): this.type + + /** + * If applicable, decrement the reference count by one and deallocates the buffer if the + * reference count reaches zero. + */ + def release(): this.type + /** * Convert the buffer into an Netty object, used to write the data out. */ @@ -123,6 +139,10 @@ final class FileSegmentManagedBuffer(val file: File, val offset: Long, val lengt val fileChannel = new FileInputStream(file).getChannel new DefaultFileRegion(fileChannel, offset, length) } + + // Content of file segments are not in-memory, so no need to reference count. + override def retain(): this.type = this + override def release(): this.type = this } @@ -138,6 +158,10 @@ final class NioByteBufferManagedBuffer(buf: ByteBuffer) extends ManagedBuffer { override def inputStream() = new ByteBufferInputStream(buf) private[network] override def convertToNetty(): AnyRef = Unpooled.wrappedBuffer(buf) + + // [[ByteBuffer]] is managed by the JVM garbage collector itself. + override def retain(): this.type = this + override def release(): this.type = this } @@ -154,6 +178,13 @@ final class NettyByteBufManagedBuffer(buf: ByteBuf) extends ManagedBuffer { private[network] override def convertToNetty(): AnyRef = buf - // TODO(rxin): Promote this to top level ManagedBuffer interface and add documentation for it. - def release(): Unit = buf.release() + override def retain(): this.type = { + buf.retain() + this + } + + override def release(): this.type = { + buf.release() + this + } } diff --git a/core/src/main/scala/org/apache/spark/network/netty/BlockServer.scala b/core/src/main/scala/org/apache/spark/network/netty/BlockServer.scala index 05443a74094d7..ceae31efac939 100644 --- a/core/src/main/scala/org/apache/spark/network/netty/BlockServer.scala +++ b/core/src/main/scala/org/apache/spark/network/netty/BlockServer.scala @@ -40,10 +40,6 @@ import org.apache.spark.util.Utils private[netty] class BlockServer(conf: NettyConfig, dataProvider: BlockDataManager) extends Logging { - def this(sparkConf: SparkConf, dataProvider: BlockDataManager) = { - this(new NettyConfig(sparkConf), dataProvider) - } - def port: Int = _port def hostName: String = _hostName @@ -117,7 +113,8 @@ class BlockServer(conf: NettyConfig, dataProvider: BlockDataManager) extends Log val addr = channelFuture.channel.localAddress.asInstanceOf[InetSocketAddress] _port = addr.getPort - _hostName = addr.getHostName + //_hostName = addr.getHostName + _hostName = Utils.localHostName() } /** Shutdown the server. */ diff --git a/core/src/main/scala/org/apache/spark/network/netty/BlockServerHandler.scala b/core/src/main/scala/org/apache/spark/network/netty/BlockServerHandler.scala index 739526a4fc6bc..c3b4d41829f4e 100644 --- a/core/src/main/scala/org/apache/spark/network/netty/BlockServerHandler.scala +++ b/core/src/main/scala/org/apache/spark/network/netty/BlockServerHandler.scala @@ -66,9 +66,9 @@ private[netty] class BlockServerHandler(dataProvider: BlockDataManager) logTrace(s"Received request from $client to fetch block $blockId") // First make sure we can find the block. If not, send error back to the user. - var blockData: Option[ManagedBuffer] = null + var buf: ManagedBuffer = null try { - blockData = dataProvider.getBlockData(blockId) + buf = dataProvider.getBlockData(blockId) } catch { case e: Exception => logError(s"Error opening block $blockId for request from $client", e) @@ -76,23 +76,18 @@ private[netty] class BlockServerHandler(dataProvider: BlockDataManager) return } - blockData match { - case Some(buf) => - ctx.writeAndFlush(new BlockFetchSuccess(blockId, buf)).addListener( - new ChannelFutureListener { - override def operationComplete(future: ChannelFuture): Unit = { - if (future.isSuccess) { - logTrace(s"Sent block $blockId (${buf.size} B) back to $client") - } else { - logError( - s"Error sending block $blockId to $client; closing connection", future.cause) - ctx.close() - } - } + ctx.writeAndFlush(new BlockFetchSuccess(blockId, buf)).addListener( + new ChannelFutureListener { + override def operationComplete(future: ChannelFuture): Unit = { + if (future.isSuccess) { + logTrace(s"Sent block $blockId (${buf.size} B) back to $client") + } else { + logError( + s"Error sending block $blockId to $client; closing connection", future.cause) + ctx.close() } - ) - case None => - respondWithError("Block not found") - } + } + } + ) } // end of processBlockRequest } diff --git a/core/src/main/scala/org/apache/spark/network/nio/NioBlockTransferService.scala b/core/src/main/scala/org/apache/spark/network/nio/NioBlockTransferService.scala index b389b9a2022c6..457ba106ced89 100644 --- a/core/src/main/scala/org/apache/spark/network/nio/NioBlockTransferService.scala +++ b/core/src/main/scala/org/apache/spark/network/nio/NioBlockTransferService.scala @@ -197,7 +197,8 @@ final class NioBlockTransferService(conf: SparkConf, securityManager: SecurityMa private def getBlock(blockId: String): ByteBuffer = { val startTimeMs = System.currentTimeMillis() logDebug("GetBlock " + blockId + " started from " + startTimeMs) - val buffer = blockDataManager.getBlockData(blockId).orNull + // TODO(rxin): propagate error back to the client? + val buffer = blockDataManager.getBlockData(blockId) logDebug("GetBlock " + blockId + " used " + Utils.getUsedTimeMs(startTimeMs) + " and got buffer " + buffer) if (buffer == null) null else buffer.nioByteBuffer() diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index 3f5d06e1aeee7..9995a25c224fe 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -216,17 +216,17 @@ private[spark] class BlockManager( * * @return Some(buffer) if the block exists locally, and None if it doesn't. */ - override def getBlockData(blockId: String): Option[ManagedBuffer] = { + override def getBlockData(blockId: String): ManagedBuffer = { val bid = BlockId(blockId) if (bid.isShuffle) { - Some(shuffleManager.shuffleBlockManager.getBlockData(bid.asInstanceOf[ShuffleBlockId])) + shuffleManager.shuffleBlockManager.getBlockData(bid.asInstanceOf[ShuffleBlockId]) } else { val blockBytesOpt = doGetLocal(bid, asBlockResult = false).asInstanceOf[Option[ByteBuffer]] if (blockBytesOpt.isDefined) { val buffer = blockBytesOpt.get - Some(new NioByteBufferManagedBuffer(buffer)) + new NioByteBufferManagedBuffer(buffer) } else { - None + throw new BlockNotFoundException(blockId) } } } diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala index 71b276b5f18e4..23f7d56895fe5 100644 --- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala +++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala @@ -23,10 +23,10 @@ import scala.collection.mutable.ArrayBuffer import scala.collection.mutable.HashSet import scala.collection.mutable.Queue -import org.apache.spark.{TaskContext, Logging} +import org.apache.spark.{Logging, TaskContext} import org.apache.spark.network.{ManagedBuffer, BlockFetchingListener, BlockTransferService} import org.apache.spark.serializer.Serializer -import org.apache.spark.util.Utils +import org.apache.spark.util.{CompletionIterator, Utils} /** @@ -88,17 +88,49 @@ final class ShuffleBlockFetcherIterator( */ private[this] val results = new LinkedBlockingQueue[FetchResult] - // Queue of fetch requests to issue; we'll pull requests off this gradually to make sure that - // the number of bytes in flight is limited to maxBytesInFlight + /** + * Current [[FetchResult]] being processed. We track this so we can release the current buffer + * in case of a runtime exception when processing the current buffer. + */ + private[this] var currentResult: FetchResult = null + + /** + * Queue of fetch requests to issue; we'll pull requests off this gradually to make sure that + * the number of bytes in flight is limited to maxBytesInFlight. + */ private[this] val fetchRequests = new Queue[FetchRequest] - // Current bytes in flight from our requests + /** Current bytes in flight from our requests */ private[this] var bytesInFlight = 0L private[this] val shuffleMetrics = context.taskMetrics.createShuffleReadMetricsForDependency() + /** + * Whether the iterator is still active. If isZombie is true, the callback interface will no + * longer place fetched blocks into [[results]]. + */ + @volatile private[this] var isZombie = false + initialize() + /** + * Mark the iterator as zombie, and release all buffers that haven't been deserialized yet. + */ + private[this] def cleanup() { + isZombie = true + // Release the current buffer if necessary + if (currentResult != null && currentResult.buf != null) { + currentResult.buf.release() + } + + // Release buffers in the results queue + val iter = results.iterator() + while (iter.hasNext) { + val result = iter.next() + result.buf.release() + } + } + private[this] def sendRequest(req: FetchRequest) { logDebug("Sending request for %d blocks (%s) from %s".format( req.blocks.size, Utils.bytesToString(req.size), req.address.hostPort)) @@ -110,13 +142,17 @@ final class ShuffleBlockFetcherIterator( blockTransferService.fetchBlocks(req.address.host, req.address.port, blockIds, new BlockFetchingListener { - override def onBlockFetchSuccess(blockId: String, data: ManagedBuffer): Unit = { - results.put(new FetchResult(BlockId(blockId), sizeMap(blockId), - () => serializer.newInstance().deserializeStream( - blockManager.wrapForCompression(BlockId(blockId), data.inputStream())).asIterator - )) - shuffleMetrics.remoteBytesRead += data.size - shuffleMetrics.remoteBlocksFetched += 1 + override def onBlockFetchSuccess(blockId: String, buf: ManagedBuffer): Unit = { + // Only add the buffer to results queue if the iterator is not zombie, + // i.e. cleanup() has not been called yet. + if (!isZombie) { + // Increment the ref count because we need to pass this to a different thread. + // This needs to be released after use. + buf.retain() + results.put(new FetchResult(BlockId(blockId), sizeMap(blockId), buf)) + shuffleMetrics.remoteBytesRead += buf.size + shuffleMetrics.remoteBlocksFetched += 1 + } logDebug("Got remote block " + blockId + " after " + Utils.getUsedTimeMs(startTime)) } @@ -138,7 +174,7 @@ final class ShuffleBlockFetcherIterator( // smaller than maxBytesInFlight is to allow multiple, parallel fetches from up to 5 // nodes, rather than blocking on reading output from one node. val targetRequestSize = math.max(maxBytesInFlight / 5, 1L) - logInfo("maxBytesInFlight: " + maxBytesInFlight + ", targetRequestSize: " + targetRequestSize) + logDebug("maxBytesInFlight: " + maxBytesInFlight + ", targetRequestSize: " + targetRequestSize) // Split local and remote blocks. Remote blocks are further split into FetchRequests of size // at most maxBytesInFlight in order to limit the amount of data in flight. @@ -185,26 +221,34 @@ final class ShuffleBlockFetcherIterator( remoteRequests } + /** + * Fetch the local blocks while we are fetching remote blocks. This is ok because + * [[ManagedBuffer]]'s memory is allocated lazily when we create the input stream, so all we + * track in-memory are the ManagedBuffer references themselves. + */ private[this] def fetchLocalBlocks() { - // Get the local blocks while remote blocks are being fetched. Note that it's okay to do - // these all at once because they will just memory-map some files, so they won't consume - // any memory that might exceed our maxBytesInFlight - for (id <- localBlocks) { + val iter = localBlocks.iterator + while (iter.hasNext) { + val blockId = iter.next() try { + val buf = blockManager.getBlockData(blockId.toString) shuffleMetrics.localBlocksFetched += 1 - results.put(new FetchResult( - id, 0, () => blockManager.getLocalShuffleFromDisk(id, serializer).get)) - logDebug("Got local block " + id) + buf.retain() + results.put(new FetchResult(blockId, 0, buf)) } catch { case e: Exception => + // If we see an exception, stop immediately. logError(s"Error occurred while fetching local blocks", e) - results.put(new FetchResult(id, -1, null)) + results.put(new FetchResult(blockId, -1, null)) return } } } private[this] def initialize(): Unit = { + // Add a task completion callback (called in both success case and failure case) to cleanup. + context.addTaskCompletionListener(_ => cleanup()) + // Split local and remote blocks. val remoteRequests = splitLocalRemoteBlocks() // Add the remote requests into our queue in a random order @@ -229,7 +273,8 @@ final class ShuffleBlockFetcherIterator( override def next(): (BlockId, Option[Iterator[Any]]) = { numBlocksProcessed += 1 val startFetchWait = System.currentTimeMillis() - val result = results.take() + currentResult = results.take() + val result = currentResult val stopFetchWait = System.currentTimeMillis() shuffleMetrics.fetchWaitTime += (stopFetchWait - startFetchWait) if (!result.failed) { @@ -240,7 +285,21 @@ final class ShuffleBlockFetcherIterator( (bytesInFlight == 0 || bytesInFlight + fetchRequests.front.size <= maxBytesInFlight)) { sendRequest(fetchRequests.dequeue()) } - (result.blockId, if (result.failed) None else Some(result.deserialize())) + + val iteratorOpt: Option[Iterator[Any]] = if (result.failed) { + None + } else { + val is = blockManager.wrapForCompression(result.blockId, result.buf.inputStream()) + val iter = serializer.newInstance().deserializeStream(is).asIterator + Some(CompletionIterator[Any, Iterator[Any]](iter, { + // Once the iterator is exhausted, release the buffer and set currentResult to null + // so we don't release it again in cleanup. + currentResult = null + result.buf.release() + })) + } + + (result.blockId, iteratorOpt) } } @@ -262,10 +321,10 @@ object ShuffleBlockFetcherIterator { * Result of a fetch from a remote block. A failure is represented as size == -1. * @param blockId block id * @param size estimated size of the block, used to calculate bytesInFlight. - * Note that this is NOT the exact bytes. - * @param deserialize closure to return the result in the form of an Iterator. + * Note that this is NOT the exact bytes. -1 if failure is present. + * @param buf [[ManagedBuffer]] for the content. null is error. */ - class FetchResult(val blockId: BlockId, val size: Long, val deserialize: () => Iterator[Any]) { + class FetchResult(val blockId: BlockId, val size: Long, val buf: ManagedBuffer) { def failed: Boolean = size == -1 } } diff --git a/core/src/test/scala/org/apache/spark/network/netty/ServerClientIntegrationSuite.scala b/core/src/test/scala/org/apache/spark/network/netty/ServerClientIntegrationSuite.scala index 178c60a048b9f..72d7c4b531099 100644 --- a/core/src/test/scala/org/apache/spark/network/netty/ServerClientIntegrationSuite.scala +++ b/core/src/test/scala/org/apache/spark/network/netty/ServerClientIntegrationSuite.scala @@ -30,7 +30,7 @@ import org.scalatest.{BeforeAndAfterAll, FunSuite} import org.apache.spark.SparkConf import org.apache.spark.network._ -import org.apache.spark.storage.StorageLevel +import org.apache.spark.storage.{BlockNotFoundException, StorageLevel} /** @@ -62,14 +62,14 @@ class ServerClientIntegrationSuite extends FunSuite with BeforeAndAfterAll { fp.write(fileContent) fp.close() - server = new BlockServer(new SparkConf, new BlockDataManager { - override def getBlockData(blockId: String): Option[ManagedBuffer] = { + server = new BlockServer(new NettyConfig(new SparkConf), new BlockDataManager { + override def getBlockData(blockId: String): ManagedBuffer = { if (blockId == bufferBlockId) { - Some(new NioByteBufferManagedBuffer(buf)) + new NioByteBufferManagedBuffer(buf) } else if (blockId == fileBlockId) { - Some(new FileSegmentManagedBuffer(testFile, 10, testFile.length - 25)) + new FileSegmentManagedBuffer(testFile, 10, testFile.length - 25) } else { - None + throw new BlockNotFoundException(blockId) } } diff --git a/core/src/test/scala/org/apache/spark/network/netty/TestManagedBuffer.scala b/core/src/test/scala/org/apache/spark/network/netty/TestManagedBuffer.scala index 6ae2d3b3faf91..1d13fd92e1f23 100644 --- a/core/src/test/scala/org/apache/spark/network/netty/TestManagedBuffer.scala +++ b/core/src/test/scala/org/apache/spark/network/netty/TestManagedBuffer.scala @@ -65,4 +65,8 @@ class TestManagedBuffer(len: Int) extends ManagedBuffer { } case _ => false } + + override def retain(): this.type = this + + override def release(): this.type = this } From b5c8d1fca6d3cf5c2b95395310200c8149a7eb16 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Wed, 10 Sep 2014 01:09:44 -0700 Subject: [PATCH 05/46] Fixed ShuffleBlockFetcherIteratorSuite. --- .../apache/spark/storage/BlockManager.scala | 5 +- .../netty/ServerClientIntegrationSuite.scala | 30 ++-- .../ShuffleBlockFetcherIteratorSuite.scala | 134 ++++-------------- 3 files changed, 47 insertions(+), 122 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index 9995a25c224fe..b803d70afe9d9 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -212,9 +212,8 @@ private[spark] class BlockManager( } /** - * Interface to get local block data. - * - * @return Some(buffer) if the block exists locally, and None if it doesn't. + * Interface to get local block data. Throws an exception if the block cannot be found or + * cannot be read successfully. */ override def getBlockData(blockId: String): ManagedBuffer = { val bid = BlockId(blockId) diff --git a/core/src/test/scala/org/apache/spark/network/netty/ServerClientIntegrationSuite.scala b/core/src/test/scala/org/apache/spark/network/netty/ServerClientIntegrationSuite.scala index 72d7c4b531099..3dacc0fb69be7 100644 --- a/core/src/test/scala/org/apache/spark/network/netty/ServerClientIntegrationSuite.scala +++ b/core/src/test/scala/org/apache/spark/network/netty/ServerClientIntegrationSuite.scala @@ -1,19 +1,19 @@ /* -* Licensed to the Apache Software Foundation (ASF) under one or more -* contributor license agreements. See the NOTICE file distributed with -* this work for additional information regarding copyright ownership. -* The ASF licenses this file to You under the Apache License, Version 2.0 -* (the "License"); you may not use this file except in compliance with -* the License. You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*/ + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ package org.apache.spark.network.netty diff --git a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala index 809bd70929656..d4c5e7bc39b88 100644 --- a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala @@ -17,9 +17,6 @@ package org.apache.spark.storage -import org.apache.spark.TaskContext -import org.apache.spark.network.{BlockFetchingListener, BlockTransferService} - import org.mockito.Mockito._ import org.mockito.Matchers.{any, eq => meq} import org.mockito.invocation.InvocationOnMock @@ -27,126 +24,55 @@ import org.mockito.stubbing.Answer import org.scalatest.FunSuite +import org.apache.spark.{SparkConf, TaskContext} +import org.apache.spark.network._ +import org.apache.spark.serializer.TestSerializer -class ShuffleBlockFetcherIteratorSuite extends FunSuite { - - test("handle local read failures in BlockManager") { - val transfer = mock(classOf[BlockTransferService]) - val blockManager = mock(classOf[BlockManager]) - doReturn(BlockManagerId("test-client", "test-client", 1)).when(blockManager).blockManagerId - val blIds = Array[BlockId]( - ShuffleBlockId(0,0,0), - ShuffleBlockId(0,1,0), - ShuffleBlockId(0,2,0), - ShuffleBlockId(0,3,0), - ShuffleBlockId(0,4,0)) - - val optItr = mock(classOf[Option[Iterator[Any]]]) - val answer = new Answer[Option[Iterator[Any]]] { - override def answer(invocation: InvocationOnMock) = Option[Iterator[Any]] { - throw new Exception - } - } - - // 3rd block is going to fail - doReturn(optItr).when(blockManager).getLocalShuffleFromDisk(meq(blIds(0)), any()) - doReturn(optItr).when(blockManager).getLocalShuffleFromDisk(meq(blIds(1)), any()) - doAnswer(answer).when(blockManager).getLocalShuffleFromDisk(meq(blIds(2)), any()) - doReturn(optItr).when(blockManager).getLocalShuffleFromDisk(meq(blIds(3)), any()) - doReturn(optItr).when(blockManager).getLocalShuffleFromDisk(meq(blIds(4)), any()) - - val bmId = BlockManagerId("test-client", "test-client", 1) - val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long)])]( - (bmId, blIds.map(blId => (blId, 1.asInstanceOf[Long])).toSeq) - ) - - val iterator = new ShuffleBlockFetcherIterator( - new TaskContext(0, 0, 0), - transfer, - blockManager, - blocksByAddress, - null, - 48 * 1024 * 1024) +class ShuffleBlockFetcherIteratorSuite extends FunSuite { - // Without exhausting the iterator, the iterator should be lazy and not call - // getLocalShuffleFromDisk. - verify(blockManager, times(0)).getLocalShuffleFromDisk(any(), any()) - - assert(iterator.hasNext, "iterator should have 5 elements but actually has no elements") - // the 2nd element of the tuple returned by iterator.next should be defined when - // fetching successfully - assert(iterator.next()._2.isDefined, - "1st element should be defined but is not actually defined") - verify(blockManager, times(1)).getLocalShuffleFromDisk(any(), any()) - - assert(iterator.hasNext, "iterator should have 5 elements but actually has 1 element") - assert(iterator.next()._2.isDefined, - "2nd element should be defined but is not actually defined") - verify(blockManager, times(2)).getLocalShuffleFromDisk(any(), any()) - - assert(iterator.hasNext, "iterator should have 5 elements but actually has 2 elements") - // 3rd fetch should be failed - intercept[Exception] { - iterator.next() - } - verify(blockManager, times(3)).getLocalShuffleFromDisk(any(), any()) - } + val conf = new SparkConf - test("handle local read successes") { - val transfer = mock(classOf[BlockTransferService]) + test("handle successful local reads") { + val buf = mock(classOf[ManagedBuffer]) val blockManager = mock(classOf[BlockManager]) doReturn(BlockManagerId("test-client", "test-client", 1)).when(blockManager).blockManagerId - val blIds = Array[BlockId]( - ShuffleBlockId(0,0,0), - ShuffleBlockId(0,1,0), - ShuffleBlockId(0,2,0), - ShuffleBlockId(0,3,0), - ShuffleBlockId(0,4,0)) - - val optItr = mock(classOf[Option[Iterator[Any]]]) + val blockIds = Array[BlockId]( + ShuffleBlockId(0, 0, 0), + ShuffleBlockId(0, 1, 0), + ShuffleBlockId(0, 2, 0), + ShuffleBlockId(0, 3, 0), + ShuffleBlockId(0, 4, 0)) // All blocks should be fetched successfully - doReturn(optItr).when(blockManager).getLocalShuffleFromDisk(meq(blIds(0)), any()) - doReturn(optItr).when(blockManager).getLocalShuffleFromDisk(meq(blIds(1)), any()) - doReturn(optItr).when(blockManager).getLocalShuffleFromDisk(meq(blIds(2)), any()) - doReturn(optItr).when(blockManager).getLocalShuffleFromDisk(meq(blIds(3)), any()) - doReturn(optItr).when(blockManager).getLocalShuffleFromDisk(meq(blIds(4)), any()) + blockIds.foreach { blockId => + doReturn(buf).when(blockManager).getBlockData(meq(blockId.toString)) + } val bmId = BlockManagerId("test-client", "test-client", 1) val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long)])]( - (bmId, blIds.map(blId => (blId, 1.asInstanceOf[Long])).toSeq) + (bmId, blockIds.map(blockId => (blockId, 1.asInstanceOf[Long])).toSeq) ) val iterator = new ShuffleBlockFetcherIterator( new TaskContext(0, 0, 0), - transfer, + mock(classOf[BlockTransferService]), blockManager, blocksByAddress, - null, + new TestSerializer, 48 * 1024 * 1024) - // Without exhausting the iterator, the iterator should be lazy and not call getLocalShuffleFromDisk. - verify(blockManager, times(0)).getLocalShuffleFromDisk(any(), any()) - - assert(iterator.hasNext, "iterator should have 5 elements but actually has no elements") - assert(iterator.next()._2.isDefined, - "All elements should be defined but 1st element is not actually defined") - assert(iterator.hasNext, "iterator should have 5 elements but actually has 1 element") - assert(iterator.next()._2.isDefined, - "All elements should be defined but 2nd element is not actually defined") - assert(iterator.hasNext, "iterator should have 5 elements but actually has 2 elements") - assert(iterator.next()._2.isDefined, - "All elements should be defined but 3rd element is not actually defined") - assert(iterator.hasNext, "iterator should have 5 elements but actually has 3 elements") - assert(iterator.next()._2.isDefined, - "All elements should be defined but 4th element is not actually defined") - assert(iterator.hasNext, "iterator should have 5 elements but actually has 4 elements") - assert(iterator.next()._2.isDefined, - "All elements should be defined but 5th element is not actually defined") - - verify(blockManager, times(5)).getLocalShuffleFromDisk(any(), any()) + // Local blocks are fetched immediately. + verify(blockManager, times(5)).getBlockData(any()) + + for (i <- 0 until 5) { + assert(iterator.hasNext, s"iterator should have 5 elements but actually has $i elements") + assert(iterator.next()._2.isDefined, + s"iterator should have 5 elements defined but actually has $i elements") + } + // No more fetching of local blocks. + verify(blockManager, times(5)).getBlockData(any()) } test("handle remote fetch failures in BlockTransferService") { @@ -173,7 +99,7 @@ class ShuffleBlockFetcherIteratorSuite extends FunSuite { transfer, blockManager, blocksByAddress, - null, + new TestSerializer, 48 * 1024 * 1024) iterator.foreach { case (_, iterOption) => From 108c9edaed06c5e046a21c9a8e54c50390da9a0b Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Wed, 10 Sep 2014 01:10:04 -0700 Subject: [PATCH 06/46] Forgot to add TestSerializer to the commit list. --- .../spark/serializer/TestSerializer.scala | 60 +++++++++++++++++++ 1 file changed, 60 insertions(+) create mode 100644 core/src/test/scala/org/apache/spark/serializer/TestSerializer.scala diff --git a/core/src/test/scala/org/apache/spark/serializer/TestSerializer.scala b/core/src/test/scala/org/apache/spark/serializer/TestSerializer.scala new file mode 100644 index 0000000000000..0ade1bab18d7e --- /dev/null +++ b/core/src/test/scala/org/apache/spark/serializer/TestSerializer.scala @@ -0,0 +1,60 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.serializer + +import java.io.{EOFException, OutputStream, InputStream} +import java.nio.ByteBuffer + +import scala.reflect.ClassTag + + +/** + * A serializer implementation that always return a single element in a deserialization stream. + */ +class TestSerializer extends Serializer { + override def newInstance() = new TestSerializerInstance +} + + +class TestSerializerInstance extends SerializerInstance { + override def serialize[T: ClassTag](t: T): ByteBuffer = ??? + + override def serializeStream(s: OutputStream): SerializationStream = ??? + + override def deserializeStream(s: InputStream) = new TestDeserializationStream + + override def deserialize[T: ClassTag](bytes: ByteBuffer): T = ??? + + override def deserialize[T: ClassTag](bytes: ByteBuffer, loader: ClassLoader): T = ??? +} + + +class TestDeserializationStream extends DeserializationStream { + + private var count = 0 + + override def readObject[T: ClassTag](): T = { + count += 1 + if (count == 2) { + throw new EOFException + } + new Object().asInstanceOf[T] + } + + override def close(): Unit = {} +} From 1be4e8ee7d932821c789cb974310e5d59df4ff84 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Wed, 10 Sep 2014 01:11:40 -0700 Subject: [PATCH 07/46] Shorten NioManagedBuffer and NettyManagedBuffer class names. --- .../scala/org/apache/spark/network/ManagedBuffer.scala | 10 +++++----- .../org/apache/spark/network/netty/protocol.scala | 6 +++--- .../spark/network/nio/NioBlockTransferService.scala | 4 ++-- .../scala/org/apache/spark/storage/BlockManager.scala | 4 ++-- .../spark/network/netty/BlockClientHandlerSuite.scala | 4 ++-- .../network/netty/ServerClientIntegrationSuite.scala | 2 +- .../apache/spark/network/netty/TestManagedBuffer.scala | 4 ++-- 7 files changed, 17 insertions(+), 17 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/network/ManagedBuffer.scala b/core/src/main/scala/org/apache/spark/network/ManagedBuffer.scala index 7f364947dd930..86a0d653b341b 100644 --- a/core/src/main/scala/org/apache/spark/network/ManagedBuffer.scala +++ b/core/src/main/scala/org/apache/spark/network/ManagedBuffer.scala @@ -36,11 +36,11 @@ import org.apache.spark.util.{ByteBufferInputStream, Utils} * should specify how the data is provided: * * - [[FileSegmentManagedBuffer]]: data backed by part of a file - * - [[NioByteBufferManagedBuffer]]: data backed by a NIO ByteBuffer - * - [[NettyByteBufManagedBuffer]]: data backed by a Netty ByteBuf + * - [[NioManagedBuffer]]: data backed by a NIO ByteBuffer + * - [[NettyManagedBuffer]]: data backed by a Netty ByteBuf * * The concrete buffer implementation might be managed outside the JVM garbage collector. - * For example, in the case of [[NettyByteBufManagedBuffer]], the buffers are reference counted. + * For example, in the case of [[NettyManagedBuffer]], the buffers are reference counted. * In that case, if the buffer is going to be passed around to a different thread, retain/release * should be called. */ @@ -149,7 +149,7 @@ final class FileSegmentManagedBuffer(val file: File, val offset: Long, val lengt /** * A [[ManagedBuffer]] backed by [[java.nio.ByteBuffer]]. */ -final class NioByteBufferManagedBuffer(buf: ByteBuffer) extends ManagedBuffer { +final class NioManagedBuffer(buf: ByteBuffer) extends ManagedBuffer { override def size: Long = buf.remaining() @@ -168,7 +168,7 @@ final class NioByteBufferManagedBuffer(buf: ByteBuffer) extends ManagedBuffer { /** * A [[ManagedBuffer]] backed by a Netty [[ByteBuf]]. */ -final class NettyByteBufManagedBuffer(buf: ByteBuf) extends ManagedBuffer { +final class NettyManagedBuffer(buf: ByteBuf) extends ManagedBuffer { override def size: Long = buf.readableBytes() diff --git a/core/src/main/scala/org/apache/spark/network/netty/protocol.scala b/core/src/main/scala/org/apache/spark/network/netty/protocol.scala index ac6a4d00f654f..ac9d2097c93e2 100644 --- a/core/src/main/scala/org/apache/spark/network/netty/protocol.scala +++ b/core/src/main/scala/org/apache/spark/network/netty/protocol.scala @@ -25,7 +25,7 @@ import io.netty.channel.ChannelHandler.Sharable import io.netty.handler.codec._ import org.apache.spark.Logging -import org.apache.spark.network.{NettyByteBufManagedBuffer, ManagedBuffer} +import org.apache.spark.network.{NettyManagedBuffer, ManagedBuffer} /** Messages from the client to the server. */ @@ -141,7 +141,7 @@ final class ClientRequestDecoder extends MessageToMessageDecoder[ByteBuf] { case 1 => // BlockUploadRequest val blockId = ProtocolUtils.readBlockId(in) in.retain() // retain the bytebuf so we don't recycle it immediately. - BlockUploadRequest(blockId, new NettyByteBufManagedBuffer(in)) + BlockUploadRequest(blockId, new NettyManagedBuffer(in)) } assert(decoded.id == msgTypeId) @@ -218,7 +218,7 @@ final class ServerResponseDecoder extends MessageToMessageDecoder[ByteBuf] { case 0 => // BlockFetchSuccess val blockId = ProtocolUtils.readBlockId(in) in.retain() - new BlockFetchSuccess(blockId, new NettyByteBufManagedBuffer(in)) + new BlockFetchSuccess(blockId, new NettyManagedBuffer(in)) case 1 => // BlockFetchFailure val blockId = ProtocolUtils.readBlockId(in) diff --git a/core/src/main/scala/org/apache/spark/network/nio/NioBlockTransferService.scala b/core/src/main/scala/org/apache/spark/network/nio/NioBlockTransferService.scala index 457ba106ced89..91ebb3fe0e0f3 100644 --- a/core/src/main/scala/org/apache/spark/network/nio/NioBlockTransferService.scala +++ b/core/src/main/scala/org/apache/spark/network/nio/NioBlockTransferService.scala @@ -104,7 +104,7 @@ final class NioBlockTransferService(conf: SparkConf, securityManager: SecurityMa val blockId = blockMessage.getId val networkSize = blockMessage.getData.limit() listener.onBlockFetchSuccess( - blockId.toString, new NioByteBufferManagedBuffer(blockMessage.getData)) + blockId.toString, new NioManagedBuffer(blockMessage.getData)) } } }(cm.futureExecContext) @@ -189,7 +189,7 @@ final class NioBlockTransferService(conf: SparkConf, securityManager: SecurityMa private def putBlock(blockId: String, bytes: ByteBuffer, level: StorageLevel) { val startTimeMs = System.currentTimeMillis() logDebug("PutBlock " + blockId + " started from " + startTimeMs + " with data: " + bytes) - blockDataManager.putBlockData(blockId, new NioByteBufferManagedBuffer(bytes), level) + blockDataManager.putBlockData(blockId, new NioManagedBuffer(bytes), level) logDebug("PutBlock " + blockId + " used " + Utils.getUsedTimeMs(startTimeMs) + " with data size: " + bytes.limit) } diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index b803d70afe9d9..663c5e83ac4c9 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -223,7 +223,7 @@ private[spark] class BlockManager( val blockBytesOpt = doGetLocal(bid, asBlockResult = false).asInstanceOf[Option[ByteBuffer]] if (blockBytesOpt.isDefined) { val buffer = blockBytesOpt.get - new NioByteBufferManagedBuffer(buffer) + new NioManagedBuffer(buffer) } else { throw new BlockNotFoundException(blockId) } @@ -868,7 +868,7 @@ private[spark] class BlockManager( data.rewind() logTrace(s"Trying to replicate $blockId of ${data.limit()} bytes to $peer") blockTransferService.uploadBlockSync( - peer.host, peer.port, blockId.toString, new NioByteBufferManagedBuffer(data), tLevel) + peer.host, peer.port, blockId.toString, new NioManagedBuffer(data), tLevel) logTrace(s"Replicated $blockId of ${data.limit()} bytes to $peer in %f ms" .format((System.currentTimeMillis - onePeerStartTime))) peersReplicatedTo += peer diff --git a/core/src/test/scala/org/apache/spark/network/netty/BlockClientHandlerSuite.scala b/core/src/test/scala/org/apache/spark/network/netty/BlockClientHandlerSuite.scala index 1358b2f9c8071..f2ed404ed8d4c 100644 --- a/core/src/test/scala/org/apache/spark/network/netty/BlockClientHandlerSuite.scala +++ b/core/src/test/scala/org/apache/spark/network/netty/BlockClientHandlerSuite.scala @@ -64,7 +64,7 @@ class BlockClientHandlerSuite extends FunSuite with PrivateMethodTester { buf.put(blockData.getBytes) buf.flip() - channel.writeInbound(BlockFetchSuccess(blockId, new NioByteBufferManagedBuffer(buf))) + channel.writeInbound(BlockFetchSuccess(blockId, new NioManagedBuffer(buf))) assert(parsedBlockId === blockId) assert(parsedBlockData === blockData) @@ -119,7 +119,7 @@ class BlockClientHandlerSuite extends FunSuite with PrivateMethodTester { assert(sizeOfOutstandingRequests(handler) === 3) val channel = new EmbeddedChannel(handler) - channel.writeInbound(BlockFetchSuccess("b1", new NettyByteBufManagedBuffer(Unpooled.buffer()))) + channel.writeInbound(BlockFetchSuccess("b1", new NettyManagedBuffer(Unpooled.buffer()))) // Need to figure out a way to generate an exception assert(successCount.get() === 1) assert(errorCount.get() === 2) diff --git a/core/src/test/scala/org/apache/spark/network/netty/ServerClientIntegrationSuite.scala b/core/src/test/scala/org/apache/spark/network/netty/ServerClientIntegrationSuite.scala index 3dacc0fb69be7..fa3512768d9a8 100644 --- a/core/src/test/scala/org/apache/spark/network/netty/ServerClientIntegrationSuite.scala +++ b/core/src/test/scala/org/apache/spark/network/netty/ServerClientIntegrationSuite.scala @@ -65,7 +65,7 @@ class ServerClientIntegrationSuite extends FunSuite with BeforeAndAfterAll { server = new BlockServer(new NettyConfig(new SparkConf), new BlockDataManager { override def getBlockData(blockId: String): ManagedBuffer = { if (blockId == bufferBlockId) { - new NioByteBufferManagedBuffer(buf) + new NioManagedBuffer(buf) } else if (blockId == fileBlockId) { new FileSegmentManagedBuffer(testFile, 10, testFile.length - 25) } else { diff --git a/core/src/test/scala/org/apache/spark/network/netty/TestManagedBuffer.scala b/core/src/test/scala/org/apache/spark/network/netty/TestManagedBuffer.scala index 1d13fd92e1f23..e47e4d03fa898 100644 --- a/core/src/test/scala/org/apache/spark/network/netty/TestManagedBuffer.scala +++ b/core/src/test/scala/org/apache/spark/network/netty/TestManagedBuffer.scala @@ -22,7 +22,7 @@ import java.nio.ByteBuffer import io.netty.buffer.Unpooled -import org.apache.spark.network.{NettyByteBufManagedBuffer, ManagedBuffer} +import org.apache.spark.network.{NettyManagedBuffer, ManagedBuffer} /** @@ -36,7 +36,7 @@ class TestManagedBuffer(len: Int) extends ManagedBuffer { private val byteArray: Array[Byte] = Array.tabulate[Byte](len)(_.toByte) - private val underlying = new NettyByteBufManagedBuffer(Unpooled.wrappedBuffer(byteArray)) + private val underlying = new NettyManagedBuffer(Unpooled.wrappedBuffer(byteArray)) override def size: Long = underlying.size From cb589ec7b6d3758498249b63b395634efb83d8ba Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Wed, 10 Sep 2014 19:01:23 -0700 Subject: [PATCH 08/46] Added more test cases covering cleanup when fault happens in ShuffleBlockFetcherIteratorSuite --- .../storage/ShuffleBlockFetcherIterator.scala | 11 +- .../ShuffleBlockFetcherIteratorSuite.scala | 189 +++++++++++++++--- 2 files changed, 164 insertions(+), 36 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala index 23f7d56895fe5..61fe0aaf0d44f 100644 --- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala +++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala @@ -119,7 +119,7 @@ final class ShuffleBlockFetcherIterator( private[this] def cleanup() { isZombie = true // Release the current buffer if necessary - if (currentResult != null && currentResult.buf != null) { + if (currentResult != null && !currentResult.failed) { currentResult.buf.release() } @@ -127,7 +127,9 @@ final class ShuffleBlockFetcherIterator( val iter = results.iterator() while (iter.hasNext) { val result = iter.next() - result.buf.release() + if (!result.failed) { + result.buf.release() + } } } @@ -313,7 +315,7 @@ object ShuffleBlockFetcherIterator { * @param blocks Sequence of tuple, where the first element is the block id, * and the second element is the estimated size, used to calculate bytesInFlight. */ - class FetchRequest(val address: BlockManagerId, val blocks: Seq[(BlockId, Long)]) { + case class FetchRequest(address: BlockManagerId, blocks: Seq[(BlockId, Long)]) { val size = blocks.map(_._2).sum } @@ -324,7 +326,8 @@ object ShuffleBlockFetcherIterator { * Note that this is NOT the exact bytes. -1 if failure is present. * @param buf [[ManagedBuffer]] for the content. null is error. */ - class FetchResult(val blockId: BlockId, val size: Long, val buf: ManagedBuffer) { + case class FetchResult(blockId: BlockId, size: Long, buf: ManagedBuffer) { def failed: Boolean = size == -1 + if (failed) assert(buf == null) else assert(buf != null) } } diff --git a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala index d4c5e7bc39b88..b4700f38a6781 100644 --- a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala @@ -17,6 +17,11 @@ package org.apache.spark.storage +import java.util.concurrent.Semaphore + +import scala.concurrent.future +import scala.concurrent.ExecutionContext.Implicits.global + import org.mockito.Mockito._ import org.mockito.Matchers.{any, eq => meq} import org.mockito.invocation.InvocationOnMock @@ -30,80 +35,200 @@ import org.apache.spark.serializer.TestSerializer class ShuffleBlockFetcherIteratorSuite extends FunSuite { + // Some of the tests are quite tricky because we are testing the cleanup behavior + // in the presence of faults. - val conf = new SparkConf + /** Creates a mock [[BlockTransferService]] that returns data from the given map. */ + private def createMockTransfer(data: Map[BlockId, ManagedBuffer]): BlockTransferService = { + val transfer = mock(classOf[BlockTransferService]) + when(transfer.fetchBlocks(any(), any(), any(), any())).thenAnswer(new Answer[Unit] { + override def answer(invocation: InvocationOnMock): Unit = { + val blocks = invocation.getArguments()(2).asInstanceOf[Seq[String]] + val listener = invocation.getArguments()(3).asInstanceOf[BlockFetchingListener] - test("handle successful local reads") { - val buf = mock(classOf[ManagedBuffer]) - val blockManager = mock(classOf[BlockManager]) - doReturn(BlockManagerId("test-client", "test-client", 1)).when(blockManager).blockManagerId + for (blockId <- blocks) { + if (data.contains(BlockId(blockId))) { + listener.onBlockFetchSuccess(blockId, data(BlockId(blockId))) + } else { + listener.onBlockFetchFailure(new BlockNotFoundException(blockId)) + } + } + } + }) + transfer + } - val blockIds = Array[BlockId]( - ShuffleBlockId(0, 0, 0), - ShuffleBlockId(0, 1, 0), - ShuffleBlockId(0, 2, 0), - ShuffleBlockId(0, 3, 0), - ShuffleBlockId(0, 4, 0)) + private val conf = new SparkConf - // All blocks should be fetched successfully - blockIds.foreach { blockId => + test("successful 3 local reads + 2 remote reads") { + val blockManager = mock(classOf[BlockManager]) + val localBmId = BlockManagerId("test-client", "test-client", 1) + doReturn(localBmId).when(blockManager).blockManagerId + + // Make sure blockManager.getBlockData would return the blocks + val localBlocks = Map[BlockId, ManagedBuffer]( + ShuffleBlockId(0, 0, 0) -> mock(classOf[ManagedBuffer]), + ShuffleBlockId(0, 1, 0) -> mock(classOf[ManagedBuffer]), + ShuffleBlockId(0, 2, 0) -> mock(classOf[ManagedBuffer])) + localBlocks.foreach { case (blockId, buf) => doReturn(buf).when(blockManager).getBlockData(meq(blockId.toString)) } - val bmId = BlockManagerId("test-client", "test-client", 1) + // Make sure remote blocks would return + val remoteBmId = BlockManagerId("test-client-1", "test-client-1", 2) + val remoteBlocks = Map[BlockId, ManagedBuffer]( + ShuffleBlockId(0, 3, 0) -> mock(classOf[ManagedBuffer]), + ShuffleBlockId(0, 4, 0) -> mock(classOf[ManagedBuffer]) + ) + + val transfer = createMockTransfer(remoteBlocks) + val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long)])]( - (bmId, blockIds.map(blockId => (blockId, 1.asInstanceOf[Long])).toSeq) + (localBmId, localBlocks.keys.map(blockId => (blockId, 1.asInstanceOf[Long])).toSeq), + (remoteBmId, remoteBlocks.keys.map(blockId => (blockId, 1.asInstanceOf[Long])).toSeq) ) val iterator = new ShuffleBlockFetcherIterator( new TaskContext(0, 0, 0), - mock(classOf[BlockTransferService]), + transfer, blockManager, blocksByAddress, new TestSerializer, 48 * 1024 * 1024) - // Local blocks are fetched immediately. - verify(blockManager, times(5)).getBlockData(any()) + // 3 local blocks fetched in initialization + verify(blockManager, times(3)).getBlockData(any()) for (i <- 0 until 5) { assert(iterator.hasNext, s"iterator should have 5 elements but actually has $i elements") - assert(iterator.next()._2.isDefined, + val (blockId, subIterator) = iterator.next() + assert(subIterator.isDefined, s"iterator should have 5 elements defined but actually has $i elements") + + // Make sure we release the buffer once the iterator is exhausted. + val mockBuf = localBlocks.getOrElse(blockId, remoteBlocks(blockId)) + verify(mockBuf, times(0)).release() + subIterator.get.foreach(_ => Unit) // exhaust the iterator + verify(mockBuf, times(1)).release() } - // No more fetching of local blocks. - verify(blockManager, times(5)).getBlockData(any()) + + // 3 local blocks, and 2 remote blocks + // (but from the same block manager so one call to fetchBlocks) + verify(blockManager, times(3)).getBlockData(any()) + verify(transfer, times(1)).fetchBlocks(any(), any(), any(), any()) } - test("handle remote fetch failures in BlockTransferService") { + test("release current unexhausted buffer in case the task completes early") { + val blockManager = mock(classOf[BlockManager]) + val localBmId = BlockManagerId("test-client", "test-client", 1) + doReturn(localBmId).when(blockManager).blockManagerId + + // Make sure remote blocks would return + val remoteBmId = BlockManagerId("test-client-1", "test-client-1", 2) + val blocks = Map[BlockId, ManagedBuffer]( + ShuffleBlockId(0, 0, 0) -> mock(classOf[ManagedBuffer]), + ShuffleBlockId(0, 1, 0) -> mock(classOf[ManagedBuffer]), + ShuffleBlockId(0, 2, 0) -> mock(classOf[ManagedBuffer]) + ) + + // Semaphore to coordinate event sequence in two different threads. + val sem = new Semaphore(0) + val transfer = mock(classOf[BlockTransferService]) when(transfer.fetchBlocks(any(), any(), any(), any())).thenAnswer(new Answer[Unit] { override def answer(invocation: InvocationOnMock): Unit = { val listener = invocation.getArguments()(3).asInstanceOf[BlockFetchingListener] - listener.onBlockFetchFailure(new Exception("blah")) + future { + // Return the first two blocks, and wait till task completion before returning the 3rd one + listener.onBlockFetchSuccess( + ShuffleBlockId(0, 0, 0).toString, blocks(ShuffleBlockId(0, 0, 0))) + listener.onBlockFetchSuccess( + ShuffleBlockId(0, 1, 0).toString, blocks(ShuffleBlockId(0, 1, 0))) + sem.acquire() + listener.onBlockFetchSuccess( + ShuffleBlockId(0, 2, 0).toString, blocks(ShuffleBlockId(0, 2, 0))) + } } }) + val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long)])]( + (remoteBmId, blocks.keys.map(blockId => (blockId, 1.asInstanceOf[Long])).toSeq)) + + val taskContext = new TaskContext(0, 0, 0) + val iterator = new ShuffleBlockFetcherIterator( + taskContext, + transfer, + blockManager, + blocksByAddress, + new TestSerializer, + 48 * 1024 * 1024) + + // Exhaust the first block, and then it should be released. + iterator.next()._2.get.foreach(_ => Unit) + verify(blocks(ShuffleBlockId(0, 0, 0)), times(1)).release() + + // Get the 2nd block but do not exhaust the iterator + val subIter = iterator.next()._2.get + + // Complete the task; then the 2nd block buffer should be exhausted + verify(blocks(ShuffleBlockId(0, 1, 0)), times(0)).release() + taskContext.markTaskCompleted() + verify(blocks(ShuffleBlockId(0, 1, 0)), times(1)).release() + + // The 3rd block should not be retained because the iterator is already in zombie state + sem.release() + verify(blocks(ShuffleBlockId(0, 2, 0)), times(0)).retain() + verify(blocks(ShuffleBlockId(0, 2, 0)), times(0)).release() + } + + test("fail all blocks if any of the remote request fails") { val blockManager = mock(classOf[BlockManager]) + val localBmId = BlockManagerId("test-client", "test-client", 1) + doReturn(localBmId).when(blockManager).blockManagerId + + // Make sure remote blocks would return + val remoteBmId = BlockManagerId("test-client-1", "test-client-1", 2) + val blocks = Map[BlockId, ManagedBuffer]( + ShuffleBlockId(0, 0, 0) -> mock(classOf[ManagedBuffer]), + ShuffleBlockId(0, 1, 0) -> mock(classOf[ManagedBuffer]), + ShuffleBlockId(0, 2, 0) -> mock(classOf[ManagedBuffer]) + ) - when(blockManager.blockManagerId).thenReturn(BlockManagerId("test-client", "test-client", 1)) + // Semaphore to coordinate event sequence in two different threads. + val sem = new Semaphore(0) + + val transfer = mock(classOf[BlockTransferService]) + when(transfer.fetchBlocks(any(), any(), any(), any())).thenAnswer(new Answer[Unit] { + override def answer(invocation: InvocationOnMock): Unit = { + val listener = invocation.getArguments()(3).asInstanceOf[BlockFetchingListener] + future { + // Return the first block, and then fail. + listener.onBlockFetchSuccess( + ShuffleBlockId(0, 0, 0).toString, blocks(ShuffleBlockId(0, 0, 0))) + listener.onBlockFetchFailure(new BlockNotFoundException("blah")) + sem.release() + } + } + }) - val blId1 = ShuffleBlockId(0, 0, 0) - val blId2 = ShuffleBlockId(0, 1, 0) - val bmId = BlockManagerId("test-server", "test-server", 1) val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long)])]( - (bmId, Seq((blId1, 1L), (blId2, 1L)))) + (remoteBmId, blocks.keys.map(blockId => (blockId, 1.asInstanceOf[Long])).toSeq)) + val taskContext = new TaskContext(0, 0, 0) val iterator = new ShuffleBlockFetcherIterator( - new TaskContext(0, 0, 0), + taskContext, transfer, blockManager, blocksByAddress, new TestSerializer, 48 * 1024 * 1024) - iterator.foreach { case (_, iterOption) => - assert(!iterOption.isDefined) - } + // Continue only after the mock calls onBlockFetchFailure + sem.acquire() + + // The first block should be defined, and the last two are not defined (due to failure) + assert(iterator.next()._2.isDefined === true) + assert(iterator.next()._2.isDefined === false) + assert(iterator.next()._2.isDefined === false) } } From 5cd33d7798ae742e76107bb976d8478ab9476ae7 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Wed, 10 Sep 2014 19:55:54 -0700 Subject: [PATCH 09/46] Fixed style violation. --- .../main/scala/org/apache/spark/network/netty/BlockServer.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/network/netty/BlockServer.scala b/core/src/main/scala/org/apache/spark/network/netty/BlockServer.scala index ceae31efac939..d95ab8dd8496d 100644 --- a/core/src/main/scala/org/apache/spark/network/netty/BlockServer.scala +++ b/core/src/main/scala/org/apache/spark/network/netty/BlockServer.scala @@ -113,7 +113,7 @@ class BlockServer(conf: NettyConfig, dataProvider: BlockDataManager) extends Log val addr = channelFuture.channel.localAddress.asInstanceOf[InetSocketAddress] _port = addr.getPort - //_hostName = addr.getHostName + // _hostName = addr.getHostName _hostName = Utils.localHostName() } From 9e0cb8736be6d38e3f30766271d28875ceca1ae8 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Wed, 10 Sep 2014 21:04:56 -0700 Subject: [PATCH 10/46] Fixed BlockClientHandlerSuite --- .../spark/network/netty/BlockClientHandlerSuite.scala | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/core/src/test/scala/org/apache/spark/network/netty/BlockClientHandlerSuite.scala b/core/src/test/scala/org/apache/spark/network/netty/BlockClientHandlerSuite.scala index f2ed404ed8d4c..7ed3dc915bb7c 100644 --- a/core/src/test/scala/org/apache/spark/network/netty/BlockClientHandlerSuite.scala +++ b/core/src/test/scala/org/apache/spark/network/netty/BlockClientHandlerSuite.scala @@ -31,8 +31,9 @@ import org.apache.spark.network._ class BlockClientHandlerSuite extends FunSuite with PrivateMethodTester { private def sizeOfOutstandingRequests(handler: BlockClientHandler): Int = { - val outstandingRequests = PrivateMethod[java.util.Map[_, _]]('outstandingRequests) - handler.invokePrivate(outstandingRequests()).size + val f = handler.getClass.getDeclaredField("outstandingRequests") + f.setAccessible(true) + f.get(handler).asInstanceOf[java.util.Map[_, _]].size } test("handling block data (successful fetch)") { @@ -56,8 +57,7 @@ class BlockClientHandlerSuite extends FunSuite with PrivateMethodTester { } ) - val outstandingRequests = PrivateMethod[java.util.Map[_, _]]('outstandingRequests) - assert(handler.invokePrivate(outstandingRequests()).size === 1) + assert(sizeOfOutstandingRequests(handler) === 1) val channel = new EmbeddedChannel(handler) val buf = ByteBuffer.allocate(blockData.size) // 4 bytes for the length field itself @@ -68,7 +68,7 @@ class BlockClientHandlerSuite extends FunSuite with PrivateMethodTester { assert(parsedBlockId === blockId) assert(parsedBlockData === blockData) - assert(handler.invokePrivate(outstandingRequests()).size === 0) + assert(sizeOfOutstandingRequests(handler) === 0) assert(channel.finish() === false) } From d23ed7bfd912770ace7eed7cd0dff2db6ac826e3 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Thu, 11 Sep 2014 18:28:45 -0700 Subject: [PATCH 11/46] Incorporated feedback from Norman: - use same pool for boss and worker - remove ioratio - disable caching of byte buf allocator - childoption sendbuf/receivebuf - fire exception through pipeline In addition: - fire failure handler BlockFetchingListener at least once per block. - enabled a bunch of ignored tests --- .../spark/network/BlockFetchingListener.scala | 4 +- .../spark/network/BlockTransferService.scala | 2 +- .../spark/network/netty/BlockClient.scala | 2 +- .../network/netty/BlockClientFactory.scala | 30 +++++- .../network/netty/BlockClientHandler.scala | 47 +++++++--- .../spark/network/netty/BlockServer.scala | 21 ++--- .../network/nio/NioBlockTransferService.scala | 12 ++- .../storage/ShuffleBlockFetcherIterator.scala | 9 +- .../netty/BlockClientHandlerSuite.scala | 94 ++++++++----------- .../netty/ServerClientIntegrationSuite.scala | 7 +- .../ShuffleBlockFetcherIteratorSuite.scala | 5 +- 11 files changed, 129 insertions(+), 104 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/network/BlockFetchingListener.scala b/core/src/main/scala/org/apache/spark/network/BlockFetchingListener.scala index 34acaa563ca58..83fe497ad7448 100644 --- a/core/src/main/scala/org/apache/spark/network/BlockFetchingListener.scala +++ b/core/src/main/scala/org/apache/spark/network/BlockFetchingListener.scala @@ -31,7 +31,7 @@ trait BlockFetchingListener extends EventListener { def onBlockFetchSuccess(blockId: String, data: ManagedBuffer): Unit /** - * Called upon failures. For each failure, this is called only once (i.e. not once per block). + * Called at least once per block upon failures. */ - def onBlockFetchFailure(exception: Throwable): Unit + def onBlockFetchFailure(blockId: String, exception: Throwable): Unit } diff --git a/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala b/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala index 84d991fa6808c..4833b8a6abf32 100644 --- a/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala +++ b/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala @@ -83,7 +83,7 @@ abstract class BlockTransferService { val lock = new Object @volatile var result: Either[ManagedBuffer, Throwable] = null fetchBlocks(hostName, port, Seq(blockId), new BlockFetchingListener { - override def onBlockFetchFailure(exception: Throwable): Unit = { + override def onBlockFetchFailure(blockId: String, exception: Throwable): Unit = { lock.synchronized { result = Right(exception) lock.notify() diff --git a/core/src/main/scala/org/apache/spark/network/netty/BlockClient.scala b/core/src/main/scala/org/apache/spark/network/netty/BlockClient.scala index 9333fefa92957..6f67187adcb37 100644 --- a/core/src/main/scala/org/apache/spark/network/netty/BlockClient.scala +++ b/core/src/main/scala/org/apache/spark/network/netty/BlockClient.scala @@ -77,8 +77,8 @@ class BlockClient(cf: ChannelFuture, handler: BlockClientHandler) extends Loggin logError(errorMsg, future.cause) blockIds.foreach { blockId => handler.removeRequest(blockId) + listener.onBlockFetchFailure(blockId, new RuntimeException(errorMsg)) } - listener.onBlockFetchFailure(new RuntimeException(errorMsg)) } } }) diff --git a/core/src/main/scala/org/apache/spark/network/netty/BlockClientFactory.scala b/core/src/main/scala/org/apache/spark/network/netty/BlockClientFactory.scala index f05f1419ded14..1414d0966e3dd 100644 --- a/core/src/main/scala/org/apache/spark/network/netty/BlockClientFactory.scala +++ b/core/src/main/scala/org/apache/spark/network/netty/BlockClientFactory.scala @@ -28,6 +28,7 @@ import io.netty.channel.oio.OioEventLoopGroup import io.netty.channel.socket.SocketChannel import io.netty.channel.socket.nio.NioSocketChannel import io.netty.channel.socket.oio.OioSocketChannel +import io.netty.util.internal.PlatformDependent import org.apache.spark.SparkConf import org.apache.spark.util.Utils @@ -92,13 +93,14 @@ class BlockClientFactory(val conf: NettyConfig) { val bootstrap = new Bootstrap bootstrap.group(workerGroup) .channel(socketChannelClass) - // Use pooled buffers to reduce temporary buffer allocation - .option(ChannelOption.ALLOCATOR, PooledByteBufAllocator.DEFAULT) // Disable Nagle's Algorithm since we don't want packets to wait .option(ChannelOption.TCP_NODELAY, java.lang.Boolean.TRUE) .option(ChannelOption.SO_KEEPALIVE, java.lang.Boolean.TRUE) .option[Integer](ChannelOption.CONNECT_TIMEOUT_MILLIS, conf.connectTimeoutMs) + // Use pooled buffers to reduce temporary buffer allocation + bootstrap.option(ChannelOption.ALLOCATOR, createPooledByteBufAllocator()) + bootstrap.handler(new ChannelInitializer[SocketChannel] { override def initChannel(ch: SocketChannel): Unit = { ch.pipeline @@ -124,4 +126,28 @@ class BlockClientFactory(val conf: NettyConfig) { workerGroup.shutdownGracefully() } } + + /** + * Create a pooled ByteBuf allocator but disables the thread-local cache. Thread-local caches + * are disabled because the ByteBufs are allocated by the event loop thread, but released by the + * executor thread rather than the event loop thread. Those thread-local caches actually delay + * the recycling of buffers, leading to larger memory usage. + */ + private def createPooledByteBufAllocator(): PooledByteBufAllocator = { + def getPrivateStaticField(name: String): Int = { + val f = PooledByteBufAllocator.DEFAULT.getClass.getDeclaredField(name) + f.setAccessible(true) + f.getInt(null) + } + new PooledByteBufAllocator( + PlatformDependent.directBufferPreferred(), + getPrivateStaticField("DEFAULT_NUM_HEAP_ARENA"), + getPrivateStaticField("DEFAULT_NUM_DIRECT_ARENA"), + getPrivateStaticField("DEFAULT_PAGE_SIZE"), + getPrivateStaticField("DEFAULT_MAX_ORDER"), + 0, // tinyCacheSize + 0, // smallCacheSize + 0 // normalCacheSize + ) + } } diff --git a/core/src/main/scala/org/apache/spark/network/netty/BlockClientHandler.scala b/core/src/main/scala/org/apache/spark/network/netty/BlockClientHandler.scala index 2a474cd71eab8..1a74c6649f28a 100644 --- a/core/src/main/scala/org/apache/spark/network/netty/BlockClientHandler.scala +++ b/core/src/main/scala/org/apache/spark/network/netty/BlockClientHandler.scala @@ -17,6 +17,8 @@ package org.apache.spark.network.netty +import java.util.concurrent.ConcurrentHashMap + import io.netty.channel.{ChannelHandlerContext, SimpleChannelInboundHandler} import org.apache.spark.Logging @@ -33,9 +35,8 @@ private[netty] class BlockClientHandler extends SimpleChannelInboundHandler[ServerResponse] with Logging { /** Tracks the list of outstanding requests and their listeners on success/failure. */ - private[this] val outstandingRequests = java.util.Collections.synchronizedMap { - new java.util.HashMap[String, BlockFetchingListener] - } + private[this] val outstandingRequests: java.util.Map[String, BlockFetchingListener] = + new ConcurrentHashMap[String, BlockFetchingListener] def addRequest(blockId: String, listener: BlockFetchingListener): Unit = { outstandingRequests.put(blockId, listener) @@ -45,20 +46,36 @@ class BlockClientHandler extends SimpleChannelInboundHandler[ServerResponse] wit outstandingRequests.remove(blockId) } - override def exceptionCaught(ctx: ChannelHandlerContext, cause: Throwable): Unit = { - val errorMsg = s"Exception in connection from ${ctx.channel.remoteAddress}: ${cause.getMessage}" - logError(errorMsg, cause) + /** + * Fire the failure callback for all outstanding requests. This is called when we have an + * uncaught exception or pre-mature connection termination. + */ + private def failOutstandingRequests(cause: Throwable): Unit = { + val iter = outstandingRequests.entrySet().iterator() + while (iter.hasNext) { + val entry = iter.next() + entry.getValue.onBlockFetchFailure(entry.getKey, cause) + } + // TODO(rxin): Maybe we need to synchronize the access? Otherwise we could clear new requests + // as well. But I guess that is ok given the caller will fail as soon as any requests fail. + outstandingRequests.clear() + } - // Fire the failure callback for all outstanding blocks - outstandingRequests.synchronized { - val iter = outstandingRequests.entrySet().iterator() - while (iter.hasNext) { - val entry = iter.next() - entry.getValue.onBlockFetchFailure(cause) - } - outstandingRequests.clear() + override def channelUnregistered(ctx: ChannelHandlerContext): Unit = { + if (outstandingRequests.size() > 0) { + logError("Still have " + outstandingRequests.size() + " requests outstanding " + + s"when connection from ${ctx.channel.remoteAddress} is closed") + failOutstandingRequests(new RuntimeException( + s"Connection from ${ctx.channel.remoteAddress} closed")) } + } + override def exceptionCaught(ctx: ChannelHandlerContext, cause: Throwable): Unit = { + if (outstandingRequests.size() > 0) { + logError( + s"Exception in connection from ${ctx.channel.remoteAddress}: ${cause.getMessage}", cause) + failOutstandingRequests(cause) + } ctx.close() } @@ -80,7 +97,7 @@ class BlockClientHandler extends SimpleChannelInboundHandler[ServerResponse] wit s"Got a response for block $blockId from $server ($errorMsg) but it is not outstanding") } else { outstandingRequests.remove(blockId) - listener.onBlockFetchFailure(new RuntimeException(errorMsg)) + listener.onBlockFetchFailure(blockId, new RuntimeException(errorMsg)) } } } diff --git a/core/src/main/scala/org/apache/spark/network/netty/BlockServer.scala b/core/src/main/scala/org/apache/spark/network/netty/BlockServer.scala index d95ab8dd8496d..bd28d48c1a5ed 100644 --- a/core/src/main/scala/org/apache/spark/network/netty/BlockServer.scala +++ b/core/src/main/scala/org/apache/spark/network/netty/BlockServer.scala @@ -54,25 +54,22 @@ class BlockServer(conf: NettyConfig, dataProvider: BlockDataManager) extends Log /** Initialize the server. */ private def init(): Unit = { bootstrap = new ServerBootstrap - val bossThreadFactory = Utils.namedThreadFactory("spark-netty-server-boss") - val workerThreadFactory = Utils.namedThreadFactory("spark-netty-server-worker") + val threadFactory = Utils.namedThreadFactory("spark-netty-server") // Use only one thread to accept connections, and 2 * num_cores for worker. def initNio(): Unit = { - val bossGroup = new NioEventLoopGroup(1, bossThreadFactory) - val workerGroup = new NioEventLoopGroup(0, workerThreadFactory) - workerGroup.setIoRatio(conf.ioRatio) + val bossGroup = new NioEventLoopGroup(0, threadFactory) + val workerGroup = bossGroup bootstrap.group(bossGroup, workerGroup).channel(classOf[NioServerSocketChannel]) } def initOio(): Unit = { - val bossGroup = new OioEventLoopGroup(1, bossThreadFactory) - val workerGroup = new OioEventLoopGroup(0, workerThreadFactory) + val bossGroup = new OioEventLoopGroup(0, threadFactory) + val workerGroup = bossGroup bootstrap.group(bossGroup, workerGroup).channel(classOf[OioServerSocketChannel]) } def initEpoll(): Unit = { - val bossGroup = new EpollEventLoopGroup(1, bossThreadFactory) - val workerGroup = new EpollEventLoopGroup(0, workerThreadFactory) - workerGroup.setIoRatio(conf.ioRatio) + val bossGroup = new EpollEventLoopGroup(0, threadFactory) + val workerGroup = bossGroup bootstrap.group(bossGroup, workerGroup).channel(classOf[EpollServerSocketChannel]) } @@ -92,10 +89,10 @@ class BlockServer(conf: NettyConfig, dataProvider: BlockDataManager) extends Log bootstrap.option[java.lang.Integer](ChannelOption.SO_BACKLOG, backLog) } conf.receiveBuf.foreach { receiveBuf => - bootstrap.option[java.lang.Integer](ChannelOption.SO_RCVBUF, receiveBuf) + bootstrap.childOption[java.lang.Integer](ChannelOption.SO_RCVBUF, receiveBuf) } conf.sendBuf.foreach { sendBuf => - bootstrap.option[java.lang.Integer](ChannelOption.SO_SNDBUF, sendBuf) + bootstrap.childOption[java.lang.Integer](ChannelOption.SO_SNDBUF, sendBuf) } bootstrap.childHandler(new ChannelInitializer[SocketChannel] { diff --git a/core/src/main/scala/org/apache/spark/network/nio/NioBlockTransferService.scala b/core/src/main/scala/org/apache/spark/network/nio/NioBlockTransferService.scala index 91ebb3fe0e0f3..e9f67163a5e2f 100644 --- a/core/src/main/scala/org/apache/spark/network/nio/NioBlockTransferService.scala +++ b/core/src/main/scala/org/apache/spark/network/nio/NioBlockTransferService.scala @@ -96,10 +96,12 @@ final class NioBlockTransferService(conf: SparkConf, securityManager: SecurityMa val bufferMessage = message.asInstanceOf[BufferMessage] val blockMessageArray = BlockMessageArray.fromBufferMessage(bufferMessage) - for (blockMessage <- blockMessageArray) { + for (blockMessage: BlockMessage <- blockMessageArray) { if (blockMessage.getType != BlockMessage.TYPE_GOT_BLOCK) { - listener.onBlockFetchFailure( - new SparkException(s"Unexpected message ${blockMessage.getType} received from $cmId")) + if (blockMessage.getId != null) { + listener.onBlockFetchFailure(blockMessage.getId.toString, + new SparkException(s"Unexpected message ${blockMessage.getType} received from $cmId")) + } } else { val blockId = blockMessage.getId val networkSize = blockMessage.getData.limit() @@ -110,7 +112,9 @@ final class NioBlockTransferService(conf: SparkConf, securityManager: SecurityMa }(cm.futureExecContext) future.onFailure { case exception => - listener.onBlockFetchFailure(exception) + blockIds.foreach { blockId => + listener.onBlockFetchFailure(blockId, exception) + } }(cm.futureExecContext) } diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala index 61fe0aaf0d44f..38486c1ded9ea 100644 --- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala +++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala @@ -158,14 +158,9 @@ final class ShuffleBlockFetcherIterator( logDebug("Got remote block " + blockId + " after " + Utils.getUsedTimeMs(startTime)) } - override def onBlockFetchFailure(e: Throwable): Unit = { + override def onBlockFetchFailure(blockId: String, e: Throwable): Unit = { logError(s"Failed to get block(s) from ${req.address.host}:${req.address.port}", e) - // Note that there is a chance that some blocks have been fetched successfully, but we - // still add them to the failed queue. This is fine because when the caller see a - // FetchFailedException, it is going to fail the entire task anyway. - for ((blockId, size) <- req.blocks) { - results.put(new FetchResult(blockId, -1, null)) - } + results.put(new FetchResult(BlockId(blockId), -1, null)) } } ) diff --git a/core/src/test/scala/org/apache/spark/network/netty/BlockClientHandlerSuite.scala b/core/src/test/scala/org/apache/spark/network/netty/BlockClientHandlerSuite.scala index 7ed3dc915bb7c..c470bff825ba8 100644 --- a/core/src/test/scala/org/apache/spark/network/netty/BlockClientHandlerSuite.scala +++ b/core/src/test/scala/org/apache/spark/network/netty/BlockClientHandlerSuite.scala @@ -18,11 +18,13 @@ package org.apache.spark.network.netty import java.nio.ByteBuffer -import java.util.concurrent.atomic.AtomicInteger import io.netty.buffer.Unpooled import io.netty.channel.embedded.EmbeddedChannel +import org.mockito.Mockito._ +import org.mockito.Matchers.{any, eq => meq} + import org.scalatest.{FunSuite, PrivateMethodTester} import org.apache.spark.network._ @@ -31,7 +33,8 @@ import org.apache.spark.network._ class BlockClientHandlerSuite extends FunSuite with PrivateMethodTester { private def sizeOfOutstandingRequests(handler: BlockClientHandler): Int = { - val f = handler.getClass.getDeclaredField("outstandingRequests") + val f = handler.getClass.getDeclaredField( + "org$apache$spark$network$netty$BlockClientHandler$$outstandingRequests") f.setAccessible(true) f.get(handler).asInstanceOf[java.util.Map[_, _]].size } @@ -39,24 +42,9 @@ class BlockClientHandlerSuite extends FunSuite with PrivateMethodTester { test("handling block data (successful fetch)") { val blockId = "test_block" val blockData = "blahblahblahblahblah" - - var parsedBlockId: String = "" - var parsedBlockData: String = "" val handler = new BlockClientHandler - handler.addRequest(blockId, - new BlockFetchingListener { - override def onBlockFetchFailure(exception: Throwable): Unit = { - throw new UnsupportedOperationException - } - override def onBlockFetchSuccess(blockId: String, data: ManagedBuffer): Unit = { - parsedBlockId = blockId - val bytes = new Array[Byte](data.size.toInt) - data.nioByteBuffer().get(bytes) - parsedBlockData = new String(bytes) - } - } - ) - + val listener = mock(classOf[BlockFetchingListener]) + handler.addRequest(blockId, listener) assert(sizeOfOutstandingRequests(handler) === 1) val channel = new EmbeddedChannel(handler) @@ -65,54 +53,29 @@ class BlockClientHandlerSuite extends FunSuite with PrivateMethodTester { buf.flip() channel.writeInbound(BlockFetchSuccess(blockId, new NioManagedBuffer(buf))) - - assert(parsedBlockId === blockId) - assert(parsedBlockData === blockData) + verify(listener, times(1)).onBlockFetchSuccess(meq(blockId), any()) assert(sizeOfOutstandingRequests(handler) === 0) assert(channel.finish() === false) } test("handling error message (failed fetch)") { val blockId = "test_block" - val errorMsg = "error erro5r error err4or error3 error6 error erro1r" - - var parsedErrorMsg: String = "" val handler = new BlockClientHandler - handler.addRequest(blockId, - new BlockFetchingListener { - override def onBlockFetchFailure(exception: Throwable): Unit = { - parsedErrorMsg = exception.getMessage - } - - override def onBlockFetchSuccess(blockId: String, data: ManagedBuffer): Unit = { - throw new UnsupportedOperationException - } - } - ) - + val listener = mock(classOf[BlockFetchingListener]) + handler.addRequest(blockId, listener) assert(sizeOfOutstandingRequests(handler) === 1) val channel = new EmbeddedChannel(handler) - channel.writeInbound(BlockFetchFailure(blockId, errorMsg)) - assert(parsedErrorMsg === errorMsg) + channel.writeInbound(BlockFetchFailure(blockId, "some error msg")) + verify(listener, times(0)).onBlockFetchSuccess(any(), any()) + verify(listener, times(1)).onBlockFetchFailure(meq(blockId), any()) assert(sizeOfOutstandingRequests(handler) === 0) assert(channel.finish() === false) } - ignore("clear all outstanding request upon connection close") { - val errorCount = new AtomicInteger(0) - val successCount = new AtomicInteger(0) + test("clear all outstanding request upon uncaught exception") { val handler = new BlockClientHandler - - val listener = new BlockFetchingListener { - override def onBlockFetchFailure(exception: Throwable): Unit = { - errorCount.incrementAndGet() - } - override def onBlockFetchSuccess(blockId: String, data: ManagedBuffer): Unit = { - successCount.incrementAndGet() - } - } - + val listener = mock(classOf[BlockFetchingListener]) handler.addRequest("b1", listener) handler.addRequest("b2", listener) handler.addRequest("b3", listener) @@ -120,9 +83,30 @@ class BlockClientHandlerSuite extends FunSuite with PrivateMethodTester { val channel = new EmbeddedChannel(handler) channel.writeInbound(BlockFetchSuccess("b1", new NettyManagedBuffer(Unpooled.buffer()))) - // Need to figure out a way to generate an exception - assert(successCount.get() === 1) - assert(errorCount.get() === 2) + channel.pipeline().fireExceptionCaught(new Exception("duh duh duh")) + + // should fail both b2 and b3 + verify(listener, times(1)).onBlockFetchSuccess(any(), any()) + verify(listener, times(2)).onBlockFetchFailure(any(), any()) + assert(sizeOfOutstandingRequests(handler) === 0) + assert(channel.finish() === false) + } + + test("clear all outstanding request upon connection close") { + val handler = new BlockClientHandler + val listener = mock(classOf[BlockFetchingListener]) + handler.addRequest("c1", listener) + handler.addRequest("c2", listener) + handler.addRequest("c3", listener) + assert(sizeOfOutstandingRequests(handler) === 3) + + val channel = new EmbeddedChannel(handler) + channel.writeInbound(BlockFetchSuccess("c1", new NettyManagedBuffer(Unpooled.buffer()))) + channel.finish() + + // should fail both b2 and b3 + verify(listener, times(1)).onBlockFetchSuccess(any(), any()) + verify(listener, times(2)).onBlockFetchFailure(any(), any()) assert(sizeOfOutstandingRequests(handler) === 0) assert(channel.finish() === false) } diff --git a/core/src/test/scala/org/apache/spark/network/netty/ServerClientIntegrationSuite.scala b/core/src/test/scala/org/apache/spark/network/netty/ServerClientIntegrationSuite.scala index fa3512768d9a8..e3f98ff173ad0 100644 --- a/core/src/test/scala/org/apache/spark/network/netty/ServerClientIntegrationSuite.scala +++ b/core/src/test/scala/org/apache/spark/network/netty/ServerClientIntegrationSuite.scala @@ -103,7 +103,8 @@ class ServerClientIntegrationSuite extends FunSuite with BeforeAndAfterAll { client.fetchBlocks( blockIds, new BlockFetchingListener { - override def onBlockFetchFailure(exception: Throwable): Unit = { + override def onBlockFetchFailure(blockId: String, exception: Throwable): Unit = { + errorBlockIds.add(blockId) sem.release() } @@ -135,7 +136,7 @@ class ServerClientIntegrationSuite extends FunSuite with BeforeAndAfterAll { assert(failBlockIds.isEmpty) } - ignore("fetch a non-existent block") { + test("fetch a non-existent block") { val (blockIds, buffers, failBlockIds) = fetchBlocks(Seq("random-block")) assert(blockIds.isEmpty) assert(buffers.isEmpty) @@ -149,7 +150,7 @@ class ServerClientIntegrationSuite extends FunSuite with BeforeAndAfterAll { assert(failBlockIds.isEmpty) } - ignore("fetch both ByteBuffer block and a non-existent block") { + test("fetch both ByteBuffer block and a non-existent block") { val (blockIds, buffers, failBlockIds) = fetchBlocks(Seq(bufferBlockId, "random-block")) assert(blockIds === Set(bufferBlockId)) assert(buffers.map(_.convertToNetty()) === Set(byteBufferBlockReference)) diff --git a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala index b4700f38a6781..5a36614d1f59e 100644 --- a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala @@ -50,7 +50,7 @@ class ShuffleBlockFetcherIteratorSuite extends FunSuite { if (data.contains(BlockId(blockId))) { listener.onBlockFetchSuccess(blockId, data(BlockId(blockId))) } else { - listener.onBlockFetchFailure(new BlockNotFoundException(blockId)) + listener.onBlockFetchFailure(blockId, new BlockNotFoundException(blockId)) } } } @@ -205,7 +205,8 @@ class ShuffleBlockFetcherIteratorSuite extends FunSuite { // Return the first block, and then fail. listener.onBlockFetchSuccess( ShuffleBlockId(0, 0, 0).toString, blocks(ShuffleBlockId(0, 0, 0))) - listener.onBlockFetchFailure(new BlockNotFoundException("blah")) + listener.onBlockFetchFailure( + ShuffleBlockId(0, 1, 0).toString, new BlockNotFoundException("blah")) sem.release() } } From b2f3281d0de540d38ea5b4c7bf576b775405d56d Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Thu, 11 Sep 2014 22:12:08 -0700 Subject: [PATCH 12/46] Added connection pooling. --- .../spark/network/netty/BlockClient.scala | 11 +-- .../network/netty/BlockClientFactory.scala | 42 +++++++-- .../netty/BlockClientFactorySuite.scala | 91 +++++++++++++++++++ .../netty/BlockClientHandlerSuite.scala | 1 + .../netty/ServerClientIntegrationSuite.scala | 9 ++ 5 files changed, 140 insertions(+), 14 deletions(-) create mode 100644 core/src/test/scala/org/apache/spark/network/netty/BlockClientFactorySuite.scala diff --git a/core/src/main/scala/org/apache/spark/network/netty/BlockClient.scala b/core/src/main/scala/org/apache/spark/network/netty/BlockClient.scala index 6f67187adcb37..2768f98e9c1fd 100644 --- a/core/src/main/scala/org/apache/spark/network/netty/BlockClient.scala +++ b/core/src/main/scala/org/apache/spark/network/netty/BlockClient.scala @@ -43,6 +43,8 @@ class BlockClient(cf: ChannelFuture, handler: BlockClientHandler) extends Loggin private[this] val serverAddr = cf.channel().remoteAddress().toString + def isActive: Boolean = cf.channel().isActive + /** * Ask the remote server for a sequence of blocks, and execute the callback. * @@ -55,7 +57,7 @@ class BlockClient(cf: ChannelFuture, handler: BlockClientHandler) extends Loggin def fetchBlocks(blockIds: Seq[String], listener: BlockFetchingListener): Unit = { var startTime: Long = 0 logTrace { - startTime = System.nanoTime + startTime = System.nanoTime() s"Sending request $blockIds to $serverAddr" } @@ -67,7 +69,7 @@ class BlockClient(cf: ChannelFuture, handler: BlockClientHandler) extends Loggin override def operationComplete(future: ChannelFuture): Unit = { if (future.isSuccess) { logTrace { - val timeTaken = (System.nanoTime - startTime).toDouble / 1000000 + val timeTaken = (System.nanoTime() - startTime).toDouble / 1000000 s"Sending request $blockIds to $serverAddr took $timeTaken ms" } } else { @@ -84,9 +86,6 @@ class BlockClient(cf: ChannelFuture, handler: BlockClientHandler) extends Loggin }) } - def waitForClose(): Unit = { - cf.channel().closeFuture().sync() - } - + /** Close the connection. This does NOT block till the connection is closed. */ def close(): Unit = cf.channel().close() } diff --git a/core/src/main/scala/org/apache/spark/network/netty/BlockClientFactory.scala b/core/src/main/scala/org/apache/spark/network/netty/BlockClientFactory.scala index 1414d0966e3dd..01fc73fe728a7 100644 --- a/core/src/main/scala/org/apache/spark/network/netty/BlockClientFactory.scala +++ b/core/src/main/scala/org/apache/spark/network/netty/BlockClientFactory.scala @@ -17,7 +17,7 @@ package org.apache.spark.network.netty -import java.util.concurrent.TimeoutException +import java.util.concurrent.{ConcurrentHashMap, TimeoutException} import io.netty.bootstrap.Bootstrap import io.netty.buffer.PooledByteBufAllocator @@ -35,8 +35,10 @@ import org.apache.spark.util.Utils /** - * Factory for creating [[BlockClient]] by using createClient. This factory reuses - * the worker thread pool for Netty. + * Factory for creating [[BlockClient]] by using createClient. + * + * The factory maintains a connection pool to other hosts and should return the same [[BlockClient]] + * for the same remote host. It also shares a single worker thread pool for all [[BlockClient]]s. */ private[netty] class BlockClientFactory(val conf: NettyConfig) { @@ -44,11 +46,15 @@ class BlockClientFactory(val conf: NettyConfig) { def this(sparkConf: SparkConf) = this(new NettyConfig(sparkConf)) /** A thread factory so the threads are named (for debugging). */ - private[netty] val threadFactory = Utils.namedThreadFactory("spark-netty-client") + private[this] val threadFactory = Utils.namedThreadFactory("spark-netty-client") + + /** Socket channel type, initialized by [[init]] depending ioMode. */ + private[this] var socketChannelClass: Class[_ <: Channel] = _ - /** The following two are instantiated by the [[init]] method, depending ioMode. */ - private[netty] var socketChannelClass: Class[_ <: Channel] = _ - private[netty] var workerGroup: EventLoopGroup = _ + /** Thread pool shared by all clients. */ + private[this] var workerGroup: EventLoopGroup = _ + + private[this] val connectionPool = new ConcurrentHashMap[(String, Int), BlockClient] // The encoders are stateless and can be shared among multiple clients. private[this] val encoder = new ClientRequestEncoder @@ -88,6 +94,16 @@ class BlockClientFactory(val conf: NettyConfig) { * Concurrency: This method is safe to call from multiple threads. */ def createClient(remoteHost: String, remotePort: Int): BlockClient = { + // Get connection from the connection pool first. + // If it is not found or not active, create a new one. + val cachedClient = connectionPool.get((remoteHost, remotePort)) + if (cachedClient != null && cachedClient.isActive) { + return cachedClient + } + + // There is a chance two threads are creating two different clients connecting to the same host. + // But that's probably ok ... + val handler = new BlockClientHandler val bootstrap = new Bootstrap @@ -118,10 +134,20 @@ class BlockClientFactory(val conf: NettyConfig) { s"Connecting to $remoteHost:$remotePort timed out (${conf.connectTimeoutMs} ms)") } - new BlockClient(cf, handler) + val client = new BlockClient(cf, handler) + connectionPool.put((remoteHost, remotePort), client) + client } + /** Close all connections in the connection pool, and shutdown the worker thread pool. */ def stop(): Unit = { + val iter = connectionPool.entrySet().iterator() + while (iter.hasNext) { + val entry = iter.next() + entry.getValue.close() + connectionPool.remove(entry.getKey) + } + if (workerGroup != null) { workerGroup.shutdownGracefully() } diff --git a/core/src/test/scala/org/apache/spark/network/netty/BlockClientFactorySuite.scala b/core/src/test/scala/org/apache/spark/network/netty/BlockClientFactorySuite.scala new file mode 100644 index 0000000000000..b2dcebfc8ceee --- /dev/null +++ b/core/src/test/scala/org/apache/spark/network/netty/BlockClientFactorySuite.scala @@ -0,0 +1,91 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.netty + +import scala.concurrent.{Await, future} +import scala.concurrent.duration._ +import scala.concurrent.ExecutionContext.Implicits.global + +import org.scalatest.{BeforeAndAfterAll, FunSuite} + +import org.apache.spark.SparkConf + + +class BlockClientFactorySuite extends FunSuite with BeforeAndAfterAll { + + private val conf = new SparkConf + private var server1: BlockServer = _ + private var server2: BlockServer = _ + + override def beforeAll() { + server1 = new BlockServer(new NettyConfig(conf), null) + server2 = new BlockServer(new NettyConfig(conf), null) + } + + override def afterAll() { + if (server1 != null) { + server1.stop() + } + if (server2 != null) { + server2.stop() + } + } + + test("BlockClients created are active and reused") { + val factory = new BlockClientFactory(conf) + val c1 = factory.createClient(server1.hostName, server1.port) + val c2 = factory.createClient(server1.hostName, server1.port) + val c3 = factory.createClient(server2.hostName, server2.port) + assert(c1.isActive) + assert(c3.isActive) + assert(c1 === c2) + assert(c1 !== c3) + factory.stop() + } + + test("never return inactive clients") { + val factory = new BlockClientFactory(conf) + val c1 = factory.createClient(server1.hostName, server1.port) + c1.close() + + // Block until c1 is no longer active + val f = future { + while (c1.isActive) { + Thread.sleep(10) + } + } + Await.result(f, 3 seconds) + assert(!c1.isActive) + + // Create c2, which should be different from c1 + val c2 = factory.createClient(server1.hostName, server1.port) + assert(c1 !== c2) + factory.stop() + } + + test("BlockClients are close when BlockClientFactory is stopped") { + val factory = new BlockClientFactory(conf) + val c1 = factory.createClient(server1.hostName, server1.port) + val c2 = factory.createClient(server2.hostName, server2.port) + assert(c1.isActive) + assert(c2.isActive) + factory.stop() + assert(!c1.isActive) + assert(!c2.isActive) + } +} diff --git a/core/src/test/scala/org/apache/spark/network/netty/BlockClientHandlerSuite.scala b/core/src/test/scala/org/apache/spark/network/netty/BlockClientHandlerSuite.scala index c470bff825ba8..7b80fe6aa364a 100644 --- a/core/src/test/scala/org/apache/spark/network/netty/BlockClientHandlerSuite.scala +++ b/core/src/test/scala/org/apache/spark/network/netty/BlockClientHandlerSuite.scala @@ -32,6 +32,7 @@ import org.apache.spark.network._ class BlockClientHandlerSuite extends FunSuite with PrivateMethodTester { + /** Helper method to get num. outstanding requests from a private field using reflection. */ private def sizeOfOutstandingRequests(handler: BlockClientHandler): Int = { val f = handler.getClass.getDeclaredField( "org$apache$spark$network$netty$BlockClientHandler$$outstandingRequests") diff --git a/core/src/test/scala/org/apache/spark/network/netty/ServerClientIntegrationSuite.scala b/core/src/test/scala/org/apache/spark/network/netty/ServerClientIntegrationSuite.scala index e3f98ff173ad0..789df1f70dcd0 100644 --- a/core/src/test/scala/org/apache/spark/network/netty/ServerClientIntegrationSuite.scala +++ b/core/src/test/scala/org/apache/spark/network/netty/ServerClientIntegrationSuite.scala @@ -27,6 +27,9 @@ import scala.collection.JavaConversions._ import io.netty.buffer.Unpooled import org.scalatest.{BeforeAndAfterAll, FunSuite} +import org.scalatest.concurrent.Eventually._ +import org.scalatest.time.Span +import org.scalatest.time.Seconds import org.apache.spark.SparkConf import org.apache.spark.network._ @@ -156,4 +159,10 @@ class ServerClientIntegrationSuite extends FunSuite with BeforeAndAfterAll { assert(buffers.map(_.convertToNetty()) === Set(byteBufferBlockReference)) assert(failBlockIds === Set("random-block")) } + + test("shutting down server should also close client") { + val client = clientFactory.createClient(server.hostName, server.port) + server.stop() + eventually(timeout(Span(5, Seconds))) { assert(!client.isActive) } + } } From 14323a55ebfa7ccc684c2ae78eac299a4426b353 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Thu, 11 Sep 2014 22:13:02 -0700 Subject: [PATCH 13/46] Removed BlockManager.getLocalShuffleFromDisk. --- .../scala/org/apache/spark/storage/BlockManager.scala | 11 ----------- 1 file changed, 11 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index 663c5e83ac4c9..9be4e80ad56b9 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -339,17 +339,6 @@ private[spark] class BlockManager( locations } - /** - * A short-circuited method to get blocks directly from disk. This is used for getting - * shuffle blocks. It is safe to do so without a lock on block info since disk store - * never deletes (recent) items. - */ - def getLocalShuffleFromDisk(blockId: BlockId, serializer: Serializer): Option[Iterator[Any]] = { - val buf = shuffleManager.shuffleBlockManager.getBlockData(blockId.asInstanceOf[ShuffleBlockId]) - val is = wrapForCompression(blockId, buf.inputStream()) - Some(serializer.newInstance().deserializeStream(is).asIterator) - } - /** * Get block from local block manager. */ From f0a16e9ec7d5c811dff3cd5219548e05077099c8 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Fri, 12 Sep 2014 00:40:53 -0700 Subject: [PATCH 14/46] Fixed test hanging. --- .../apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala | 2 ++ 1 file changed, 2 insertions(+) diff --git a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala index 5a36614d1f59e..7d4086313fcc1 100644 --- a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala @@ -207,6 +207,8 @@ class ShuffleBlockFetcherIteratorSuite extends FunSuite { ShuffleBlockId(0, 0, 0).toString, blocks(ShuffleBlockId(0, 0, 0))) listener.onBlockFetchFailure( ShuffleBlockId(0, 1, 0).toString, new BlockNotFoundException("blah")) + listener.onBlockFetchFailure( + ShuffleBlockId(0, 2, 0).toString, new BlockNotFoundException("blah")) sem.release() } } From 519d64dcb7768b3657438a4cfc85ee8065f56c2a Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Fri, 12 Sep 2014 14:18:58 -0700 Subject: [PATCH 15/46] Mark private package visibility and MimaExcludes. --- .../org/apache/spark/network/BlockDataManager.scala | 1 + .../apache/spark/network/BlockFetchingListener.scala | 1 + .../apache/spark/network/BlockTransferService.scala | 1 + .../scala/org/apache/spark/network/ManagedBuffer.scala | 4 ++++ .../network/netty/NettyBlockTransferService.scala | 1 + .../org/apache/spark/network/netty/protocol.scala | 10 ++++++++++ project/MimaExcludes.scala | 8 +++++++- 7 files changed, 25 insertions(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/network/BlockDataManager.scala b/core/src/main/scala/org/apache/spark/network/BlockDataManager.scala index 638e05f481f55..0eeffe0e7c5e6 100644 --- a/core/src/main/scala/org/apache/spark/network/BlockDataManager.scala +++ b/core/src/main/scala/org/apache/spark/network/BlockDataManager.scala @@ -20,6 +20,7 @@ package org.apache.spark.network import org.apache.spark.storage.StorageLevel +private[spark] trait BlockDataManager { /** diff --git a/core/src/main/scala/org/apache/spark/network/BlockFetchingListener.scala b/core/src/main/scala/org/apache/spark/network/BlockFetchingListener.scala index 83fe497ad7448..dd70e26647939 100644 --- a/core/src/main/scala/org/apache/spark/network/BlockFetchingListener.scala +++ b/core/src/main/scala/org/apache/spark/network/BlockFetchingListener.scala @@ -23,6 +23,7 @@ import java.util.EventListener /** * Listener callback interface for [[BlockTransferService.fetchBlocks]]. */ +private[spark] trait BlockFetchingListener extends EventListener { /** diff --git a/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala b/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala index 4833b8a6abf32..d894eac374b7f 100644 --- a/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala +++ b/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala @@ -23,6 +23,7 @@ import scala.concurrent.duration.Duration import org.apache.spark.storage.StorageLevel +private[spark] abstract class BlockTransferService { /** diff --git a/core/src/main/scala/org/apache/spark/network/ManagedBuffer.scala b/core/src/main/scala/org/apache/spark/network/ManagedBuffer.scala index 86a0d653b341b..1611a44079570 100644 --- a/core/src/main/scala/org/apache/spark/network/ManagedBuffer.scala +++ b/core/src/main/scala/org/apache/spark/network/ManagedBuffer.scala @@ -44,6 +44,7 @@ import org.apache.spark.util.{ByteBufferInputStream, Utils} * In that case, if the buffer is going to be passed around to a different thread, retain/release * should be called. */ +private[spark] abstract class ManagedBuffer { // Note that all the methods are defined with parenthesis because their implementations can // have side effects (io operations). @@ -85,6 +86,7 @@ abstract class ManagedBuffer { /** * A [[ManagedBuffer]] backed by a segment in a file */ +private[spark] final class FileSegmentManagedBuffer(val file: File, val offset: Long, val length: Long) extends ManagedBuffer { @@ -149,6 +151,7 @@ final class FileSegmentManagedBuffer(val file: File, val offset: Long, val lengt /** * A [[ManagedBuffer]] backed by [[java.nio.ByteBuffer]]. */ +private[spark] final class NioManagedBuffer(buf: ByteBuffer) extends ManagedBuffer { override def size: Long = buf.remaining() @@ -168,6 +171,7 @@ final class NioManagedBuffer(buf: ByteBuffer) extends ManagedBuffer { /** * A [[ManagedBuffer]] backed by a Netty [[ByteBuf]]. */ +private[spark] final class NettyManagedBuffer(buf: ByteBuf) extends ManagedBuffer { override def size: Long = buf.readableBytes() diff --git a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala index fa8bdfc96e8b8..30dc812c4e7de 100644 --- a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala +++ b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala @@ -29,6 +29,7 @@ import org.apache.spark.storage.StorageLevel * * See protocol.scala for the communication protocol between server and client */ +private[spark] final class NettyBlockTransferService(conf: SparkConf) extends BlockTransferService { private[this] val nettyConf: NettyConfig = new NettyConfig(conf) diff --git a/core/src/main/scala/org/apache/spark/network/netty/protocol.scala b/core/src/main/scala/org/apache/spark/network/netty/protocol.scala index ac9d2097c93e2..6a14ad26dbf21 100644 --- a/core/src/main/scala/org/apache/spark/network/netty/protocol.scala +++ b/core/src/main/scala/org/apache/spark/network/netty/protocol.scala @@ -29,6 +29,7 @@ import org.apache.spark.network.{NettyManagedBuffer, ManagedBuffer} /** Messages from the client to the server. */ +private[netty] sealed trait ClientRequest { def id: Byte } @@ -37,6 +38,7 @@ sealed trait ClientRequest { * Request to fetch a sequence of blocks from the server. A single [[BlockFetchRequest]] can * correspond to multiple [[ServerResponse]]s. */ +private[netty] final case class BlockFetchRequest(blocks: Seq[String]) extends ClientRequest { override def id = 0 } @@ -44,6 +46,7 @@ final case class BlockFetchRequest(blocks: Seq[String]) extends ClientRequest { /** * Request to upload a block to the server. Currently the server does not ack the upload request. */ +private[netty] final case class BlockUploadRequest(blockId: String, data: ManagedBuffer) extends ClientRequest { require(blockId.length <= Byte.MaxValue) override def id = 1 @@ -51,17 +54,20 @@ final case class BlockUploadRequest(blockId: String, data: ManagedBuffer) extend /** Messages from server to client (usually in response to some [[ClientRequest]]. */ +private[netty] sealed trait ServerResponse { def id: Byte } /** Response to [[BlockFetchRequest]] when a block exists and has been successfully fetched. */ +private[netty] final case class BlockFetchSuccess(blockId: String, data: ManagedBuffer) extends ServerResponse { require(blockId.length <= Byte.MaxValue) override def id = 0 } /** Response to [[BlockFetchRequest]] when there is an error fetching the block. */ +private[netty] final case class BlockFetchFailure(blockId: String, error: String) extends ServerResponse { require(blockId.length <= Byte.MaxValue) override def id = 1 @@ -74,6 +80,7 @@ final case class BlockFetchFailure(blockId: String, error: String) extends Serve * This encoder is stateless so it is safe to be shared by multiple threads. */ @Sharable +private[netty] final class ClientRequestEncoder extends MessageToMessageEncoder[ClientRequest] { override def encode(ctx: ChannelHandlerContext, in: ClientRequest, out: JList[Object]): Unit = { in match { @@ -128,6 +135,7 @@ final class ClientRequestEncoder extends MessageToMessageEncoder[ClientRequest] * [[ProtocolUtils.createFrameDecoder()]]. */ @Sharable +private[netty] final class ClientRequestDecoder extends MessageToMessageDecoder[ByteBuf] { override protected def decode(ctx: ChannelHandlerContext, in: ByteBuf, out: JList[AnyRef]): Unit = { @@ -155,6 +163,7 @@ final class ClientRequestDecoder extends MessageToMessageDecoder[ByteBuf] { * This encoder is stateless so it is safe to be shared by multiple threads. */ @Sharable +private[netty] final class ServerResponseEncoder extends MessageToMessageEncoder[ServerResponse] with Logging { override def encode(ctx: ChannelHandlerContext, in: ServerResponse, out: JList[Object]): Unit = { in match { @@ -211,6 +220,7 @@ final class ServerResponseEncoder extends MessageToMessageEncoder[ServerResponse * [[ProtocolUtils.createFrameDecoder()]]. */ @Sharable +private[netty] final class ServerResponseDecoder extends MessageToMessageDecoder[ByteBuf] { override def decode(ctx: ChannelHandlerContext, in: ByteBuf, out: JList[AnyRef]): Unit = { val msgId = in.readByte() diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index d499302124461..8a1b2d3b91327 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -50,7 +50,13 @@ object MimaExcludes { "org.apache.spark.mllib.stat.MultivariateStatisticalSummary.normL2"), // MapStatus should be private[spark] ProblemFilters.exclude[IncompatibleTemplateDefProblem]( - "org.apache.spark.scheduler.MapStatus") + "org.apache.spark.scheduler.MapStatus"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.api.java.JavaRDDLike.foreachAsync"), + ProblemFilters.exclude[MissingClassProblem]( + "org.apache.spark.network.netty.PathResolver"), + ProblemFilters.exclude[MissingClassProblem]( + "org.apache.spark.network.netty.client.BlockClientListener") ) case v if v.startsWith("1.1") => From c066309afbb0e248a8b2b808d997e6b37a2bff1e Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Fri, 12 Sep 2014 22:42:32 -0700 Subject: [PATCH 16/46] Implement java.io.Closeable interface. --- .../apache/spark/network/BlockTransferService.scala | 6 ++++-- .../org/apache/spark/network/netty/BlockClient.scala | 3 ++- .../spark/network/netty/BlockClientFactory.scala | 5 +++-- .../org/apache/spark/network/netty/BlockServer.scala | 8 +++++--- .../network/netty/NettyBlockTransferService.scala | 6 +++--- .../spark/network/nio/NioBlockTransferService.scala | 2 +- .../scala/org/apache/spark/storage/BlockManager.scala | 2 +- .../spark/network/netty/BlockClientFactorySuite.scala | 10 +++++----- .../network/netty/ServerClientIntegrationSuite.scala | 6 +++--- 9 files changed, 27 insertions(+), 21 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala b/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala index d894eac374b7f..a8379a207a18b 100644 --- a/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala +++ b/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala @@ -17,6 +17,8 @@ package org.apache.spark.network +import java.io.Closeable + import scala.concurrent.{Await, Future} import scala.concurrent.duration.Duration @@ -24,7 +26,7 @@ import org.apache.spark.storage.StorageLevel private[spark] -abstract class BlockTransferService { +abstract class BlockTransferService extends Closeable { /** * Initialize the transfer service by giving it the BlockDataManager that can be used to fetch @@ -35,7 +37,7 @@ abstract class BlockTransferService { /** * Tear down the transfer service. */ - def stop(): Unit + def close(): Unit /** * Port number the service is listening on, available only after [[init]] is invoked. diff --git a/core/src/main/scala/org/apache/spark/network/netty/BlockClient.scala b/core/src/main/scala/org/apache/spark/network/netty/BlockClient.scala index 2768f98e9c1fd..fb50b15474292 100644 --- a/core/src/main/scala/org/apache/spark/network/netty/BlockClient.scala +++ b/core/src/main/scala/org/apache/spark/network/netty/BlockClient.scala @@ -17,6 +17,7 @@ package org.apache.spark.network.netty +import java.io.Closeable import java.util.concurrent.TimeoutException import io.netty.channel.{ChannelFuture, ChannelFutureListener} @@ -39,7 +40,7 @@ import org.apache.spark.network.BlockFetchingListener */ @throws[TimeoutException] private[netty] -class BlockClient(cf: ChannelFuture, handler: BlockClientHandler) extends Logging { +class BlockClient(cf: ChannelFuture, handler: BlockClientHandler) extends Closeable with Logging { private[this] val serverAddr = cf.channel().remoteAddress().toString diff --git a/core/src/main/scala/org/apache/spark/network/netty/BlockClientFactory.scala b/core/src/main/scala/org/apache/spark/network/netty/BlockClientFactory.scala index 01fc73fe728a7..e264f91142ec1 100644 --- a/core/src/main/scala/org/apache/spark/network/netty/BlockClientFactory.scala +++ b/core/src/main/scala/org/apache/spark/network/netty/BlockClientFactory.scala @@ -17,6 +17,7 @@ package org.apache.spark.network.netty +import java.io.Closeable import java.util.concurrent.{ConcurrentHashMap, TimeoutException} import io.netty.bootstrap.Bootstrap @@ -41,7 +42,7 @@ import org.apache.spark.util.Utils * for the same remote host. It also shares a single worker thread pool for all [[BlockClient]]s. */ private[netty] -class BlockClientFactory(val conf: NettyConfig) { +class BlockClientFactory(val conf: NettyConfig) extends Closeable { def this(sparkConf: SparkConf) = this(new NettyConfig(sparkConf)) @@ -140,7 +141,7 @@ class BlockClientFactory(val conf: NettyConfig) { } /** Close all connections in the connection pool, and shutdown the worker thread pool. */ - def stop(): Unit = { + override def close(): Unit = { val iter = connectionPool.entrySet().iterator() while (iter.hasNext) { val entry = iter.next() diff --git a/core/src/main/scala/org/apache/spark/network/netty/BlockServer.scala b/core/src/main/scala/org/apache/spark/network/netty/BlockServer.scala index bd28d48c1a5ed..9a8ffabd04c84 100644 --- a/core/src/main/scala/org/apache/spark/network/netty/BlockServer.scala +++ b/core/src/main/scala/org/apache/spark/network/netty/BlockServer.scala @@ -17,6 +17,7 @@ package org.apache.spark.network.netty +import java.io.Closeable import java.net.InetSocketAddress import io.netty.bootstrap.ServerBootstrap @@ -29,7 +30,7 @@ import io.netty.channel.socket.nio.NioServerSocketChannel import io.netty.channel.socket.oio.OioServerSocketChannel import io.netty.channel.{ChannelInitializer, ChannelFuture, ChannelOption} -import org.apache.spark.{Logging, SparkConf} +import org.apache.spark.Logging import org.apache.spark.network.BlockDataManager import org.apache.spark.util.Utils @@ -38,7 +39,8 @@ import org.apache.spark.util.Utils * Server for the [[NettyBlockTransferService]]. */ private[netty] -class BlockServer(conf: NettyConfig, dataProvider: BlockDataManager) extends Logging { +class BlockServer(conf: NettyConfig, dataProvider: BlockDataManager) + extends Closeable with Logging { def port: Int = _port @@ -115,7 +117,7 @@ class BlockServer(conf: NettyConfig, dataProvider: BlockDataManager) extends Log } /** Shutdown the server. */ - def stop(): Unit = { + def close(): Unit = { if (channelFuture != null) { channelFuture.channel().close().awaitUninterruptibly() channelFuture = null diff --git a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala index 30dc812c4e7de..14df5161cb0f3 100644 --- a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala +++ b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala @@ -42,12 +42,12 @@ final class NettyBlockTransferService(conf: SparkConf) extends BlockTransferServ clientFactory = new BlockClientFactory(nettyConf) } - override def stop(): Unit = { + override def close(): Unit = { if (server != null) { - server.stop() + server.close() } if (clientFactory != null) { - clientFactory.stop() + clientFactory.close() } } diff --git a/core/src/main/scala/org/apache/spark/network/nio/NioBlockTransferService.scala b/core/src/main/scala/org/apache/spark/network/nio/NioBlockTransferService.scala index e9f67163a5e2f..3d72155f8db8d 100644 --- a/core/src/main/scala/org/apache/spark/network/nio/NioBlockTransferService.scala +++ b/core/src/main/scala/org/apache/spark/network/nio/NioBlockTransferService.scala @@ -71,7 +71,7 @@ final class NioBlockTransferService(conf: SparkConf, securityManager: SecurityMa /** * Tear down the transfer service. */ - override def stop(): Unit = { + override def close(): Unit = { if (cm != null) { cm.stop() } diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index 9be4e80ad56b9..ac0599f30ef22 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -1113,7 +1113,7 @@ private[spark] class BlockManager( } def stop(): Unit = { - blockTransferService.stop() + blockTransferService.close() diskBlockManager.stop() actorSystem.stop(slaveActor) blockInfo.clear() diff --git a/core/src/test/scala/org/apache/spark/network/netty/BlockClientFactorySuite.scala b/core/src/test/scala/org/apache/spark/network/netty/BlockClientFactorySuite.scala index b2dcebfc8ceee..5075688b1b27c 100644 --- a/core/src/test/scala/org/apache/spark/network/netty/BlockClientFactorySuite.scala +++ b/core/src/test/scala/org/apache/spark/network/netty/BlockClientFactorySuite.scala @@ -39,10 +39,10 @@ class BlockClientFactorySuite extends FunSuite with BeforeAndAfterAll { override def afterAll() { if (server1 != null) { - server1.stop() + server1.close() } if (server2 != null) { - server2.stop() + server2.close() } } @@ -55,7 +55,7 @@ class BlockClientFactorySuite extends FunSuite with BeforeAndAfterAll { assert(c3.isActive) assert(c1 === c2) assert(c1 !== c3) - factory.stop() + factory.close() } test("never return inactive clients") { @@ -75,7 +75,7 @@ class BlockClientFactorySuite extends FunSuite with BeforeAndAfterAll { // Create c2, which should be different from c1 val c2 = factory.createClient(server1.hostName, server1.port) assert(c1 !== c2) - factory.stop() + factory.close() } test("BlockClients are close when BlockClientFactory is stopped") { @@ -84,7 +84,7 @@ class BlockClientFactorySuite extends FunSuite with BeforeAndAfterAll { val c2 = factory.createClient(server2.hostName, server2.port) assert(c1.isActive) assert(c2.isActive) - factory.stop() + factory.close() assert(!c1.isActive) assert(!c2.isActive) } diff --git a/core/src/test/scala/org/apache/spark/network/netty/ServerClientIntegrationSuite.scala b/core/src/test/scala/org/apache/spark/network/netty/ServerClientIntegrationSuite.scala index 789df1f70dcd0..98e896221f910 100644 --- a/core/src/test/scala/org/apache/spark/network/netty/ServerClientIntegrationSuite.scala +++ b/core/src/test/scala/org/apache/spark/network/netty/ServerClientIntegrationSuite.scala @@ -86,8 +86,8 @@ class ServerClientIntegrationSuite extends FunSuite with BeforeAndAfterAll { } override def afterAll() = { - server.stop() - clientFactory.stop() + server.close() + clientFactory.close() } /** A ByteBuf for buffer_block */ @@ -162,7 +162,7 @@ class ServerClientIntegrationSuite extends FunSuite with BeforeAndAfterAll { test("shutting down server should also close client") { val client = clientFactory.createClient(server.hostName, server.port) - server.stop() + server.close() eventually(timeout(Span(5, Seconds))) { assert(!client.isActive) } } } From 6afc435037a0448d6eb243bd18411ef25e3a2cf7 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Tue, 16 Sep 2014 22:51:11 -0700 Subject: [PATCH 17/46] Added logging. --- .../org/apache/spark/network/netty/BlockClientFactory.scala | 6 ++++-- .../scala/org/apache/spark/network/netty/BlockServer.scala | 2 ++ 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/network/netty/BlockClientFactory.scala b/core/src/main/scala/org/apache/spark/network/netty/BlockClientFactory.scala index e264f91142ec1..6278e69c2200b 100644 --- a/core/src/main/scala/org/apache/spark/network/netty/BlockClientFactory.scala +++ b/core/src/main/scala/org/apache/spark/network/netty/BlockClientFactory.scala @@ -31,7 +31,7 @@ import io.netty.channel.socket.nio.NioSocketChannel import io.netty.channel.socket.oio.OioSocketChannel import io.netty.util.internal.PlatformDependent -import org.apache.spark.SparkConf +import org.apache.spark.{Logging, SparkConf} import org.apache.spark.util.Utils @@ -42,7 +42,7 @@ import org.apache.spark.util.Utils * for the same remote host. It also shares a single worker thread pool for all [[BlockClient]]s. */ private[netty] -class BlockClientFactory(val conf: NettyConfig) extends Closeable { +class BlockClientFactory(val conf: NettyConfig) extends Logging with Closeable { def this(sparkConf: SparkConf) = this(new NettyConfig(sparkConf)) @@ -102,6 +102,8 @@ class BlockClientFactory(val conf: NettyConfig) extends Closeable { return cachedClient } + logInfo(s"Creating new connection to $remoteHost:$remotePort") + // There is a chance two threads are creating two different clients connecting to the same host. // But that's probably ok ... diff --git a/core/src/main/scala/org/apache/spark/network/netty/BlockServer.scala b/core/src/main/scala/org/apache/spark/network/netty/BlockServer.scala index 9a8ffabd04c84..2611f2eacdb36 100644 --- a/core/src/main/scala/org/apache/spark/network/netty/BlockServer.scala +++ b/core/src/main/scala/org/apache/spark/network/netty/BlockServer.scala @@ -114,6 +114,8 @@ class BlockServer(conf: NettyConfig, dataProvider: BlockDataManager) _port = addr.getPort // _hostName = addr.getHostName _hostName = Utils.localHostName() + + logInfo(s"Server started ${_hostName}:${_port}") } /** Shutdown the server. */ From f63fb4c1976e503238b7d7151f8f45f40ced36e9 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Mon, 29 Sep 2014 11:13:44 -0700 Subject: [PATCH 18/46] Add more debug message. --- .../org/apache/spark/network/ManagedBuffer.scala | 13 +++---------- 1 file changed, 3 insertions(+), 10 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/network/ManagedBuffer.scala b/core/src/main/scala/org/apache/spark/network/ManagedBuffer.scala index 1611a44079570..9a56ca157223b 100644 --- a/core/src/main/scala/org/apache/spark/network/ManagedBuffer.scala +++ b/core/src/main/scala/org/apache/spark/network/ManagedBuffer.scala @@ -121,7 +121,7 @@ final class FileSegmentManagedBuffer(val file: File, val offset: Long, val lengt } catch { case e: IOException => if (is != null) { - Utils.tryLog(is.close()) + is.close() } Try(file.length).toOption match { case Some(fileLen) => @@ -131,20 +131,13 @@ final class FileSegmentManagedBuffer(val file: File, val offset: Long, val lengt } case e: Throwable => if (is != null) { - Utils.tryLog(is.close()) + is.close() } throw e } } - private[network] override def convertToNetty(): AnyRef = { - val fileChannel = new FileInputStream(file).getChannel - new DefaultFileRegion(fileChannel, offset, length) - } - - // Content of file segments are not in-memory, so no need to reference count. - override def retain(): this.type = this - override def release(): this.type = this + override def toString: String = s"${getClass.getName}($file, $offset, $length)" } From d68f3286a4a9795dfb61a8a63b8a20b3eafb4821 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Mon, 29 Sep 2014 11:30:13 -0700 Subject: [PATCH 19/46] Logging close() in case close() fails. --- .../main/scala/org/apache/spark/network/ManagedBuffer.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/network/ManagedBuffer.scala b/core/src/main/scala/org/apache/spark/network/ManagedBuffer.scala index 9a56ca157223b..dd808d2500fbc 100644 --- a/core/src/main/scala/org/apache/spark/network/ManagedBuffer.scala +++ b/core/src/main/scala/org/apache/spark/network/ManagedBuffer.scala @@ -121,7 +121,7 @@ final class FileSegmentManagedBuffer(val file: File, val offset: Long, val lengt } catch { case e: IOException => if (is != null) { - is.close() + Utils.tryLog(is.close()) } Try(file.length).toOption match { case Some(fileLen) => @@ -131,7 +131,7 @@ final class FileSegmentManagedBuffer(val file: File, val offset: Long, val lengt } case e: Throwable => if (is != null) { - is.close() + Utils.tryLog(is.close()) } throw e } From 1bdd7eec5d9ddb5a9eb33c9733878aea3ca26ba6 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Mon, 29 Sep 2014 12:07:53 -0700 Subject: [PATCH 20/46] Fixed tests. --- .../network/netty/BlockClientHandler.scala | 2 ++ .../storage/ShuffleBlockFetcherIterator.scala | 2 +- .../spark/network/netty/ProtocolSuite.scala | 25 +++++++++++++++++-- .../netty/ServerClientIntegrationSuite.scala | 6 +++++ 4 files changed, 32 insertions(+), 3 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/network/netty/BlockClientHandler.scala b/core/src/main/scala/org/apache/spark/network/netty/BlockClientHandler.scala index 1a74c6649f28a..466ece99b9b96 100644 --- a/core/src/main/scala/org/apache/spark/network/netty/BlockClientHandler.scala +++ b/core/src/main/scala/org/apache/spark/network/netty/BlockClientHandler.scala @@ -86,9 +86,11 @@ class BlockClientHandler extends SimpleChannelInboundHandler[ServerResponse] wit val listener = outstandingRequests.get(blockId) if (listener == null) { logWarning(s"Got a response for block $blockId from $server but it is not outstanding") + buf.release() } else { outstandingRequests.remove(blockId) listener.onBlockFetchSuccess(blockId, buf) + buf.release() } case BlockFetchFailure(blockId, errorMsg) => val listener = outstandingRequests.get(blockId) diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala index 38486c1ded9ea..d095452a261db 100644 --- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala +++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala @@ -155,7 +155,7 @@ final class ShuffleBlockFetcherIterator( shuffleMetrics.remoteBytesRead += buf.size shuffleMetrics.remoteBlocksFetched += 1 } - logDebug("Got remote block " + blockId + " after " + Utils.getUsedTimeMs(startTime)) + logTrace("Got remote block " + blockId + " after " + Utils.getUsedTimeMs(startTime)) } override def onBlockFetchFailure(blockId: String, e: Throwable): Unit = { diff --git a/core/src/test/scala/org/apache/spark/network/netty/ProtocolSuite.scala b/core/src/test/scala/org/apache/spark/network/netty/ProtocolSuite.scala index 72034634a5bd2..46604ea1fb624 100644 --- a/core/src/test/scala/org/apache/spark/network/netty/ProtocolSuite.scala +++ b/core/src/test/scala/org/apache/spark/network/netty/ProtocolSuite.scala @@ -71,18 +71,39 @@ class ProtocolSuite extends FunSuite { assert(msg === serverChannel.readInbound()) } - test("server to client protocol") { + test("server to client protocol - BlockFetchSuccess(\"a1234\", new TestManagedBuffer(10))") { testServerToClient(BlockFetchSuccess("a1234", new TestManagedBuffer(10))) + } + + test("server to client protocol - BlockFetchSuccess(\"\", new TestManagedBuffer(0))") { testServerToClient(BlockFetchSuccess("", new TestManagedBuffer(0))) + } + + test("server to client protocol - BlockFetchFailure(\"abcd\", \"this is an error\")") { testServerToClient(BlockFetchFailure("abcd", "this is an error")) + } + + test("server to client protocol - BlockFetchFailure(\"\", \"\")") { testServerToClient(BlockFetchFailure("", "")) } - test("client to server protocol") { + test("client to server protocol - BlockFetchRequest(Seq.empty[String])") { testClientToServer(BlockFetchRequest(Seq.empty[String])) + } + + test("client to server protocol - BlockFetchRequest(Seq(\"b1\"))") { testClientToServer(BlockFetchRequest(Seq("b1"))) + } + + test("client to server protocol - BlockFetchRequest(Seq(\"b1\", \"b2\", \"b3\"))") { testClientToServer(BlockFetchRequest(Seq("b1", "b2", "b3"))) + } + + ignore("client to server protocol - BlockUploadRequest(\"\", new TestManagedBuffer(0))") { testClientToServer(BlockUploadRequest("", new TestManagedBuffer(0))) + } + + ignore("client to server protocol - BlockUploadRequest(\"b_upload\", new TestManagedBuffer(10))") { testClientToServer(BlockUploadRequest("b_upload", new TestManagedBuffer(10))) } } diff --git a/core/src/test/scala/org/apache/spark/network/netty/ServerClientIntegrationSuite.scala b/core/src/test/scala/org/apache/spark/network/netty/ServerClientIntegrationSuite.scala index 98e896221f910..35ff90a2dabc5 100644 --- a/core/src/test/scala/org/apache/spark/network/netty/ServerClientIntegrationSuite.scala +++ b/core/src/test/scala/org/apache/spark/network/netty/ServerClientIntegrationSuite.scala @@ -112,6 +112,7 @@ class ServerClientIntegrationSuite extends FunSuite with BeforeAndAfterAll { } override def onBlockFetchSuccess(blockId: String, data: ManagedBuffer): Unit = { + data.retain() receivedBlockIds.add(blockId) receivedBuffers.add(data) sem.release() @@ -130,6 +131,7 @@ class ServerClientIntegrationSuite extends FunSuite with BeforeAndAfterAll { assert(blockIds === Set(bufferBlockId)) assert(buffers.map(_.convertToNetty()) === Set(byteBufferBlockReference)) assert(failBlockIds.isEmpty) + buffers.foreach(_.release()) } test("fetch a FileSegment block via zero-copy send") { @@ -137,6 +139,7 @@ class ServerClientIntegrationSuite extends FunSuite with BeforeAndAfterAll { assert(blockIds === Set(fileBlockId)) assert(buffers.map(_.convertToNetty()) === Set(fileBlockReference)) assert(failBlockIds.isEmpty) + buffers.foreach(_.release()) } test("fetch a non-existent block") { @@ -144,6 +147,7 @@ class ServerClientIntegrationSuite extends FunSuite with BeforeAndAfterAll { assert(blockIds.isEmpty) assert(buffers.isEmpty) assert(failBlockIds === Set("random-block")) + buffers.foreach(_.release()) } test("fetch both ByteBuffer block and FileSegment block") { @@ -151,6 +155,7 @@ class ServerClientIntegrationSuite extends FunSuite with BeforeAndAfterAll { assert(blockIds === Set(bufferBlockId, fileBlockId)) assert(buffers.map(_.convertToNetty()) === Set(byteBufferBlockReference, fileBlockReference)) assert(failBlockIds.isEmpty) + buffers.foreach(_.release()) } test("fetch both ByteBuffer block and a non-existent block") { @@ -158,6 +163,7 @@ class ServerClientIntegrationSuite extends FunSuite with BeforeAndAfterAll { assert(blockIds === Set(bufferBlockId)) assert(buffers.map(_.convertToNetty()) === Set(byteBufferBlockReference)) assert(failBlockIds === Set("random-block")) + buffers.foreach(_.release()) } test("shutting down server should also close client") { From bec4ea2b54659cfed6f54e527aa878dfbff829c7 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Mon, 29 Sep 2014 12:22:01 -0700 Subject: [PATCH 21/46] Removed OIO and added num threads settings. --- .../network/netty/BlockClientFactory.scala | 13 +++---------- .../spark/network/netty/BlockServer.scala | 12 ++---------- .../spark/network/netty/NettyConfig.scala | 18 ++++++++++-------- 3 files changed, 15 insertions(+), 28 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/network/netty/BlockClientFactory.scala b/core/src/main/scala/org/apache/spark/network/netty/BlockClientFactory.scala index 6278e69c2200b..8021cfdf42d1a 100644 --- a/core/src/main/scala/org/apache/spark/network/netty/BlockClientFactory.scala +++ b/core/src/main/scala/org/apache/spark/network/netty/BlockClientFactory.scala @@ -25,10 +25,8 @@ import io.netty.buffer.PooledByteBufAllocator import io.netty.channel._ import io.netty.channel.epoll.{Epoll, EpollEventLoopGroup, EpollSocketChannel} import io.netty.channel.nio.NioEventLoopGroup -import io.netty.channel.oio.OioEventLoopGroup import io.netty.channel.socket.SocketChannel import io.netty.channel.socket.nio.NioSocketChannel -import io.netty.channel.socket.oio.OioSocketChannel import io.netty.util.internal.PlatformDependent import org.apache.spark.{Logging, SparkConf} @@ -65,23 +63,18 @@ class BlockClientFactory(val conf: NettyConfig) extends Logging with Closeable { /** Initialize [[socketChannelClass]] and [[workerGroup]] based on ioMode. */ private def init(): Unit = { - def initOio(): Unit = { - socketChannelClass = classOf[OioSocketChannel] - workerGroup = new OioEventLoopGroup(0, threadFactory) - } def initNio(): Unit = { socketChannelClass = classOf[NioSocketChannel] - workerGroup = new NioEventLoopGroup(0, threadFactory) + workerGroup = new NioEventLoopGroup(conf.clientThreads, threadFactory) } def initEpoll(): Unit = { socketChannelClass = classOf[EpollSocketChannel] - workerGroup = new EpollEventLoopGroup(0, threadFactory) + workerGroup = new EpollEventLoopGroup(conf.clientThreads, threadFactory) } // For auto mode, first try epoll (only available on Linux), then nio. conf.ioMode match { case "nio" => initNio() - case "oio" => initOio() case "epoll" => initEpoll() case "auto" => if (Epoll.isAvailable) initEpoll() else initNio() } @@ -102,7 +95,7 @@ class BlockClientFactory(val conf: NettyConfig) extends Logging with Closeable { return cachedClient } - logInfo(s"Creating new connection to $remoteHost:$remotePort") + logDebug(s"Creating new connection to $remoteHost:$remotePort") // There is a chance two threads are creating two different clients connecting to the same host. // But that's probably ok ... diff --git a/core/src/main/scala/org/apache/spark/network/netty/BlockServer.scala b/core/src/main/scala/org/apache/spark/network/netty/BlockServer.scala index 2611f2eacdb36..e2eb7c379f14d 100644 --- a/core/src/main/scala/org/apache/spark/network/netty/BlockServer.scala +++ b/core/src/main/scala/org/apache/spark/network/netty/BlockServer.scala @@ -24,10 +24,8 @@ import io.netty.bootstrap.ServerBootstrap import io.netty.buffer.PooledByteBufAllocator import io.netty.channel.epoll.{Epoll, EpollEventLoopGroup, EpollServerSocketChannel} import io.netty.channel.nio.NioEventLoopGroup -import io.netty.channel.oio.OioEventLoopGroup import io.netty.channel.socket.SocketChannel import io.netty.channel.socket.nio.NioServerSocketChannel -import io.netty.channel.socket.oio.OioServerSocketChannel import io.netty.channel.{ChannelInitializer, ChannelFuture, ChannelOption} import org.apache.spark.Logging @@ -60,24 +58,18 @@ class BlockServer(conf: NettyConfig, dataProvider: BlockDataManager) // Use only one thread to accept connections, and 2 * num_cores for worker. def initNio(): Unit = { - val bossGroup = new NioEventLoopGroup(0, threadFactory) + val bossGroup = new NioEventLoopGroup(conf.serverThreads, threadFactory) val workerGroup = bossGroup bootstrap.group(bossGroup, workerGroup).channel(classOf[NioServerSocketChannel]) } - def initOio(): Unit = { - val bossGroup = new OioEventLoopGroup(0, threadFactory) - val workerGroup = bossGroup - bootstrap.group(bossGroup, workerGroup).channel(classOf[OioServerSocketChannel]) - } def initEpoll(): Unit = { - val bossGroup = new EpollEventLoopGroup(0, threadFactory) + val bossGroup = new EpollEventLoopGroup(conf.serverThreads, threadFactory) val workerGroup = bossGroup bootstrap.group(bossGroup, workerGroup).channel(classOf[EpollServerSocketChannel]) } conf.ioMode match { case "nio" => initNio() - case "oio" => initOio() case "epoll" => initEpoll() case "auto" => if (Epoll.isAvailable) initEpoll() else initNio() } diff --git a/core/src/main/scala/org/apache/spark/network/netty/NettyConfig.scala b/core/src/main/scala/org/apache/spark/network/netty/NettyConfig.scala index b5870152c5a64..d5078e417d6d2 100644 --- a/core/src/main/scala/org/apache/spark/network/netty/NettyConfig.scala +++ b/core/src/main/scala/org/apache/spark/network/netty/NettyConfig.scala @@ -31,18 +31,20 @@ class NettyConfig(conf: SparkConf) { /** IO mode: nio, oio, epoll, or auto (try epoll first and then nio). */ private[netty] val ioMode = conf.get("spark.shuffle.io.mode", "nio").toLowerCase - /** Connect timeout in secs. Default 60 secs. */ - private[netty] val connectTimeoutMs = conf.getInt("spark.shuffle.io.connectionTimeout", 60) * 1000 - - /** - * Percentage of the desired amount of time spent for I/O in the child event loops. - * Only applicable in nio and epoll. - */ - private[netty] val ioRatio = conf.getInt("spark.shuffle.io.netty.ioRatio", 80) + /** Connect timeout in secs. Default 120 secs. */ + private[netty] val connectTimeoutMs = { + conf.getInt("spark.shuffle.io.connectionTimeout", 120) * 1000 + } /** Requested maximum length of the queue of incoming connections. */ private[netty] val backLog: Option[Int] = conf.getOption("spark.shuffle.io.backLog").map(_.toInt) + /** Number of threads used in the server thread pool. Default to 0, which is 2x#cores. */ + private[netty] val serverThreads: Int = conf.getInt("spark.shuffle.io.serverThreads", 0) + + /** Number of threads used in the client thread pool. Default to 0, which is 2x#cores. */ + private[netty] val clientThreads: Int = conf.getInt("spark.shuffle.io.clientThreads", 0) + /** * Receive buffer size (SO_RCVBUF). * Note: the optimal size for receive buffer and send buffer should be From 4b18db29edcdb87577fd033835275fd1c2957dcd Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Mon, 29 Sep 2014 15:45:05 -0700 Subject: [PATCH 22/46] Copy the buffer in fetchBlockSync. --- .../org/apache/spark/network/BlockTransferService.scala | 5 ++++- .../apache/spark/network/netty/BlockClientFactorySuite.scala | 2 +- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala b/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala index a8379a207a18b..c874bddcf4a6c 100644 --- a/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala +++ b/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala @@ -18,6 +18,7 @@ package org.apache.spark.network import java.io.Closeable +import java.nio.ByteBuffer import scala.concurrent.{Await, Future} import scala.concurrent.duration.Duration @@ -94,7 +95,9 @@ abstract class BlockTransferService extends Closeable { } override def onBlockFetchSuccess(blockId: String, data: ManagedBuffer): Unit = { lock.synchronized { - result = Left(data) + val ret = ByteBuffer.allocate(data.size.toInt) + ret.put(data.nioByteBuffer()) + result = Left(new NioManagedBuffer(ret)) lock.notify() } } diff --git a/core/src/test/scala/org/apache/spark/network/netty/BlockClientFactorySuite.scala b/core/src/test/scala/org/apache/spark/network/netty/BlockClientFactorySuite.scala index 5075688b1b27c..2d4baafcf03d0 100644 --- a/core/src/test/scala/org/apache/spark/network/netty/BlockClientFactorySuite.scala +++ b/core/src/test/scala/org/apache/spark/network/netty/BlockClientFactorySuite.scala @@ -69,7 +69,7 @@ class BlockClientFactorySuite extends FunSuite with BeforeAndAfterAll { Thread.sleep(10) } } - Await.result(f, 3 seconds) + Await.result(f, 3.seconds) assert(!c1.isActive) // Create c2, which should be different from c1 From a0518c766f0f4eba24459ffac61dce789fc14092 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Mon, 29 Sep 2014 19:22:34 -0700 Subject: [PATCH 23/46] Implemented block uploads. --- .../spark/network/BlockTransferService.scala | 3 - .../org/apache/spark/network/exceptions.scala | 31 ++++++++ .../spark/network/netty/BlockClient.scala | 42 +++++++++-- .../network/netty/BlockClientHandler.scala | 68 +++++++++++++----- .../network/netty/BlockServerHandler.scala | 40 +++++++++-- .../netty/NettyBlockTransferService.scala | 6 +- .../apache/spark/network/netty/protocol.scala | 72 ++++++++++++++++--- .../apache/spark/storage/StorageLevel.scala | 3 +- .../netty/BlockClientHandlerSuite.scala | 18 ++--- .../spark/network/netty/ProtocolSuite.scala | 12 ++-- 10 files changed, 240 insertions(+), 55 deletions(-) create mode 100644 core/src/main/scala/org/apache/spark/network/exceptions.scala diff --git a/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala b/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala index c874bddcf4a6c..2a0a1a0bc0a14 100644 --- a/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala +++ b/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala @@ -54,9 +54,6 @@ abstract class BlockTransferService extends Closeable { * Fetch a sequence of blocks from a remote node asynchronously, * available only after [[init]] is invoked. * - * Note that [[BlockFetchingListener.onBlockFetchSuccess]] is called once per block, - * while [[BlockFetchingListener.onBlockFetchFailure]] is called once per failure (not per block). - * * Note that this API takes a sequence so the implementation can batch requests, and does not * return a future so the underlying implementation can invoke onBlockFetchSuccess as soon as * the data of a block is fetched, rather than waiting for all blocks to be fetched. diff --git a/core/src/main/scala/org/apache/spark/network/exceptions.scala b/core/src/main/scala/org/apache/spark/network/exceptions.scala new file mode 100644 index 0000000000000..d918d358c4adb --- /dev/null +++ b/core/src/main/scala/org/apache/spark/network/exceptions.scala @@ -0,0 +1,31 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network + +class BlockFetchFailureException(blockId: String, errorMsg: String, cause: Throwable) + extends Exception(errorMsg, cause) { + + def this(blockId: String, errorMsg: String) = this(blockId, errorMsg, null) +} + + +class BlockUploadFailureException(blockId: String, cause: Throwable) + extends Exception(s"Failed to fetch block $blockId", cause) { + + def this(blockId: String) = this(blockId, null) +} diff --git a/core/src/main/scala/org/apache/spark/network/netty/BlockClient.scala b/core/src/main/scala/org/apache/spark/network/netty/BlockClient.scala index fb50b15474292..c77a7ae1ccb0f 100644 --- a/core/src/main/scala/org/apache/spark/network/netty/BlockClient.scala +++ b/core/src/main/scala/org/apache/spark/network/netty/BlockClient.scala @@ -20,10 +20,13 @@ package org.apache.spark.network.netty import java.io.Closeable import java.util.concurrent.TimeoutException +import scala.concurrent.{Future, promise} + import io.netty.channel.{ChannelFuture, ChannelFutureListener} import org.apache.spark.Logging -import org.apache.spark.network.BlockFetchingListener +import org.apache.spark.network.{ManagedBuffer, BlockFetchingListener} +import org.apache.spark.storage.StorageLevel /** @@ -58,19 +61,19 @@ class BlockClient(cf: ChannelFuture, handler: BlockClientHandler) extends Closea def fetchBlocks(blockIds: Seq[String], listener: BlockFetchingListener): Unit = { var startTime: Long = 0 logTrace { - startTime = System.nanoTime() + startTime = System.currentTimeMillis() s"Sending request $blockIds to $serverAddr" } blockIds.foreach { blockId => - handler.addRequest(blockId, listener) + handler.addFetchRequest(blockId, listener) } cf.channel().writeAndFlush(BlockFetchRequest(blockIds)).addListener(new ChannelFutureListener { override def operationComplete(future: ChannelFuture): Unit = { if (future.isSuccess) { logTrace { - val timeTaken = (System.nanoTime() - startTime).toDouble / 1000000 + val timeTaken = System.currentTimeMillis() - startTime s"Sending request $blockIds to $serverAddr took $timeTaken ms" } } else { @@ -79,7 +82,7 @@ class BlockClient(cf: ChannelFuture, handler: BlockClientHandler) extends Closea s"Failed to send request $blockIds to $serverAddr: ${future.cause.getMessage}" logError(errorMsg, future.cause) blockIds.foreach { blockId => - handler.removeRequest(blockId) + handler.removeFetchRequest(blockId) listener.onBlockFetchFailure(blockId, new RuntimeException(errorMsg)) } } @@ -87,6 +90,35 @@ class BlockClient(cf: ChannelFuture, handler: BlockClientHandler) extends Closea }) } + def uploadBlock(blockId: String, data: ManagedBuffer, storageLevel: StorageLevel): Future[Unit] = { + var startTime: Long = 0 + logTrace { + startTime = System.currentTimeMillis() + s"Uploading block ($blockId) to $serverAddr" + } + val f = cf.channel().writeAndFlush(new BlockUploadRequest(blockId, data, storageLevel)) + + val p = promise[Unit]() + handler.addUploadRequest(blockId, p) + f.addListener(new ChannelFutureListener { + override def operationComplete(future: ChannelFuture): Unit = { + if (future.isSuccess) { + logTrace { + val timeTaken = System.currentTimeMillis() - startTime + s"Uploading block ($blockId) to $serverAddr took $timeTaken ms" + } + } else { + // Fail all blocks. + val errorMsg = + s"Failed to upload block $blockId to $serverAddr: ${future.cause.getMessage}" + logError(errorMsg, future.cause) + } + } + }) + + p.future + } + /** Close the connection. This does NOT block till the connection is closed. */ def close(): Unit = cf.channel().close() } diff --git a/core/src/main/scala/org/apache/spark/network/netty/BlockClientHandler.scala b/core/src/main/scala/org/apache/spark/network/netty/BlockClientHandler.scala index 466ece99b9b96..5e28a07a461fa 100644 --- a/core/src/main/scala/org/apache/spark/network/netty/BlockClientHandler.scala +++ b/core/src/main/scala/org/apache/spark/network/netty/BlockClientHandler.scala @@ -19,10 +19,12 @@ package org.apache.spark.network.netty import java.util.concurrent.ConcurrentHashMap +import scala.concurrent.Promise + import io.netty.channel.{ChannelHandlerContext, SimpleChannelInboundHandler} import org.apache.spark.Logging -import org.apache.spark.network.BlockFetchingListener +import org.apache.spark.network.{BlockFetchFailureException, BlockUploadFailureException, BlockFetchingListener} /** @@ -35,15 +37,22 @@ private[netty] class BlockClientHandler extends SimpleChannelInboundHandler[ServerResponse] with Logging { /** Tracks the list of outstanding requests and their listeners on success/failure. */ - private[this] val outstandingRequests: java.util.Map[String, BlockFetchingListener] = + private[this] val outstandingFetches: java.util.Map[String, BlockFetchingListener] = new ConcurrentHashMap[String, BlockFetchingListener] - def addRequest(blockId: String, listener: BlockFetchingListener): Unit = { - outstandingRequests.put(blockId, listener) + private[this] val outstandingUploads: java.util.Map[String, Promise[Unit]] = + new ConcurrentHashMap[String, Promise[Unit]] + + def addFetchRequest(blockId: String, listener: BlockFetchingListener): Unit = { + outstandingFetches.put(blockId, listener) } - def removeRequest(blockId: String): Unit = { - outstandingRequests.remove(blockId) + def removeFetchRequest(blockId: String): Unit = { + outstandingFetches.remove(blockId) + } + + def addUploadRequest(blockId: String, promise: Promise[Unit]): Unit = { + outstandingUploads.put(blockId, promise) } /** @@ -51,19 +60,26 @@ class BlockClientHandler extends SimpleChannelInboundHandler[ServerResponse] wit * uncaught exception or pre-mature connection termination. */ private def failOutstandingRequests(cause: Throwable): Unit = { - val iter = outstandingRequests.entrySet().iterator() - while (iter.hasNext) { - val entry = iter.next() + val iter1 = outstandingFetches.entrySet().iterator() + while (iter1.hasNext) { + val entry = iter1.next() entry.getValue.onBlockFetchFailure(entry.getKey, cause) } // TODO(rxin): Maybe we need to synchronize the access? Otherwise we could clear new requests // as well. But I guess that is ok given the caller will fail as soon as any requests fail. - outstandingRequests.clear() + outstandingFetches.clear() + + val iter2 = outstandingUploads.entrySet().iterator() + while (iter2.hasNext) { + val entry = iter2.next() + entry.getValue.failure(new RuntimeException(s"Failed to upload block ${entry.getKey}")) + } + outstandingUploads.clear() } override def channelUnregistered(ctx: ChannelHandlerContext): Unit = { - if (outstandingRequests.size() > 0) { - logError("Still have " + outstandingRequests.size() + " requests outstanding " + + if (outstandingFetches.size() > 0) { + logError("Still have " + outstandingFetches.size() + " requests outstanding " + s"when connection from ${ctx.channel.remoteAddress} is closed") failOutstandingRequests(new RuntimeException( s"Connection from ${ctx.channel.remoteAddress} closed")) @@ -71,7 +87,7 @@ class BlockClientHandler extends SimpleChannelInboundHandler[ServerResponse] wit } override def exceptionCaught(ctx: ChannelHandlerContext, cause: Throwable): Unit = { - if (outstandingRequests.size() > 0) { + if (outstandingFetches.size() > 0) { logError( s"Exception in connection from ${ctx.channel.remoteAddress}: ${cause.getMessage}", cause) failOutstandingRequests(cause) @@ -83,23 +99,39 @@ class BlockClientHandler extends SimpleChannelInboundHandler[ServerResponse] wit val server = ctx.channel.remoteAddress.toString response match { case BlockFetchSuccess(blockId, buf) => - val listener = outstandingRequests.get(blockId) + val listener = outstandingFetches.get(blockId) if (listener == null) { logWarning(s"Got a response for block $blockId from $server but it is not outstanding") buf.release() } else { - outstandingRequests.remove(blockId) + outstandingFetches.remove(blockId) listener.onBlockFetchSuccess(blockId, buf) buf.release() } case BlockFetchFailure(blockId, errorMsg) => - val listener = outstandingRequests.get(blockId) + val listener = outstandingFetches.get(blockId) if (listener == null) { logWarning( s"Got a response for block $blockId from $server ($errorMsg) but it is not outstanding") } else { - outstandingRequests.remove(blockId) - listener.onBlockFetchFailure(blockId, new RuntimeException(errorMsg)) + outstandingFetches.remove(blockId) + listener.onBlockFetchFailure(blockId, new BlockFetchFailureException(blockId, errorMsg)) + } + case BlockUploadSuccess(blockId) => + val p = outstandingUploads.get(blockId) + if (p == null) { + logWarning(s"Got a response for upload $blockId from $server but it is not outstanding") + } else { + outstandingUploads.remove(blockId) + p.success(Unit) + } + case BlockUploadFailure(blockId, error) => + val p = outstandingUploads.get(blockId) + if (p == null) { + logWarning(s"Got a response for upload $blockId from $server but it is not outstanding") + } else { + outstandingUploads.remove(blockId) + p.failure(new BlockUploadFailureException(blockId)) } } } diff --git a/core/src/main/scala/org/apache/spark/network/netty/BlockServerHandler.scala b/core/src/main/scala/org/apache/spark/network/netty/BlockServerHandler.scala index c3b4d41829f4e..44687f0b770e9 100644 --- a/core/src/main/scala/org/apache/spark/network/netty/BlockServerHandler.scala +++ b/core/src/main/scala/org/apache/spark/network/netty/BlockServerHandler.scala @@ -21,6 +21,7 @@ import io.netty.channel._ import org.apache.spark.Logging import org.apache.spark.network.{ManagedBuffer, BlockDataManager} +import org.apache.spark.storage.StorageLevel /** @@ -39,13 +40,13 @@ private[netty] class BlockServerHandler(dataProvider: BlockDataManager) override def channelRead0(ctx: ChannelHandlerContext, request: ClientRequest): Unit = { request match { case BlockFetchRequest(blockIds) => - blockIds.foreach(processBlockRequest(ctx, _)) - case BlockUploadRequest(blockId, data) => - // TODO(rxin): handle upload. + blockIds.foreach(processFetchRequest(ctx, _)) + case BlockUploadRequest(blockId, data, level) => + processUploadRequest(ctx, blockId, data, level) } } // end of channelRead0 - private def processBlockRequest(ctx: ChannelHandlerContext, blockId: String): Unit = { + private def processFetchRequest(ctx: ChannelHandlerContext, blockId: String): Unit = { // A helper function to send error message back to the client. def client = ctx.channel.remoteAddress.toString @@ -90,4 +91,35 @@ private[netty] class BlockServerHandler(dataProvider: BlockDataManager) } ) } // end of processBlockRequest + + private def processUploadRequest( + ctx: ChannelHandlerContext, + blockId: String, + data: ManagedBuffer, + level: StorageLevel): Unit = { + // A helper function to send error message back to the client. + def client = ctx.channel.remoteAddress.toString + + try { + dataProvider.putBlockData(blockId, data, level) + ctx.writeAndFlush(BlockUploadSuccess(blockId)).addListener(new ChannelFutureListener { + override def operationComplete(future: ChannelFuture): Unit = { + if (!future.isSuccess) { + logError(s"Error sending an ACK back to client $client") + } + } + }) + } catch { + case e: Throwable => + logError(s"Error processing uploaded block $blockId", e) + ctx.writeAndFlush(BlockUploadFailure(blockId, e.getMessage)).addListener( + new ChannelFutureListener { + override def operationComplete(future: ChannelFuture): Unit = { + if (!future.isSuccess) { + logError(s"Error sending an ACK back to client $client") + } + } + }) + } + } // end of processUploadRequest } diff --git a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala index 14df5161cb0f3..b7f979dccd0f5 100644 --- a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala +++ b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala @@ -63,9 +63,9 @@ final class NettyBlockTransferService(conf: SparkConf) extends BlockTransferServ hostname: String, port: Int, blockId: String, - blockData: ManagedBuffer, level: StorageLevel): Future[Unit] = { - // TODO(rxin): Implement uploadBlock. - ??? + blockData: ManagedBuffer, + level: StorageLevel): Future[Unit] = { + clientFactory.createClient(hostName, port).uploadBlock(blockId, blockData, level) } override def hostName: String = { diff --git a/core/src/main/scala/org/apache/spark/network/netty/protocol.scala b/core/src/main/scala/org/apache/spark/network/netty/protocol.scala index 6a14ad26dbf21..13942f3d0adcd 100644 --- a/core/src/main/scala/org/apache/spark/network/netty/protocol.scala +++ b/core/src/main/scala/org/apache/spark/network/netty/protocol.scala @@ -17,6 +17,7 @@ package org.apache.spark.network.netty +import java.nio.ByteBuffer import java.util.{List => JList} import io.netty.buffer.ByteBuf @@ -25,7 +26,8 @@ import io.netty.channel.ChannelHandler.Sharable import io.netty.handler.codec._ import org.apache.spark.Logging -import org.apache.spark.network.{NettyManagedBuffer, ManagedBuffer} +import org.apache.spark.network.{NioManagedBuffer, NettyManagedBuffer, ManagedBuffer} +import org.apache.spark.storage.StorageLevel /** Messages from the client to the server. */ @@ -47,7 +49,11 @@ final case class BlockFetchRequest(blocks: Seq[String]) extends ClientRequest { * Request to upload a block to the server. Currently the server does not ack the upload request. */ private[netty] -final case class BlockUploadRequest(blockId: String, data: ManagedBuffer) extends ClientRequest { +final case class BlockUploadRequest( + blockId: String, + data: ManagedBuffer, + level: StorageLevel) + extends ClientRequest { require(blockId.length <= Byte.MaxValue) override def id = 1 } @@ -73,6 +79,20 @@ final case class BlockFetchFailure(blockId: String, error: String) extends Serve override def id = 1 } +/** Response to [[BlockUploadRequest]] when a block is successfully uploaded. */ +private[netty] +final case class BlockUploadSuccess(blockId: String) extends ServerResponse { + require(blockId.length <= Byte.MaxValue) + override def id = 2 +} + +/** Response to [[BlockUploadRequest]] when there is an error uploading the block. */ +private[netty] +final case class BlockUploadFailure(blockId: String, error: String) extends ServerResponse { + require(blockId.length <= Byte.MaxValue) + override def id = 3 +} + /** * Encoder for [[ClientRequest]] used in client side. @@ -102,12 +122,12 @@ final class ClientRequestEncoder extends MessageToMessageEncoder[ClientRequest] assert(buf.writableBytes() == 0) out.add(buf) - case BlockUploadRequest(blockId, data) => + case BlockUploadRequest(blockId, data, level) => // 8 bytes: frame size // 1 byte: msg id (BlockFetchRequest vs BlockUploadRequest) // 1 byte: blockId.length // data itself (length can be derived from: frame size - 1 - blockId.length) - val headerLength = 8 + 1 + 1 + blockId.length + val headerLength = 8 + 1 + 1 + blockId.length + 5 val frameLength = headerLength + data.size val header = ctx.alloc().buffer(headerLength) @@ -118,6 +138,8 @@ final class ClientRequestEncoder extends MessageToMessageEncoder[ClientRequest] header.writeLong(frameLength) header.writeByte(in.id) ProtocolUtils.writeBlockId(header, blockId) + header.writeInt(level.toInt) + header.writeByte(level.replication) assert(header.writableBytes() == 0) out.add(header) @@ -148,8 +170,12 @@ final class ClientRequestDecoder extends MessageToMessageDecoder[ByteBuf] { case 1 => // BlockUploadRequest val blockId = ProtocolUtils.readBlockId(in) - in.retain() // retain the bytebuf so we don't recycle it immediately. - BlockUploadRequest(blockId, new NettyManagedBuffer(in)) + val level = new StorageLevel(in.readInt(), in.readByte()) + + val ret = ByteBuffer.allocate(in.readableBytes()) + ret.put(in.nioBuffer()) + ret.flip() + BlockUploadRequest(blockId, new NioManagedBuffer(ret), level) } assert(decoded.id == msgTypeId) @@ -205,6 +231,27 @@ final class ServerResponseEncoder extends MessageToMessageEncoder[ServerResponse ProtocolUtils.writeBlockId(buf, blockId) buf.writeBytes(error.getBytes) + assert(buf.writableBytes() == 0) + out.add(buf) + + case BlockUploadSuccess(blockId) => + val frameLength = 8 + 1 + 1 + blockId.length + val buf = ctx.alloc().buffer(frameLength) + buf.writeLong(frameLength) + buf.writeByte(in.id) + ProtocolUtils.writeBlockId(buf, blockId) + + assert(buf.writableBytes() == 0) + out.add(buf) + + case BlockUploadFailure(blockId, error) => + val frameLength = 8 + 1 + 1 + blockId.length + + error.length + val buf = ctx.alloc().buffer(frameLength) + buf.writeLong(frameLength) + buf.writeByte(in.id) + ProtocolUtils.writeBlockId(buf, blockId) + buf.writeBytes(error.getBytes) + assert(buf.writableBytes() == 0) out.add(buf) } @@ -228,13 +275,22 @@ final class ServerResponseDecoder extends MessageToMessageDecoder[ByteBuf] { case 0 => // BlockFetchSuccess val blockId = ProtocolUtils.readBlockId(in) in.retain() - new BlockFetchSuccess(blockId, new NettyManagedBuffer(in)) + BlockFetchSuccess(blockId, new NettyManagedBuffer(in)) case 1 => // BlockFetchFailure val blockId = ProtocolUtils.readBlockId(in) val errorBytes = new Array[Byte](in.readableBytes()) in.readBytes(errorBytes) - new BlockFetchFailure(blockId, new String(errorBytes)) + BlockFetchFailure(blockId, new String(errorBytes)) + + case 2 => // BlockUploadSuccess + BlockUploadSuccess(ProtocolUtils.readBlockId(in)) + + case 3 => // BlockUploadFailure + val blockId = ProtocolUtils.readBlockId(in) + val errorBytes = new Array[Byte](in.readableBytes()) + in.readBytes(errorBytes) + BlockUploadFailure(blockId, new String(errorBytes)) } assert(decoded.id == msgId) diff --git a/core/src/main/scala/org/apache/spark/storage/StorageLevel.scala b/core/src/main/scala/org/apache/spark/storage/StorageLevel.scala index 1e35abaab5353..2fc7c7d9b8312 100644 --- a/core/src/main/scala/org/apache/spark/storage/StorageLevel.scala +++ b/core/src/main/scala/org/apache/spark/storage/StorageLevel.scala @@ -42,7 +42,7 @@ class StorageLevel private( extends Externalizable { // TODO: Also add fields for caching priority, dataset ID, and flushing. - private def this(flags: Int, replication: Int) { + private[spark] def this(flags: Int, replication: Int) { this((flags & 8) != 0, (flags & 4) != 0, (flags & 2) != 0, (flags & 1) != 0, replication) } @@ -98,6 +98,7 @@ class StorageLevel private( } override def writeExternal(out: ObjectOutput) { + /* If the wire protocol changes, please also update [[ClientRequestEncoder]] */ out.writeByte(toInt) out.writeByte(_replication) } diff --git a/core/src/test/scala/org/apache/spark/network/netty/BlockClientHandlerSuite.scala b/core/src/test/scala/org/apache/spark/network/netty/BlockClientHandlerSuite.scala index 7b80fe6aa364a..4c3a649081574 100644 --- a/core/src/test/scala/org/apache/spark/network/netty/BlockClientHandlerSuite.scala +++ b/core/src/test/scala/org/apache/spark/network/netty/BlockClientHandlerSuite.scala @@ -35,7 +35,7 @@ class BlockClientHandlerSuite extends FunSuite with PrivateMethodTester { /** Helper method to get num. outstanding requests from a private field using reflection. */ private def sizeOfOutstandingRequests(handler: BlockClientHandler): Int = { val f = handler.getClass.getDeclaredField( - "org$apache$spark$network$netty$BlockClientHandler$$outstandingRequests") + "org$apache$spark$network$netty$BlockClientHandler$$outstandingFetches") f.setAccessible(true) f.get(handler).asInstanceOf[java.util.Map[_, _]].size } @@ -45,7 +45,7 @@ class BlockClientHandlerSuite extends FunSuite with PrivateMethodTester { val blockData = "blahblahblahblahblah" val handler = new BlockClientHandler val listener = mock(classOf[BlockFetchingListener]) - handler.addRequest(blockId, listener) + handler.addFetchRequest(blockId, listener) assert(sizeOfOutstandingRequests(handler) === 1) val channel = new EmbeddedChannel(handler) @@ -63,7 +63,7 @@ class BlockClientHandlerSuite extends FunSuite with PrivateMethodTester { val blockId = "test_block" val handler = new BlockClientHandler val listener = mock(classOf[BlockFetchingListener]) - handler.addRequest(blockId, listener) + handler.addFetchRequest(blockId, listener) assert(sizeOfOutstandingRequests(handler) === 1) val channel = new EmbeddedChannel(handler) @@ -77,9 +77,9 @@ class BlockClientHandlerSuite extends FunSuite with PrivateMethodTester { test("clear all outstanding request upon uncaught exception") { val handler = new BlockClientHandler val listener = mock(classOf[BlockFetchingListener]) - handler.addRequest("b1", listener) - handler.addRequest("b2", listener) - handler.addRequest("b3", listener) + handler.addFetchRequest("b1", listener) + handler.addFetchRequest("b2", listener) + handler.addFetchRequest("b3", listener) assert(sizeOfOutstandingRequests(handler) === 3) val channel = new EmbeddedChannel(handler) @@ -96,9 +96,9 @@ class BlockClientHandlerSuite extends FunSuite with PrivateMethodTester { test("clear all outstanding request upon connection close") { val handler = new BlockClientHandler val listener = mock(classOf[BlockFetchingListener]) - handler.addRequest("c1", listener) - handler.addRequest("c2", listener) - handler.addRequest("c3", listener) + handler.addFetchRequest("c1", listener) + handler.addFetchRequest("c2", listener) + handler.addFetchRequest("c3", listener) assert(sizeOfOutstandingRequests(handler) === 3) val channel = new EmbeddedChannel(handler) diff --git a/core/src/test/scala/org/apache/spark/network/netty/ProtocolSuite.scala b/core/src/test/scala/org/apache/spark/network/netty/ProtocolSuite.scala index 46604ea1fb624..8d1b7276f4082 100644 --- a/core/src/test/scala/org/apache/spark/network/netty/ProtocolSuite.scala +++ b/core/src/test/scala/org/apache/spark/network/netty/ProtocolSuite.scala @@ -21,6 +21,8 @@ import io.netty.channel.embedded.EmbeddedChannel import org.scalatest.FunSuite +import org.apache.spark.api.java.StorageLevels + /** * Test client/server encoder/decoder protocol. @@ -99,11 +101,13 @@ class ProtocolSuite extends FunSuite { testClientToServer(BlockFetchRequest(Seq("b1", "b2", "b3"))) } - ignore("client to server protocol - BlockUploadRequest(\"\", new TestManagedBuffer(0))") { - testClientToServer(BlockUploadRequest("", new TestManagedBuffer(0))) + test("client to server protocol - BlockUploadRequest(\"\", new TestManagedBuffer(0))") { + testClientToServer( + BlockUploadRequest("", new TestManagedBuffer(0), StorageLevels.MEMORY_AND_DISK)) } - ignore("client to server protocol - BlockUploadRequest(\"b_upload\", new TestManagedBuffer(10))") { - testClientToServer(BlockUploadRequest("b_upload", new TestManagedBuffer(10))) + test("client to server protocol - BlockUploadRequest(\"b_upload\", new TestManagedBuffer(10))") { + testClientToServer( + BlockUploadRequest("b_upload", new TestManagedBuffer(10), StorageLevels.MEMORY_AND_DISK_2)) } } From 407e59afd3cb7385af9f63dc2263a40c7c21d783 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Mon, 29 Sep 2014 19:37:28 -0700 Subject: [PATCH 24/46] Fix style violation. --- .../scala/org/apache/spark/network/netty/BlockClient.scala | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/network/netty/BlockClient.scala b/core/src/main/scala/org/apache/spark/network/netty/BlockClient.scala index c77a7ae1ccb0f..6bdbf88d337ce 100644 --- a/core/src/main/scala/org/apache/spark/network/netty/BlockClient.scala +++ b/core/src/main/scala/org/apache/spark/network/netty/BlockClient.scala @@ -90,7 +90,8 @@ class BlockClient(cf: ChannelFuture, handler: BlockClientHandler) extends Closea }) } - def uploadBlock(blockId: String, data: ManagedBuffer, storageLevel: StorageLevel): Future[Unit] = { + def uploadBlock(blockId: String, data: ManagedBuffer, storageLevel: StorageLevel): Future[Unit] = + { var startTime: Long = 0 logTrace { startTime = System.currentTimeMillis() From f6c220df8406be14fbdb7270682727e1085518a4 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Mon, 29 Sep 2014 23:30:17 -0700 Subject: [PATCH 25/46] Merge with latest master. --- .../org/apache/spark/network/nio/NioBlockTransferService.scala | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/network/nio/NioBlockTransferService.scala b/core/src/main/scala/org/apache/spark/network/nio/NioBlockTransferService.scala index 3d72155f8db8d..e942b43d9cc4a 100644 --- a/core/src/main/scala/org/apache/spark/network/nio/NioBlockTransferService.scala +++ b/core/src/main/scala/org/apache/spark/network/nio/NioBlockTransferService.scala @@ -201,10 +201,9 @@ final class NioBlockTransferService(conf: SparkConf, securityManager: SecurityMa private def getBlock(blockId: String): ByteBuffer = { val startTimeMs = System.currentTimeMillis() logDebug("GetBlock " + blockId + " started from " + startTimeMs) - // TODO(rxin): propagate error back to the client? val buffer = blockDataManager.getBlockData(blockId) logDebug("GetBlock " + blockId + " used " + Utils.getUsedTimeMs(startTimeMs) + " and got buffer " + buffer) - if (buffer == null) null else buffer.nioByteBuffer() + buffer.nioByteBuffer() } } From 5d98ce3de1deeeb7fbdc26b9303a591c46f1892b Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Tue, 30 Sep 2014 00:56:32 -0700 Subject: [PATCH 26/46] Flip buffer. --- .../scala/org/apache/spark/network/BlockTransferService.scala | 1 + 1 file changed, 1 insertion(+) diff --git a/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala b/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala index 2a0a1a0bc0a14..d3ed683c7e880 100644 --- a/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala +++ b/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala @@ -94,6 +94,7 @@ abstract class BlockTransferService extends Closeable { lock.synchronized { val ret = ByteBuffer.allocate(data.size.toInt) ret.put(data.nioByteBuffer()) + ret.flip() result = Left(new NioManagedBuffer(ret)) lock.notify() } From f7e7568414692989215d97abce9dda2fe172abb4 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Tue, 30 Sep 2014 12:28:21 -0700 Subject: [PATCH 27/46] Fixed spark.shuffle.io.receiveBuffer setting. --- .../main/scala/org/apache/spark/network/netty/NettyConfig.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/network/netty/NettyConfig.scala b/core/src/main/scala/org/apache/spark/network/netty/NettyConfig.scala index d5078e417d6d2..7c3074e939794 100644 --- a/core/src/main/scala/org/apache/spark/network/netty/NettyConfig.scala +++ b/core/src/main/scala/org/apache/spark/network/netty/NettyConfig.scala @@ -53,7 +53,7 @@ class NettyConfig(conf: SparkConf) { * buffer size should be ~ 1.25MB */ private[netty] val receiveBuf: Option[Int] = - conf.getOption("spark.shuffle.io.sendBuffer").map(_.toInt) + conf.getOption("spark.shuffle.io.receiveBuffer").map(_.toInt) /** Send buffer size (SO_SNDBUF). */ private[netty] val sendBuf: Option[Int] = From 29c6dcfaacb2e8b1f0582c6d5e435349c52e29af Mon Sep 17 00:00:00 2001 From: Aaron Davidson Date: Sun, 5 Oct 2014 17:58:43 -0700 Subject: [PATCH 28/46] [SPARK-3453] Netty-based BlockTransferService, extracted from Spark core This PR encapsulates #2330, which is itself a continuation of #2240. The first goal of this PR is to provide an alternate, simpler implementation of the ConnectionManager which is based on Netty. In addition to this goal, however, we want to resolve [SPARK-3796](https://issues.apache.org/jira/browse/SPARK-3796), which calls for a standalone shuffle service which can be integrated into the YARN NodeManager, Standalone Worker, or on its own. This PR makes the first step in this direction by ensuring that the actual Netty service is as small as possible and extracted from Spark core. Given this, we should be able to construct this standalone jar which can be included in other JVMs without incurring significant dependency or runtime issues. The actual work to ensure that such a standalone shuffle service would work in Spark will be left for a future PR, however. In order to minimize dependencies and allow for the service to be long-running (possibly much longer-running than Spark, and possibly having to support multiple version of Spark simultaneously), the entire service has been ported to Java, where we have full control over the binary compatibility of the components and do not depend on the Scala runtime or version. These PRs have been addressed by folding in #2330: SPARK-3453: Refactor Netty module to use BlockTransferService interface SPARK-3018: Release all buffers upon task completion/failure SPARK-3002: Create a connection pool and reuse clients across different threads SPARK-3017: Integration tests and unit tests for connection failures SPARK-3049: Make sure client doesn't block when server/connection has error(s) SPARK-3502: SO_RCVBUF and SO_SNDBUF should be bootstrap childOption, not option SPARK-3503: Disable thread local cache in PooledByteBufAllocator TODO before mergeable: [ ] Implement uploadBlock() [ ] Unit tests for RPC side of code [ ] Performance testing [ ] Turn OFF by default (currently on for unit testing) --- core/pom.xml | 5 + .../scala/org/apache/spark/SparkEnv.scala | 17 +- .../spark/network/BlockDataManager.scala | 8 +- .../spark/network/BlockFetchingListener.scala | 2 + .../spark/network/BlockTransferService.scala | 13 +- .../apache/spark/network/ManagedBuffer.scala | 187 ---------- .../spark/network/netty/BlockClient.scala | 125 ------- .../network/netty/BlockClientFactory.scala | 175 ---------- .../network/netty/BlockClientHandler.scala | 138 -------- .../spark/network/netty/BlockServer.scala | 127 ------- .../network/netty/BlockServerHandler.scala | 125 ------- .../network/netty/NettyBlockFetcher.scala | 92 +++++ .../network/netty/NettyBlockRpcServer.scala | 59 ++++ .../netty/NettyBlockTransferService.scala | 69 ++-- .../apache/spark/network/netty/protocol.scala | 326 ------------------ .../network/nio/NioBlockTransferService.scala | 18 +- .../shuffle/FileShuffleBlockManager.scala | 6 +- .../shuffle/IndexShuffleBlockManager.scala | 2 +- .../spark/shuffle/ShuffleBlockManager.scala | 3 +- .../apache/spark/storage/BlockManager.scala | 26 +- .../storage/BlockNotFoundException.scala | 1 - .../storage/ShuffleBlockFetcherIterator.scala | 11 +- .../apache/spark/storage/StorageLevel.scala | 3 +- .../netty/BlockClientFactorySuite.scala | 91 ----- .../netty/BlockClientHandlerSuite.scala | 114 ------ .../spark/network/netty/ProtocolSuite.scala | 113 ------ .../netty/ServerClientIntegrationSuite.scala | 174 ---------- .../network/netty/TestManagedBuffer.scala | 72 ---- .../hash/HashShuffleManagerSuite.scala | 8 +- .../ShuffleBlockFetcherIteratorSuite.scala | 3 +- network/common/pom.xml | 94 +++++ .../buffer/FileSegmentManagedBuffer.java | 146 ++++++++ .../spark/network/buffer/ManagedBuffer.java | 70 ++++ .../network/buffer/NettyManagedBuffer.java | 76 ++++ .../network/buffer/NioManagedBuffer.java | 75 ++++ .../client/ChunkFetchFailureException.java | 37 ++ .../network/client/ChunkReceivedCallback.java | 47 +++ .../network/client/RpcResponseCallback.java | 30 ++ .../spark/network/client/SluiceClient.java | 161 +++++++++ .../network/client/SluiceClientFactory.java | 173 ++++++++++ .../network/client/SluiceClientHandler.java | 155 +++++++++ .../spark/network/protocol/Encodable.java | 35 ++ .../spark/network/protocol/StreamChunkId.java | 73 ++++ .../protocol/request/ChunkFetchRequest.java | 68 ++++ .../protocol/request/ClientRequest.java | 58 ++++ .../request/ClientRequestDecoder.java | 57 +++ .../request/ClientRequestEncoder.java | 46 +++ .../network/protocol/request/RpcRequest.java | 81 +++++ .../protocol/response/ChunkFetchFailure.java | 78 +++++ .../protocol/response/ChunkFetchSuccess.java | 82 +++++ .../network/protocol/response/RpcFailure.java | 73 ++++ .../protocol/response/RpcResponse.java | 72 ++++ .../protocol/response/ServerResponse.java | 63 ++++ .../response/ServerResponseDecoder.java | 60 ++++ .../response/ServerResponseEncoder.java | 74 ++++ .../network/server/DefaultStreamManager.java | 87 +++++ .../spark/network/server/RpcHandler.java | 31 ++ .../spark/network/server/SluiceServer.java | 124 +++++++ .../network/server/SluiceServerHandler.java | 153 ++++++++ .../spark/network/server/StreamManager.java | 52 +++ .../spark/network/util/ConfigProvider.java | 52 +++ .../network/util/DefaultConfigProvider.java | 32 ++ .../org/apache/spark/network/util/IOMode.java | 27 ++ .../apache/spark/network/util/JavaUtils.java | 19 +- .../apache/spark/network/util/NettyUtils.java | 109 ++++++ .../spark/network/util/SluiceConfig.java | 38 +- .../spark/network/IntegrationSuite.java | 217 ++++++++++++ .../apache/spark/network/NoOpRpcHandler.java | 26 ++ .../apache/spark/network/ProtocolSuite.java | 84 +++++ .../network/SluiceClientFactorySuite.java | 101 ++++++ .../network/SluiceClientHandlerSuite.java | 90 +++++ .../spark/network/TestManagedBuffer.java | 104 ++++++ .../org/apache/spark/network/TestUtils.java | 30 ++ pom.xml | 1 + project/MimaExcludes.scala | 2 - .../streaming/scheduler/ReceiverTracker.scala | 2 +- 76 files changed, 3579 insertions(+), 1899 deletions(-) delete mode 100644 core/src/main/scala/org/apache/spark/network/ManagedBuffer.scala delete mode 100644 core/src/main/scala/org/apache/spark/network/netty/BlockClient.scala delete mode 100644 core/src/main/scala/org/apache/spark/network/netty/BlockClientFactory.scala delete mode 100644 core/src/main/scala/org/apache/spark/network/netty/BlockClientHandler.scala delete mode 100644 core/src/main/scala/org/apache/spark/network/netty/BlockServer.scala delete mode 100644 core/src/main/scala/org/apache/spark/network/netty/BlockServerHandler.scala create mode 100644 core/src/main/scala/org/apache/spark/network/netty/NettyBlockFetcher.scala create mode 100644 core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala delete mode 100644 core/src/main/scala/org/apache/spark/network/netty/protocol.scala delete mode 100644 core/src/test/scala/org/apache/spark/network/netty/BlockClientFactorySuite.scala delete mode 100644 core/src/test/scala/org/apache/spark/network/netty/BlockClientHandlerSuite.scala delete mode 100644 core/src/test/scala/org/apache/spark/network/netty/ProtocolSuite.scala delete mode 100644 core/src/test/scala/org/apache/spark/network/netty/ServerClientIntegrationSuite.scala delete mode 100644 core/src/test/scala/org/apache/spark/network/netty/TestManagedBuffer.scala create mode 100644 network/common/pom.xml create mode 100644 network/common/src/main/java/org/apache/spark/network/buffer/FileSegmentManagedBuffer.java create mode 100644 network/common/src/main/java/org/apache/spark/network/buffer/ManagedBuffer.java create mode 100644 network/common/src/main/java/org/apache/spark/network/buffer/NettyManagedBuffer.java create mode 100644 network/common/src/main/java/org/apache/spark/network/buffer/NioManagedBuffer.java create mode 100644 network/common/src/main/java/org/apache/spark/network/client/ChunkFetchFailureException.java create mode 100644 network/common/src/main/java/org/apache/spark/network/client/ChunkReceivedCallback.java create mode 100644 network/common/src/main/java/org/apache/spark/network/client/RpcResponseCallback.java create mode 100644 network/common/src/main/java/org/apache/spark/network/client/SluiceClient.java create mode 100644 network/common/src/main/java/org/apache/spark/network/client/SluiceClientFactory.java create mode 100644 network/common/src/main/java/org/apache/spark/network/client/SluiceClientHandler.java create mode 100644 network/common/src/main/java/org/apache/spark/network/protocol/Encodable.java create mode 100644 network/common/src/main/java/org/apache/spark/network/protocol/StreamChunkId.java create mode 100644 network/common/src/main/java/org/apache/spark/network/protocol/request/ChunkFetchRequest.java create mode 100644 network/common/src/main/java/org/apache/spark/network/protocol/request/ClientRequest.java create mode 100644 network/common/src/main/java/org/apache/spark/network/protocol/request/ClientRequestDecoder.java create mode 100644 network/common/src/main/java/org/apache/spark/network/protocol/request/ClientRequestEncoder.java create mode 100644 network/common/src/main/java/org/apache/spark/network/protocol/request/RpcRequest.java create mode 100644 network/common/src/main/java/org/apache/spark/network/protocol/response/ChunkFetchFailure.java create mode 100644 network/common/src/main/java/org/apache/spark/network/protocol/response/ChunkFetchSuccess.java create mode 100644 network/common/src/main/java/org/apache/spark/network/protocol/response/RpcFailure.java create mode 100644 network/common/src/main/java/org/apache/spark/network/protocol/response/RpcResponse.java create mode 100644 network/common/src/main/java/org/apache/spark/network/protocol/response/ServerResponse.java create mode 100644 network/common/src/main/java/org/apache/spark/network/protocol/response/ServerResponseDecoder.java create mode 100644 network/common/src/main/java/org/apache/spark/network/protocol/response/ServerResponseEncoder.java create mode 100644 network/common/src/main/java/org/apache/spark/network/server/DefaultStreamManager.java create mode 100644 network/common/src/main/java/org/apache/spark/network/server/RpcHandler.java create mode 100644 network/common/src/main/java/org/apache/spark/network/server/SluiceServer.java create mode 100644 network/common/src/main/java/org/apache/spark/network/server/SluiceServerHandler.java create mode 100644 network/common/src/main/java/org/apache/spark/network/server/StreamManager.java create mode 100644 network/common/src/main/java/org/apache/spark/network/util/ConfigProvider.java create mode 100644 network/common/src/main/java/org/apache/spark/network/util/DefaultConfigProvider.java create mode 100644 network/common/src/main/java/org/apache/spark/network/util/IOMode.java rename core/src/main/scala/org/apache/spark/network/exceptions.scala => network/common/src/main/java/org/apache/spark/network/util/JavaUtils.java (65%) create mode 100644 network/common/src/main/java/org/apache/spark/network/util/NettyUtils.java rename core/src/main/scala/org/apache/spark/network/netty/NettyConfig.scala => network/common/src/main/java/org/apache/spark/network/util/SluiceConfig.java (58%) create mode 100644 network/common/src/test/java/org/apache/spark/network/IntegrationSuite.java create mode 100644 network/common/src/test/java/org/apache/spark/network/NoOpRpcHandler.java create mode 100644 network/common/src/test/java/org/apache/spark/network/ProtocolSuite.java create mode 100644 network/common/src/test/java/org/apache/spark/network/SluiceClientFactorySuite.java create mode 100644 network/common/src/test/java/org/apache/spark/network/SluiceClientHandlerSuite.java create mode 100644 network/common/src/test/java/org/apache/spark/network/TestManagedBuffer.java create mode 100644 network/common/src/test/java/org/apache/spark/network/TestUtils.java diff --git a/core/pom.xml b/core/pom.xml index a5a178079bc57..aff0d989d01bb 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -44,6 +44,11 @@ + + org.apache.spark + network + ${project.version} + net.java.dev.jets3t jets3t diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala index 373ce795a309e..867173e04714e 100644 --- a/core/src/main/scala/org/apache/spark/SparkEnv.scala +++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala @@ -32,7 +32,7 @@ import org.apache.spark.api.python.PythonWorkerFactory import org.apache.spark.broadcast.BroadcastManager import org.apache.spark.metrics.MetricsSystem import org.apache.spark.network.BlockTransferService -import org.apache.spark.network.netty.NettyBlockTransferService +import org.apache.spark.network.netty.{NettyBlockTransferService} import org.apache.spark.network.nio.NioBlockTransferService import org.apache.spark.scheduler.LiveListenerBus import org.apache.spark.serializer.Serializer @@ -40,7 +40,6 @@ import org.apache.spark.shuffle.{ShuffleMemoryManager, ShuffleManager} import org.apache.spark.storage._ import org.apache.spark.util.{AkkaUtils, Utils} - /** * :: DeveloperApi :: * Holds all the runtime environment objects for a running Spark instance (either master or worker), @@ -233,12 +232,14 @@ object SparkEnv extends Logging { val shuffleMemoryManager = new ShuffleMemoryManager(conf) - // TODO(rxin): Config option based on class name, similar to shuffle mgr and compression codec. - val blockTransferService = if (conf.getBoolean("spark.shuffle.use.netty", false)) { - new NettyBlockTransferService(conf) - } else { - new NioBlockTransferService(conf, securityManager) - } + // TODO: This is only netty by default for initial testing -- it should not be merged as such!!! + val blockTransferService = + conf.get("spark.shuffle.blockTransferService", "netty").toLowerCase match { + case "netty" => + new NettyBlockTransferService(conf) + case "nio" => + new NioBlockTransferService(conf, securityManager) + } val blockManagerMaster = new BlockManagerMaster(registerOrLookup( "BlockManagerMaster", diff --git a/core/src/main/scala/org/apache/spark/network/BlockDataManager.scala b/core/src/main/scala/org/apache/spark/network/BlockDataManager.scala index 0eeffe0e7c5e6..1745d52c81923 100644 --- a/core/src/main/scala/org/apache/spark/network/BlockDataManager.scala +++ b/core/src/main/scala/org/apache/spark/network/BlockDataManager.scala @@ -17,8 +17,8 @@ package org.apache.spark.network -import org.apache.spark.storage.StorageLevel - +import org.apache.spark.network.buffer.ManagedBuffer +import org.apache.spark.storage.{BlockId, StorageLevel} private[spark] trait BlockDataManager { @@ -27,10 +27,10 @@ trait BlockDataManager { * Interface to get local block data. Throws an exception if the block cannot be found or * cannot be read successfully. */ - def getBlockData(blockId: String): ManagedBuffer + def getBlockData(blockId: BlockId): ManagedBuffer /** * Put the block locally, using the given storage level. */ - def putBlockData(blockId: String, data: ManagedBuffer, level: StorageLevel): Unit + def putBlockData(blockId: BlockId, data: ManagedBuffer, level: StorageLevel): Unit } diff --git a/core/src/main/scala/org/apache/spark/network/BlockFetchingListener.scala b/core/src/main/scala/org/apache/spark/network/BlockFetchingListener.scala index dd70e26647939..e35fdb4e95899 100644 --- a/core/src/main/scala/org/apache/spark/network/BlockFetchingListener.scala +++ b/core/src/main/scala/org/apache/spark/network/BlockFetchingListener.scala @@ -19,6 +19,8 @@ package org.apache.spark.network import java.util.EventListener +import org.apache.spark.network.buffer.ManagedBuffer + /** * Listener callback interface for [[BlockTransferService.fetchBlocks]]. diff --git a/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala b/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala index d3ed683c7e880..8287a0fc81cfe 100644 --- a/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala +++ b/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala @@ -18,16 +18,18 @@ package org.apache.spark.network import java.io.Closeable -import java.nio.ByteBuffer + +import org.apache.spark.network.buffer.ManagedBuffer import scala.concurrent.{Await, Future} import scala.concurrent.duration.Duration +import org.apache.spark.Logging import org.apache.spark.storage.StorageLevel - +import org.apache.spark.util.Utils private[spark] -abstract class BlockTransferService extends Closeable { +abstract class BlockTransferService extends Closeable with Logging { /** * Initialize the transfer service by giving it the BlockDataManager that can be used to fetch @@ -92,10 +94,7 @@ abstract class BlockTransferService extends Closeable { } override def onBlockFetchSuccess(blockId: String, data: ManagedBuffer): Unit = { lock.synchronized { - val ret = ByteBuffer.allocate(data.size.toInt) - ret.put(data.nioByteBuffer()) - ret.flip() - result = Left(new NioManagedBuffer(ret)) + result = Left(data) lock.notify() } } diff --git a/core/src/main/scala/org/apache/spark/network/ManagedBuffer.scala b/core/src/main/scala/org/apache/spark/network/ManagedBuffer.scala deleted file mode 100644 index dd808d2500fbc..0000000000000 --- a/core/src/main/scala/org/apache/spark/network/ManagedBuffer.scala +++ /dev/null @@ -1,187 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network - -import java.io._ -import java.nio.ByteBuffer -import java.nio.channels.FileChannel -import java.nio.channels.FileChannel.MapMode - -import scala.util.Try - -import com.google.common.io.ByteStreams -import io.netty.buffer.{Unpooled, ByteBufInputStream, ByteBuf} -import io.netty.channel.DefaultFileRegion - -import org.apache.spark.util.{ByteBufferInputStream, Utils} - - -/** - * This interface provides an immutable view for data in the form of bytes. The implementation - * should specify how the data is provided: - * - * - [[FileSegmentManagedBuffer]]: data backed by part of a file - * - [[NioManagedBuffer]]: data backed by a NIO ByteBuffer - * - [[NettyManagedBuffer]]: data backed by a Netty ByteBuf - * - * The concrete buffer implementation might be managed outside the JVM garbage collector. - * For example, in the case of [[NettyManagedBuffer]], the buffers are reference counted. - * In that case, if the buffer is going to be passed around to a different thread, retain/release - * should be called. - */ -private[spark] -abstract class ManagedBuffer { - // Note that all the methods are defined with parenthesis because their implementations can - // have side effects (io operations). - - /** Number of bytes of the data. */ - def size: Long - - /** - * Exposes this buffer's data as an NIO ByteBuffer. Changing the position and limit of the - * returned ByteBuffer should not affect the content of this buffer. - */ - def nioByteBuffer(): ByteBuffer - - /** - * Exposes this buffer's data as an InputStream. The underlying implementation does not - * necessarily check for the length of bytes read, so the caller is responsible for making sure - * it does not go over the limit. - */ - def inputStream(): InputStream - - /** - * Increment the reference count by one if applicable. - */ - def retain(): this.type - - /** - * If applicable, decrement the reference count by one and deallocates the buffer if the - * reference count reaches zero. - */ - def release(): this.type - - /** - * Convert the buffer into an Netty object, used to write the data out. - */ - private[network] def convertToNetty(): AnyRef -} - - -/** - * A [[ManagedBuffer]] backed by a segment in a file - */ -private[spark] -final class FileSegmentManagedBuffer(val file: File, val offset: Long, val length: Long) - extends ManagedBuffer { - - override def size: Long = length - - override def nioByteBuffer(): ByteBuffer = { - var channel: FileChannel = null - try { - channel = new RandomAccessFile(file, "r").getChannel - channel.map(MapMode.READ_ONLY, offset, length) - } catch { - case e: IOException => - Try(channel.size).toOption match { - case Some(fileLen) => - throw new IOException(s"Error in reading $this (actual file length $fileLen)", e) - case None => - throw new IOException(s"Error in opening $this", e) - } - } finally { - if (channel != null) { - Utils.tryLog(channel.close()) - } - } - } - - override def inputStream(): InputStream = { - var is: FileInputStream = null - try { - is = new FileInputStream(file) - is.skip(offset) - ByteStreams.limit(is, length) - } catch { - case e: IOException => - if (is != null) { - Utils.tryLog(is.close()) - } - Try(file.length).toOption match { - case Some(fileLen) => - throw new IOException(s"Error in reading $this (actual file length $fileLen)", e) - case None => - throw new IOException(s"Error in opening $this", e) - } - case e: Throwable => - if (is != null) { - Utils.tryLog(is.close()) - } - throw e - } - } - - override def toString: String = s"${getClass.getName}($file, $offset, $length)" -} - - -/** - * A [[ManagedBuffer]] backed by [[java.nio.ByteBuffer]]. - */ -private[spark] -final class NioManagedBuffer(buf: ByteBuffer) extends ManagedBuffer { - - override def size: Long = buf.remaining() - - override def nioByteBuffer() = buf.duplicate() - - override def inputStream() = new ByteBufferInputStream(buf) - - private[network] override def convertToNetty(): AnyRef = Unpooled.wrappedBuffer(buf) - - // [[ByteBuffer]] is managed by the JVM garbage collector itself. - override def retain(): this.type = this - override def release(): this.type = this -} - - -/** - * A [[ManagedBuffer]] backed by a Netty [[ByteBuf]]. - */ -private[spark] -final class NettyManagedBuffer(buf: ByteBuf) extends ManagedBuffer { - - override def size: Long = buf.readableBytes() - - override def nioByteBuffer() = buf.nioBuffer() - - override def inputStream() = new ByteBufInputStream(buf) - - private[network] override def convertToNetty(): AnyRef = buf - - override def retain(): this.type = { - buf.retain() - this - } - - override def release(): this.type = { - buf.release() - this - } -} diff --git a/core/src/main/scala/org/apache/spark/network/netty/BlockClient.scala b/core/src/main/scala/org/apache/spark/network/netty/BlockClient.scala deleted file mode 100644 index 6bdbf88d337ce..0000000000000 --- a/core/src/main/scala/org/apache/spark/network/netty/BlockClient.scala +++ /dev/null @@ -1,125 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network.netty - -import java.io.Closeable -import java.util.concurrent.TimeoutException - -import scala.concurrent.{Future, promise} - -import io.netty.channel.{ChannelFuture, ChannelFutureListener} - -import org.apache.spark.Logging -import org.apache.spark.network.{ManagedBuffer, BlockFetchingListener} -import org.apache.spark.storage.StorageLevel - - -/** - * Client for [[NettyBlockTransferService]]. The connection to server must have been established - * using [[BlockClientFactory]] before instantiating this. - * - * This class is used to make requests to the server , while [[BlockClientHandler]] is responsible - * for handling responses from the server. - * - * Concurrency: thread safe and can be called from multiple threads. - * - * @param cf the ChannelFuture for the connection. - * @param handler [[BlockClientHandler]] for handling outstanding requests. - */ -@throws[TimeoutException] -private[netty] -class BlockClient(cf: ChannelFuture, handler: BlockClientHandler) extends Closeable with Logging { - - private[this] val serverAddr = cf.channel().remoteAddress().toString - - def isActive: Boolean = cf.channel().isActive - - /** - * Ask the remote server for a sequence of blocks, and execute the callback. - * - * Note that this is asynchronous and returns immediately. Upstream caller should throttle the - * rate of fetching; otherwise we could run out of memory due to large outstanding fetches. - * - * @param blockIds sequence of block ids to fetch. - * @param listener callback to fire on fetch success / failure. - */ - def fetchBlocks(blockIds: Seq[String], listener: BlockFetchingListener): Unit = { - var startTime: Long = 0 - logTrace { - startTime = System.currentTimeMillis() - s"Sending request $blockIds to $serverAddr" - } - - blockIds.foreach { blockId => - handler.addFetchRequest(blockId, listener) - } - - cf.channel().writeAndFlush(BlockFetchRequest(blockIds)).addListener(new ChannelFutureListener { - override def operationComplete(future: ChannelFuture): Unit = { - if (future.isSuccess) { - logTrace { - val timeTaken = System.currentTimeMillis() - startTime - s"Sending request $blockIds to $serverAddr took $timeTaken ms" - } - } else { - // Fail all blocks. - val errorMsg = - s"Failed to send request $blockIds to $serverAddr: ${future.cause.getMessage}" - logError(errorMsg, future.cause) - blockIds.foreach { blockId => - handler.removeFetchRequest(blockId) - listener.onBlockFetchFailure(blockId, new RuntimeException(errorMsg)) - } - } - } - }) - } - - def uploadBlock(blockId: String, data: ManagedBuffer, storageLevel: StorageLevel): Future[Unit] = - { - var startTime: Long = 0 - logTrace { - startTime = System.currentTimeMillis() - s"Uploading block ($blockId) to $serverAddr" - } - val f = cf.channel().writeAndFlush(new BlockUploadRequest(blockId, data, storageLevel)) - - val p = promise[Unit]() - handler.addUploadRequest(blockId, p) - f.addListener(new ChannelFutureListener { - override def operationComplete(future: ChannelFuture): Unit = { - if (future.isSuccess) { - logTrace { - val timeTaken = System.currentTimeMillis() - startTime - s"Uploading block ($blockId) to $serverAddr took $timeTaken ms" - } - } else { - // Fail all blocks. - val errorMsg = - s"Failed to upload block $blockId to $serverAddr: ${future.cause.getMessage}" - logError(errorMsg, future.cause) - } - } - }) - - p.future - } - - /** Close the connection. This does NOT block till the connection is closed. */ - def close(): Unit = cf.channel().close() -} diff --git a/core/src/main/scala/org/apache/spark/network/netty/BlockClientFactory.scala b/core/src/main/scala/org/apache/spark/network/netty/BlockClientFactory.scala deleted file mode 100644 index 8021cfdf42d1a..0000000000000 --- a/core/src/main/scala/org/apache/spark/network/netty/BlockClientFactory.scala +++ /dev/null @@ -1,175 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network.netty - -import java.io.Closeable -import java.util.concurrent.{ConcurrentHashMap, TimeoutException} - -import io.netty.bootstrap.Bootstrap -import io.netty.buffer.PooledByteBufAllocator -import io.netty.channel._ -import io.netty.channel.epoll.{Epoll, EpollEventLoopGroup, EpollSocketChannel} -import io.netty.channel.nio.NioEventLoopGroup -import io.netty.channel.socket.SocketChannel -import io.netty.channel.socket.nio.NioSocketChannel -import io.netty.util.internal.PlatformDependent - -import org.apache.spark.{Logging, SparkConf} -import org.apache.spark.util.Utils - - -/** - * Factory for creating [[BlockClient]] by using createClient. - * - * The factory maintains a connection pool to other hosts and should return the same [[BlockClient]] - * for the same remote host. It also shares a single worker thread pool for all [[BlockClient]]s. - */ -private[netty] -class BlockClientFactory(val conf: NettyConfig) extends Logging with Closeable { - - def this(sparkConf: SparkConf) = this(new NettyConfig(sparkConf)) - - /** A thread factory so the threads are named (for debugging). */ - private[this] val threadFactory = Utils.namedThreadFactory("spark-netty-client") - - /** Socket channel type, initialized by [[init]] depending ioMode. */ - private[this] var socketChannelClass: Class[_ <: Channel] = _ - - /** Thread pool shared by all clients. */ - private[this] var workerGroup: EventLoopGroup = _ - - private[this] val connectionPool = new ConcurrentHashMap[(String, Int), BlockClient] - - // The encoders are stateless and can be shared among multiple clients. - private[this] val encoder = new ClientRequestEncoder - private[this] val decoder = new ServerResponseDecoder - - init() - - /** Initialize [[socketChannelClass]] and [[workerGroup]] based on ioMode. */ - private def init(): Unit = { - def initNio(): Unit = { - socketChannelClass = classOf[NioSocketChannel] - workerGroup = new NioEventLoopGroup(conf.clientThreads, threadFactory) - } - def initEpoll(): Unit = { - socketChannelClass = classOf[EpollSocketChannel] - workerGroup = new EpollEventLoopGroup(conf.clientThreads, threadFactory) - } - - // For auto mode, first try epoll (only available on Linux), then nio. - conf.ioMode match { - case "nio" => initNio() - case "epoll" => initEpoll() - case "auto" => if (Epoll.isAvailable) initEpoll() else initNio() - } - } - - /** - * Create a new BlockFetchingClient connecting to the given remote host / port. - * - * This blocks until a connection is successfully established. - * - * Concurrency: This method is safe to call from multiple threads. - */ - def createClient(remoteHost: String, remotePort: Int): BlockClient = { - // Get connection from the connection pool first. - // If it is not found or not active, create a new one. - val cachedClient = connectionPool.get((remoteHost, remotePort)) - if (cachedClient != null && cachedClient.isActive) { - return cachedClient - } - - logDebug(s"Creating new connection to $remoteHost:$remotePort") - - // There is a chance two threads are creating two different clients connecting to the same host. - // But that's probably ok ... - - val handler = new BlockClientHandler - - val bootstrap = new Bootstrap - bootstrap.group(workerGroup) - .channel(socketChannelClass) - // Disable Nagle's Algorithm since we don't want packets to wait - .option(ChannelOption.TCP_NODELAY, java.lang.Boolean.TRUE) - .option(ChannelOption.SO_KEEPALIVE, java.lang.Boolean.TRUE) - .option[Integer](ChannelOption.CONNECT_TIMEOUT_MILLIS, conf.connectTimeoutMs) - - // Use pooled buffers to reduce temporary buffer allocation - bootstrap.option(ChannelOption.ALLOCATOR, createPooledByteBufAllocator()) - - bootstrap.handler(new ChannelInitializer[SocketChannel] { - override def initChannel(ch: SocketChannel): Unit = { - ch.pipeline - .addLast("clientRequestEncoder", encoder) - .addLast("frameDecoder", ProtocolUtils.createFrameDecoder()) - .addLast("serverResponseDecoder", decoder) - .addLast("handler", handler) - } - }) - - // Connect to the remote server - val cf: ChannelFuture = bootstrap.connect(remoteHost, remotePort) - if (!cf.awaitUninterruptibly(conf.connectTimeoutMs)) { - throw new TimeoutException( - s"Connecting to $remoteHost:$remotePort timed out (${conf.connectTimeoutMs} ms)") - } - - val client = new BlockClient(cf, handler) - connectionPool.put((remoteHost, remotePort), client) - client - } - - /** Close all connections in the connection pool, and shutdown the worker thread pool. */ - override def close(): Unit = { - val iter = connectionPool.entrySet().iterator() - while (iter.hasNext) { - val entry = iter.next() - entry.getValue.close() - connectionPool.remove(entry.getKey) - } - - if (workerGroup != null) { - workerGroup.shutdownGracefully() - } - } - - /** - * Create a pooled ByteBuf allocator but disables the thread-local cache. Thread-local caches - * are disabled because the ByteBufs are allocated by the event loop thread, but released by the - * executor thread rather than the event loop thread. Those thread-local caches actually delay - * the recycling of buffers, leading to larger memory usage. - */ - private def createPooledByteBufAllocator(): PooledByteBufAllocator = { - def getPrivateStaticField(name: String): Int = { - val f = PooledByteBufAllocator.DEFAULT.getClass.getDeclaredField(name) - f.setAccessible(true) - f.getInt(null) - } - new PooledByteBufAllocator( - PlatformDependent.directBufferPreferred(), - getPrivateStaticField("DEFAULT_NUM_HEAP_ARENA"), - getPrivateStaticField("DEFAULT_NUM_DIRECT_ARENA"), - getPrivateStaticField("DEFAULT_PAGE_SIZE"), - getPrivateStaticField("DEFAULT_MAX_ORDER"), - 0, // tinyCacheSize - 0, // smallCacheSize - 0 // normalCacheSize - ) - } -} diff --git a/core/src/main/scala/org/apache/spark/network/netty/BlockClientHandler.scala b/core/src/main/scala/org/apache/spark/network/netty/BlockClientHandler.scala deleted file mode 100644 index 5e28a07a461fa..0000000000000 --- a/core/src/main/scala/org/apache/spark/network/netty/BlockClientHandler.scala +++ /dev/null @@ -1,138 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network.netty - -import java.util.concurrent.ConcurrentHashMap - -import scala.concurrent.Promise - -import io.netty.channel.{ChannelHandlerContext, SimpleChannelInboundHandler} - -import org.apache.spark.Logging -import org.apache.spark.network.{BlockFetchFailureException, BlockUploadFailureException, BlockFetchingListener} - - -/** - * Handler that processes server responses, in response to requests issued from [[BlockClient]]. - * It works by tracking the list of outstanding requests (and their callbacks). - * - * Concurrency: thread safe and can be called from multiple threads. - */ -private[netty] -class BlockClientHandler extends SimpleChannelInboundHandler[ServerResponse] with Logging { - - /** Tracks the list of outstanding requests and their listeners on success/failure. */ - private[this] val outstandingFetches: java.util.Map[String, BlockFetchingListener] = - new ConcurrentHashMap[String, BlockFetchingListener] - - private[this] val outstandingUploads: java.util.Map[String, Promise[Unit]] = - new ConcurrentHashMap[String, Promise[Unit]] - - def addFetchRequest(blockId: String, listener: BlockFetchingListener): Unit = { - outstandingFetches.put(blockId, listener) - } - - def removeFetchRequest(blockId: String): Unit = { - outstandingFetches.remove(blockId) - } - - def addUploadRequest(blockId: String, promise: Promise[Unit]): Unit = { - outstandingUploads.put(blockId, promise) - } - - /** - * Fire the failure callback for all outstanding requests. This is called when we have an - * uncaught exception or pre-mature connection termination. - */ - private def failOutstandingRequests(cause: Throwable): Unit = { - val iter1 = outstandingFetches.entrySet().iterator() - while (iter1.hasNext) { - val entry = iter1.next() - entry.getValue.onBlockFetchFailure(entry.getKey, cause) - } - // TODO(rxin): Maybe we need to synchronize the access? Otherwise we could clear new requests - // as well. But I guess that is ok given the caller will fail as soon as any requests fail. - outstandingFetches.clear() - - val iter2 = outstandingUploads.entrySet().iterator() - while (iter2.hasNext) { - val entry = iter2.next() - entry.getValue.failure(new RuntimeException(s"Failed to upload block ${entry.getKey}")) - } - outstandingUploads.clear() - } - - override def channelUnregistered(ctx: ChannelHandlerContext): Unit = { - if (outstandingFetches.size() > 0) { - logError("Still have " + outstandingFetches.size() + " requests outstanding " + - s"when connection from ${ctx.channel.remoteAddress} is closed") - failOutstandingRequests(new RuntimeException( - s"Connection from ${ctx.channel.remoteAddress} closed")) - } - } - - override def exceptionCaught(ctx: ChannelHandlerContext, cause: Throwable): Unit = { - if (outstandingFetches.size() > 0) { - logError( - s"Exception in connection from ${ctx.channel.remoteAddress}: ${cause.getMessage}", cause) - failOutstandingRequests(cause) - } - ctx.close() - } - - override def channelRead0(ctx: ChannelHandlerContext, response: ServerResponse) { - val server = ctx.channel.remoteAddress.toString - response match { - case BlockFetchSuccess(blockId, buf) => - val listener = outstandingFetches.get(blockId) - if (listener == null) { - logWarning(s"Got a response for block $blockId from $server but it is not outstanding") - buf.release() - } else { - outstandingFetches.remove(blockId) - listener.onBlockFetchSuccess(blockId, buf) - buf.release() - } - case BlockFetchFailure(blockId, errorMsg) => - val listener = outstandingFetches.get(blockId) - if (listener == null) { - logWarning( - s"Got a response for block $blockId from $server ($errorMsg) but it is not outstanding") - } else { - outstandingFetches.remove(blockId) - listener.onBlockFetchFailure(blockId, new BlockFetchFailureException(blockId, errorMsg)) - } - case BlockUploadSuccess(blockId) => - val p = outstandingUploads.get(blockId) - if (p == null) { - logWarning(s"Got a response for upload $blockId from $server but it is not outstanding") - } else { - outstandingUploads.remove(blockId) - p.success(Unit) - } - case BlockUploadFailure(blockId, error) => - val p = outstandingUploads.get(blockId) - if (p == null) { - logWarning(s"Got a response for upload $blockId from $server but it is not outstanding") - } else { - outstandingUploads.remove(blockId) - p.failure(new BlockUploadFailureException(blockId)) - } - } - } -} diff --git a/core/src/main/scala/org/apache/spark/network/netty/BlockServer.scala b/core/src/main/scala/org/apache/spark/network/netty/BlockServer.scala deleted file mode 100644 index e2eb7c379f14d..0000000000000 --- a/core/src/main/scala/org/apache/spark/network/netty/BlockServer.scala +++ /dev/null @@ -1,127 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network.netty - -import java.io.Closeable -import java.net.InetSocketAddress - -import io.netty.bootstrap.ServerBootstrap -import io.netty.buffer.PooledByteBufAllocator -import io.netty.channel.epoll.{Epoll, EpollEventLoopGroup, EpollServerSocketChannel} -import io.netty.channel.nio.NioEventLoopGroup -import io.netty.channel.socket.SocketChannel -import io.netty.channel.socket.nio.NioServerSocketChannel -import io.netty.channel.{ChannelInitializer, ChannelFuture, ChannelOption} - -import org.apache.spark.Logging -import org.apache.spark.network.BlockDataManager -import org.apache.spark.util.Utils - - -/** - * Server for the [[NettyBlockTransferService]]. - */ -private[netty] -class BlockServer(conf: NettyConfig, dataProvider: BlockDataManager) - extends Closeable with Logging { - - def port: Int = _port - - def hostName: String = _hostName - - private var _port: Int = conf.serverPort - private var _hostName: String = "" - private var bootstrap: ServerBootstrap = _ - private var channelFuture: ChannelFuture = _ - - init() - - /** Initialize the server. */ - private def init(): Unit = { - bootstrap = new ServerBootstrap - val threadFactory = Utils.namedThreadFactory("spark-netty-server") - - // Use only one thread to accept connections, and 2 * num_cores for worker. - def initNio(): Unit = { - val bossGroup = new NioEventLoopGroup(conf.serverThreads, threadFactory) - val workerGroup = bossGroup - bootstrap.group(bossGroup, workerGroup).channel(classOf[NioServerSocketChannel]) - } - def initEpoll(): Unit = { - val bossGroup = new EpollEventLoopGroup(conf.serverThreads, threadFactory) - val workerGroup = bossGroup - bootstrap.group(bossGroup, workerGroup).channel(classOf[EpollServerSocketChannel]) - } - - conf.ioMode match { - case "nio" => initNio() - case "epoll" => initEpoll() - case "auto" => if (Epoll.isAvailable) initEpoll() else initNio() - } - - // Use pooled buffers to reduce temporary buffer allocation - bootstrap.option(ChannelOption.ALLOCATOR, PooledByteBufAllocator.DEFAULT) - bootstrap.childOption(ChannelOption.ALLOCATOR, PooledByteBufAllocator.DEFAULT) - - // Various (advanced) user-configured settings. - conf.backLog.foreach { backLog => - bootstrap.option[java.lang.Integer](ChannelOption.SO_BACKLOG, backLog) - } - conf.receiveBuf.foreach { receiveBuf => - bootstrap.childOption[java.lang.Integer](ChannelOption.SO_RCVBUF, receiveBuf) - } - conf.sendBuf.foreach { sendBuf => - bootstrap.childOption[java.lang.Integer](ChannelOption.SO_SNDBUF, sendBuf) - } - - bootstrap.childHandler(new ChannelInitializer[SocketChannel] { - override def initChannel(ch: SocketChannel): Unit = { - ch.pipeline - .addLast("frameDecoder", ProtocolUtils.createFrameDecoder()) - .addLast("clientRequestDecoder", new ClientRequestDecoder) - .addLast("serverResponseEncoder", new ServerResponseEncoder) - .addLast("handler", new BlockServerHandler(dataProvider)) - } - }) - - channelFuture = bootstrap.bind(new InetSocketAddress(_port)) - channelFuture.sync() - - val addr = channelFuture.channel.localAddress.asInstanceOf[InetSocketAddress] - _port = addr.getPort - // _hostName = addr.getHostName - _hostName = Utils.localHostName() - - logInfo(s"Server started ${_hostName}:${_port}") - } - - /** Shutdown the server. */ - def close(): Unit = { - if (channelFuture != null) { - channelFuture.channel().close().awaitUninterruptibly() - channelFuture = null - } - if (bootstrap != null && bootstrap.group() != null) { - bootstrap.group().shutdownGracefully() - } - if (bootstrap != null && bootstrap.childGroup() != null) { - bootstrap.childGroup().shutdownGracefully() - } - bootstrap = null - } -} diff --git a/core/src/main/scala/org/apache/spark/network/netty/BlockServerHandler.scala b/core/src/main/scala/org/apache/spark/network/netty/BlockServerHandler.scala deleted file mode 100644 index 44687f0b770e9..0000000000000 --- a/core/src/main/scala/org/apache/spark/network/netty/BlockServerHandler.scala +++ /dev/null @@ -1,125 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network.netty - -import io.netty.channel._ - -import org.apache.spark.Logging -import org.apache.spark.network.{ManagedBuffer, BlockDataManager} -import org.apache.spark.storage.StorageLevel - - -/** - * A handler that processes requests from clients and writes block data back. - * - * The messages should have been processed by the pipeline setup by BlockServerChannelInitializer. - */ -private[netty] class BlockServerHandler(dataProvider: BlockDataManager) - extends SimpleChannelInboundHandler[ClientRequest] with Logging { - - override def exceptionCaught(ctx: ChannelHandlerContext, cause: Throwable): Unit = { - logError(s"Exception in connection from ${ctx.channel.remoteAddress}", cause) - ctx.close() - } - - override def channelRead0(ctx: ChannelHandlerContext, request: ClientRequest): Unit = { - request match { - case BlockFetchRequest(blockIds) => - blockIds.foreach(processFetchRequest(ctx, _)) - case BlockUploadRequest(blockId, data, level) => - processUploadRequest(ctx, blockId, data, level) - } - } // end of channelRead0 - - private def processFetchRequest(ctx: ChannelHandlerContext, blockId: String): Unit = { - // A helper function to send error message back to the client. - def client = ctx.channel.remoteAddress.toString - - def respondWithError(error: String): Unit = { - ctx.writeAndFlush(new BlockFetchFailure(blockId, error)).addListener( - new ChannelFutureListener { - override def operationComplete(future: ChannelFuture) { - if (!future.isSuccess) { - // TODO: Maybe log the success case as well. - logError(s"Error sending error back to $client", future.cause) - ctx.close() - } - } - } - ) - } - - logTrace(s"Received request from $client to fetch block $blockId") - - // First make sure we can find the block. If not, send error back to the user. - var buf: ManagedBuffer = null - try { - buf = dataProvider.getBlockData(blockId) - } catch { - case e: Exception => - logError(s"Error opening block $blockId for request from $client", e) - respondWithError(e.getMessage) - return - } - - ctx.writeAndFlush(new BlockFetchSuccess(blockId, buf)).addListener( - new ChannelFutureListener { - override def operationComplete(future: ChannelFuture): Unit = { - if (future.isSuccess) { - logTrace(s"Sent block $blockId (${buf.size} B) back to $client") - } else { - logError( - s"Error sending block $blockId to $client; closing connection", future.cause) - ctx.close() - } - } - } - ) - } // end of processBlockRequest - - private def processUploadRequest( - ctx: ChannelHandlerContext, - blockId: String, - data: ManagedBuffer, - level: StorageLevel): Unit = { - // A helper function to send error message back to the client. - def client = ctx.channel.remoteAddress.toString - - try { - dataProvider.putBlockData(blockId, data, level) - ctx.writeAndFlush(BlockUploadSuccess(blockId)).addListener(new ChannelFutureListener { - override def operationComplete(future: ChannelFuture): Unit = { - if (!future.isSuccess) { - logError(s"Error sending an ACK back to client $client") - } - } - }) - } catch { - case e: Throwable => - logError(s"Error processing uploaded block $blockId", e) - ctx.writeAndFlush(BlockUploadFailure(blockId, e.getMessage)).addListener( - new ChannelFutureListener { - override def operationComplete(future: ChannelFuture): Unit = { - if (!future.isSuccess) { - logError(s"Error sending an ACK back to client $client") - } - } - }) - } - } // end of processUploadRequest -} diff --git a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockFetcher.scala b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockFetcher.scala new file mode 100644 index 0000000000000..aefd8a6335b2a --- /dev/null +++ b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockFetcher.scala @@ -0,0 +1,92 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.netty + +import java.nio.ByteBuffer +import java.util + +import org.apache.spark.Logging +import org.apache.spark.network.BlockFetchingListener +import org.apache.spark.serializer.Serializer +import org.apache.spark.network.buffer.ManagedBuffer +import org.apache.spark.network.client.{RpcResponseCallback, ChunkReceivedCallback, SluiceClient} +import org.apache.spark.storage.BlockId +import org.apache.spark.util.Utils + +/** + * Responsible for holding the state for a request for a single set of blocks. This assumes that + * the chunks will be returned in the same order as requested, and that there will be exactly + * one chunk per block. + * + * Upon receipt of any block, the listener will be called back. Upon failure part way through, + * the listener will receive a failure callback for each outstanding block. + */ +class NettyBlockFetcher( + serializer: Serializer, + client: SluiceClient, + blockIds: Seq[String], + listener: BlockFetchingListener) + extends Logging { + + require(blockIds.nonEmpty) + + val ser = serializer.newInstance() + + var streamHandle: ShuffleStreamHandle = _ + + val chunkCallback = new ChunkReceivedCallback { + // On receipt of a chunk, pass it upwards as a block. + def onSuccess(chunkIndex: Int, buffer: ManagedBuffer): Unit = Utils.logUncaughtExceptions { + buffer.retain() + listener.onBlockFetchSuccess(blockIds(chunkIndex), buffer) + } + + // On receipt of a failure, fail every block from chunkIndex onwards. + def onFailure(chunkIndex: Int, e: Throwable): Unit = { + blockIds.drop(chunkIndex).foreach { blockId => + listener.onBlockFetchFailure(blockId, e); + } + } + } + + // Send the RPC to open the given set of blocks. This will return a ShuffleStreamHandle. + client.sendRpc(ser.serialize(OpenBlocks(blockIds.map(BlockId.apply))).array(), + new RpcResponseCallback { + override def onSuccess(response: Array[Byte]): Unit = { + try { + streamHandle = ser.deserialize[ShuffleStreamHandle](ByteBuffer.wrap(response)) + logTrace(s"Successfully opened block set: $streamHandle! Preparing to fetch chunks.") + + // Immediately request all chunks -- we expect that the total size of the request is + // reasonable due to higher level chunking in [[ShuffleBlockFetcherIterator]]. + for (i <- 0 until streamHandle.numChunks) { + client.fetchChunk(streamHandle.streamId, i, chunkCallback) + } + } catch { + case e: Exception => + logError("Failed while starting block fetches", e) + blockIds.foreach(listener.onBlockFetchFailure(_, e)) + } + } + + override def onFailure(e: Throwable): Unit = { + logError("Failed while starting block fetches") + blockIds.foreach(listener.onBlockFetchFailure(_, e)) + } + }) +} diff --git a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala new file mode 100644 index 0000000000000..c8658ec98b82c --- /dev/null +++ b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala @@ -0,0 +1,59 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.netty + +import java.nio.ByteBuffer + +import org.apache.spark.Logging +import org.apache.spark.network.BlockDataManager +import org.apache.spark.serializer.Serializer +import org.apache.spark.network.buffer.ManagedBuffer +import org.apache.spark.network.client.RpcResponseCallback +import org.apache.spark.network.server.{DefaultStreamManager, RpcHandler} +import org.apache.spark.storage.BlockId + +import scala.collection.JavaConversions._ + +/** Request to read a set of blocks. Returns [[ShuffleStreamHandle]] to identify the stream. */ +case class OpenBlocks(blockIds: Seq[BlockId]) + +/** Identifier for a fixed number of chunks to read from a stream created by [[OpenBlocks]]. */ +case class ShuffleStreamHandle(streamId: Long, numChunks: Int) + +/** + * Serves requests to open blocks by simply registering one chunk per block requested. + */ +class NettyBlockRpcServer( + serializer: Serializer, + streamManager: DefaultStreamManager, + blockManager: BlockDataManager) + extends RpcHandler with Logging { + + override def receive(messageBytes: Array[Byte], responseContext: RpcResponseCallback): Unit = { + val ser = serializer.newInstance() + val message = ser.deserialize[AnyRef](ByteBuffer.wrap(messageBytes)) + logTrace(s"Received request: $message") + message match { + case OpenBlocks(blockIds) => + val blocks: Seq[ManagedBuffer] = blockIds.map(blockManager.getBlockData) + val streamId = streamManager.registerStream(blocks.iterator) + responseContext.onSuccess( + ser.serialize(new ShuffleStreamHandle(streamId, blocks.size)).array()) + } + } +} diff --git a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala index b7f979dccd0f5..7576d51e22175 100644 --- a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala +++ b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala @@ -17,38 +17,39 @@ package org.apache.spark.network.netty -import scala.concurrent.Future - import org.apache.spark.SparkConf import org.apache.spark.network._ +import org.apache.spark.network.buffer.ManagedBuffer +import org.apache.spark.network.client.{SluiceClient, SluiceClientFactory} +import org.apache.spark.network.server.{DefaultStreamManager, SluiceServer} +import org.apache.spark.network.util.{ConfigProvider, SluiceConfig} +import org.apache.spark.serializer.JavaSerializer import org.apache.spark.storage.StorageLevel +import org.apache.spark.util.Utils +import scala.concurrent.Future /** - * A [[BlockTransferService]] implementation based on Netty. - * - * See protocol.scala for the communication protocol between server and client + * A BlockTransferService that uses Netty to fetch a set of blocks at at time. */ -private[spark] -final class NettyBlockTransferService(conf: SparkConf) extends BlockTransferService { +class NettyBlockTransferService(conf: SparkConf) extends BlockTransferService { + var client: SluiceClient = _ - private[this] val nettyConf: NettyConfig = new NettyConfig(conf) + // TODO: Don't use Java serialization, use a more cross-version compatible serialization format. + val serializer = new JavaSerializer(conf) - private[this] var server: BlockServer = _ - private[this] var clientFactory: BlockClientFactory = _ + // Create a SluiceConfig using SparkConf. + private[this] val sluiceConf = new SluiceConfig( + new ConfigProvider { override def get(name: String) = conf.get(name) }) - override def init(blockDataManager: BlockDataManager): Unit = { - server = new BlockServer(nettyConf, blockDataManager) - clientFactory = new BlockClientFactory(nettyConf) - } + private[this] var server: SluiceServer = _ + private[this] var clientFactory: SluiceClientFactory = _ - override def close(): Unit = { - if (server != null) { - server.close() - } - if (clientFactory != null) { - clientFactory.close() - } + override def init(blockDataManager: BlockDataManager): Unit = { + val streamManager = new DefaultStreamManager + val rpcHandler = new NettyBlockRpcServer(serializer, streamManager, blockDataManager) + server = new SluiceServer(sluiceConf, streamManager, rpcHandler) + clientFactory = new SluiceClientFactory(sluiceConf) } override def fetchBlocks( @@ -56,29 +57,21 @@ final class NettyBlockTransferService(conf: SparkConf) extends BlockTransferServ port: Int, blockIds: Seq[String], listener: BlockFetchingListener): Unit = { - clientFactory.createClient(hostName, port).fetchBlocks(blockIds, listener) + val client = clientFactory.createClient(hostName, port) + new NettyBlockFetcher(serializer, client, blockIds, listener) } + override def hostName: String = Utils.localHostName() + + override def port: Int = server.getPort + + // TODO: Implement override def uploadBlock( hostname: String, port: Int, blockId: String, blockData: ManagedBuffer, - level: StorageLevel): Future[Unit] = { - clientFactory.createClient(hostName, port).uploadBlock(blockId, blockData, level) - } + level: StorageLevel): Future[Unit] = ??? - override def hostName: String = { - if (server == null) { - throw new IllegalStateException("Server has not been started") - } - server.hostName - } - - override def port: Int = { - if (server == null) { - throw new IllegalStateException("Server has not been started") - } - server.port - } + override def close(): Unit = server.close() } diff --git a/core/src/main/scala/org/apache/spark/network/netty/protocol.scala b/core/src/main/scala/org/apache/spark/network/netty/protocol.scala deleted file mode 100644 index 13942f3d0adcd..0000000000000 --- a/core/src/main/scala/org/apache/spark/network/netty/protocol.scala +++ /dev/null @@ -1,326 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network.netty - -import java.nio.ByteBuffer -import java.util.{List => JList} - -import io.netty.buffer.ByteBuf -import io.netty.channel.ChannelHandlerContext -import io.netty.channel.ChannelHandler.Sharable -import io.netty.handler.codec._ - -import org.apache.spark.Logging -import org.apache.spark.network.{NioManagedBuffer, NettyManagedBuffer, ManagedBuffer} -import org.apache.spark.storage.StorageLevel - - -/** Messages from the client to the server. */ -private[netty] -sealed trait ClientRequest { - def id: Byte -} - -/** - * Request to fetch a sequence of blocks from the server. A single [[BlockFetchRequest]] can - * correspond to multiple [[ServerResponse]]s. - */ -private[netty] -final case class BlockFetchRequest(blocks: Seq[String]) extends ClientRequest { - override def id = 0 -} - -/** - * Request to upload a block to the server. Currently the server does not ack the upload request. - */ -private[netty] -final case class BlockUploadRequest( - blockId: String, - data: ManagedBuffer, - level: StorageLevel) - extends ClientRequest { - require(blockId.length <= Byte.MaxValue) - override def id = 1 -} - - -/** Messages from server to client (usually in response to some [[ClientRequest]]. */ -private[netty] -sealed trait ServerResponse { - def id: Byte -} - -/** Response to [[BlockFetchRequest]] when a block exists and has been successfully fetched. */ -private[netty] -final case class BlockFetchSuccess(blockId: String, data: ManagedBuffer) extends ServerResponse { - require(blockId.length <= Byte.MaxValue) - override def id = 0 -} - -/** Response to [[BlockFetchRequest]] when there is an error fetching the block. */ -private[netty] -final case class BlockFetchFailure(blockId: String, error: String) extends ServerResponse { - require(blockId.length <= Byte.MaxValue) - override def id = 1 -} - -/** Response to [[BlockUploadRequest]] when a block is successfully uploaded. */ -private[netty] -final case class BlockUploadSuccess(blockId: String) extends ServerResponse { - require(blockId.length <= Byte.MaxValue) - override def id = 2 -} - -/** Response to [[BlockUploadRequest]] when there is an error uploading the block. */ -private[netty] -final case class BlockUploadFailure(blockId: String, error: String) extends ServerResponse { - require(blockId.length <= Byte.MaxValue) - override def id = 3 -} - - -/** - * Encoder for [[ClientRequest]] used in client side. - * - * This encoder is stateless so it is safe to be shared by multiple threads. - */ -@Sharable -private[netty] -final class ClientRequestEncoder extends MessageToMessageEncoder[ClientRequest] { - override def encode(ctx: ChannelHandlerContext, in: ClientRequest, out: JList[Object]): Unit = { - in match { - case BlockFetchRequest(blocks) => - // 8 bytes: frame size - // 1 byte: BlockFetchRequest vs BlockUploadRequest - // 4 byte: num blocks - // then for each block id write 1 byte for blockId.length and then blockId itself - val frameLength = 8 + 1 + 4 + blocks.size + blocks.map(_.size).fold(0)(_ + _) - val buf = ctx.alloc().buffer(frameLength) - - buf.writeLong(frameLength) - buf.writeByte(in.id) - buf.writeInt(blocks.size) - blocks.foreach { blockId => - ProtocolUtils.writeBlockId(buf, blockId) - } - - assert(buf.writableBytes() == 0) - out.add(buf) - - case BlockUploadRequest(blockId, data, level) => - // 8 bytes: frame size - // 1 byte: msg id (BlockFetchRequest vs BlockUploadRequest) - // 1 byte: blockId.length - // data itself (length can be derived from: frame size - 1 - blockId.length) - val headerLength = 8 + 1 + 1 + blockId.length + 5 - val frameLength = headerLength + data.size - val header = ctx.alloc().buffer(headerLength) - - // Call this before we add header to out so in case of exceptions - // we don't send anything at all. - val body = data.convertToNetty() - - header.writeLong(frameLength) - header.writeByte(in.id) - ProtocolUtils.writeBlockId(header, blockId) - header.writeInt(level.toInt) - header.writeByte(level.replication) - - assert(header.writableBytes() == 0) - out.add(header) - out.add(body) - } - } -} - - -/** - * Decoder in the server side to decode client requests. - * This decoder is stateless so it is safe to be shared by multiple threads. - * - * This assumes the inbound messages have been processed by a frame decoder created by - * [[ProtocolUtils.createFrameDecoder()]]. - */ -@Sharable -private[netty] -final class ClientRequestDecoder extends MessageToMessageDecoder[ByteBuf] { - override protected def decode(ctx: ChannelHandlerContext, in: ByteBuf, out: JList[AnyRef]): Unit = - { - val msgTypeId = in.readByte() - val decoded = msgTypeId match { - case 0 => // BlockFetchRequest - val numBlocks = in.readInt() - val blockIds = Seq.fill(numBlocks) { ProtocolUtils.readBlockId(in) } - BlockFetchRequest(blockIds) - - case 1 => // BlockUploadRequest - val blockId = ProtocolUtils.readBlockId(in) - val level = new StorageLevel(in.readInt(), in.readByte()) - - val ret = ByteBuffer.allocate(in.readableBytes()) - ret.put(in.nioBuffer()) - ret.flip() - BlockUploadRequest(blockId, new NioManagedBuffer(ret), level) - } - - assert(decoded.id == msgTypeId) - out.add(decoded) - } -} - - -/** - * Encoder used by the server side to encode server-to-client responses. - * This encoder is stateless so it is safe to be shared by multiple threads. - */ -@Sharable -private[netty] -final class ServerResponseEncoder extends MessageToMessageEncoder[ServerResponse] with Logging { - override def encode(ctx: ChannelHandlerContext, in: ServerResponse, out: JList[Object]): Unit = { - in match { - case BlockFetchSuccess(blockId, data) => - // Handle the body first so if we encounter an error getting the body, we can respond - // with an error instead. - var body: AnyRef = null - try { - body = data.convertToNetty() - } catch { - case e: Exception => - // Re-encode this message as BlockFetchFailure. - logError(s"Error opening block $blockId for client ${ctx.channel.remoteAddress}", e) - encode(ctx, new BlockFetchFailure(blockId, e.getMessage), out) - return - } - - // If we got here, body cannot be null - // 8 bytes = long for frame length - // 1 byte = message id (type) - // 1 byte = block id length - // followed by block id itself - val headerLength = 8 + 1 + 1 + blockId.length - val frameLength = headerLength + data.size - val header = ctx.alloc().buffer(headerLength) - header.writeLong(frameLength) - header.writeByte(in.id) - ProtocolUtils.writeBlockId(header, blockId) - - assert(header.writableBytes() == 0) - out.add(header) - out.add(body) - - case BlockFetchFailure(blockId, error) => - val frameLength = 8 + 1 + 1 + blockId.length + error.length - val buf = ctx.alloc().buffer(frameLength) - buf.writeLong(frameLength) - buf.writeByte(in.id) - ProtocolUtils.writeBlockId(buf, blockId) - buf.writeBytes(error.getBytes) - - assert(buf.writableBytes() == 0) - out.add(buf) - - case BlockUploadSuccess(blockId) => - val frameLength = 8 + 1 + 1 + blockId.length - val buf = ctx.alloc().buffer(frameLength) - buf.writeLong(frameLength) - buf.writeByte(in.id) - ProtocolUtils.writeBlockId(buf, blockId) - - assert(buf.writableBytes() == 0) - out.add(buf) - - case BlockUploadFailure(blockId, error) => - val frameLength = 8 + 1 + 1 + blockId.length + + error.length - val buf = ctx.alloc().buffer(frameLength) - buf.writeLong(frameLength) - buf.writeByte(in.id) - ProtocolUtils.writeBlockId(buf, blockId) - buf.writeBytes(error.getBytes) - - assert(buf.writableBytes() == 0) - out.add(buf) - } - } -} - - -/** - * Decoder in the client side to decode server responses. - * This decoder is stateless so it is safe to be shared by multiple threads. - * - * This assumes the inbound messages have been processed by a frame decoder created by - * [[ProtocolUtils.createFrameDecoder()]]. - */ -@Sharable -private[netty] -final class ServerResponseDecoder extends MessageToMessageDecoder[ByteBuf] { - override def decode(ctx: ChannelHandlerContext, in: ByteBuf, out: JList[AnyRef]): Unit = { - val msgId = in.readByte() - val decoded = msgId match { - case 0 => // BlockFetchSuccess - val blockId = ProtocolUtils.readBlockId(in) - in.retain() - BlockFetchSuccess(blockId, new NettyManagedBuffer(in)) - - case 1 => // BlockFetchFailure - val blockId = ProtocolUtils.readBlockId(in) - val errorBytes = new Array[Byte](in.readableBytes()) - in.readBytes(errorBytes) - BlockFetchFailure(blockId, new String(errorBytes)) - - case 2 => // BlockUploadSuccess - BlockUploadSuccess(ProtocolUtils.readBlockId(in)) - - case 3 => // BlockUploadFailure - val blockId = ProtocolUtils.readBlockId(in) - val errorBytes = new Array[Byte](in.readableBytes()) - in.readBytes(errorBytes) - BlockUploadFailure(blockId, new String(errorBytes)) - } - - assert(decoded.id == msgId) - out.add(decoded) - } -} - - -private[netty] object ProtocolUtils { - - /** LengthFieldBasedFrameDecoder used before all decoders. */ - def createFrameDecoder(): ByteToMessageDecoder = { - // maxFrameLength = 2G - // lengthFieldOffset = 0 - // lengthFieldLength = 8 - // lengthAdjustment = -8, i.e. exclude the 8 byte length itself - // initialBytesToStrip = 8, i.e. strip out the length field itself - new LengthFieldBasedFrameDecoder(Int.MaxValue, 0, 8, -8, 8) - } - - // TODO(rxin): Make sure these work for all charsets. - def readBlockId(in: ByteBuf): String = { - val numBytesToRead = in.readByte().toInt - val bytes = new Array[Byte](numBytesToRead) - in.readBytes(bytes) - new String(bytes) - } - - def writeBlockId(out: ByteBuf, blockId: String): Unit = { - out.writeByte(blockId.length) - out.writeBytes(blockId.getBytes) - } -} diff --git a/core/src/main/scala/org/apache/spark/network/nio/NioBlockTransferService.scala b/core/src/main/scala/org/apache/spark/network/nio/NioBlockTransferService.scala index e942b43d9cc4a..bce1069548437 100644 --- a/core/src/main/scala/org/apache/spark/network/nio/NioBlockTransferService.scala +++ b/core/src/main/scala/org/apache/spark/network/nio/NioBlockTransferService.scala @@ -19,12 +19,13 @@ package org.apache.spark.network.nio import java.nio.ByteBuffer -import scala.concurrent.Future - -import org.apache.spark.{SparkException, Logging, SecurityManager, SparkConf} import org.apache.spark.network._ +import org.apache.spark.network.buffer.{ManagedBuffer, NioManagedBuffer} import org.apache.spark.storage.{BlockId, StorageLevel} import org.apache.spark.util.Utils +import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkException} + +import scala.concurrent.Future /** @@ -153,12 +154,11 @@ final class NioBlockTransferService(conf: SparkConf, securityManager: SecurityMa val responseMessages = blockMessages.map(processBlockMessage).filter(_ != None).map(_.get) Some(new BlockMessageArray(responseMessages).toBufferMessage) } catch { - case e: Exception => { + case e: Exception => logError("Exception handling buffer message", e) val errorMessage = Message.createBufferMessage(msg.id) errorMessage.hasError = true Some(errorMessage) - } } case otherMessage: Any => @@ -174,13 +174,13 @@ final class NioBlockTransferService(conf: SparkConf, securityManager: SecurityMa case BlockMessage.TYPE_PUT_BLOCK => val msg = PutBlock(blockMessage.getId, blockMessage.getData, blockMessage.getLevel) logDebug("Received [" + msg + "]") - putBlock(msg.id.toString, msg.data, msg.level) + putBlock(msg.id, msg.data, msg.level) None case BlockMessage.TYPE_GET_BLOCK => val msg = new GetBlock(blockMessage.getId) logDebug("Received [" + msg + "]") - val buffer = getBlock(msg.id.toString) + val buffer = getBlock(msg.id) if (buffer == null) { return None } @@ -190,7 +190,7 @@ final class NioBlockTransferService(conf: SparkConf, securityManager: SecurityMa } } - private def putBlock(blockId: String, bytes: ByteBuffer, level: StorageLevel) { + private def putBlock(blockId: BlockId, bytes: ByteBuffer, level: StorageLevel) { val startTimeMs = System.currentTimeMillis() logDebug("PutBlock " + blockId + " started from " + startTimeMs + " with data: " + bytes) blockDataManager.putBlockData(blockId, new NioManagedBuffer(bytes), level) @@ -198,7 +198,7 @@ final class NioBlockTransferService(conf: SparkConf, securityManager: SecurityMa + " with data size: " + bytes.limit) } - private def getBlock(blockId: String): ByteBuffer = { + private def getBlock(blockId: BlockId): ByteBuffer = { val startTimeMs = System.currentTimeMillis() logDebug("GetBlock " + blockId + " started from " + startTimeMs) val buffer = blockDataManager.getBlockData(blockId) diff --git a/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockManager.scala b/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockManager.scala index 439981d232349..c35aa2481ad03 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockManager.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockManager.scala @@ -24,14 +24,14 @@ import java.util.concurrent.atomic.AtomicInteger import scala.collection.JavaConversions._ -import org.apache.spark.{SparkEnv, SparkConf, Logging} import org.apache.spark.executor.ShuffleWriteMetrics -import org.apache.spark.network.{FileSegmentManagedBuffer, ManagedBuffer} +import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, ManagedBuffer} import org.apache.spark.serializer.Serializer import org.apache.spark.shuffle.FileShuffleBlockManager.ShuffleFileGroup import org.apache.spark.storage._ -import org.apache.spark.util.{MetadataCleaner, MetadataCleanerType, TimeStampedHashMap} import org.apache.spark.util.collection.{PrimitiveKeyOpenHashMap, PrimitiveVector} +import org.apache.spark.util.{MetadataCleaner, MetadataCleanerType, TimeStampedHashMap} +import org.apache.spark.{Logging, SparkConf, SparkEnv} /** A group of writers for a ShuffleMapTask, one writer per reducer. */ private[spark] trait ShuffleWriterGroup { diff --git a/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockManager.scala b/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockManager.scala index 4ab34336d3f01..6a9fa4ec65d5d 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockManager.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockManager.scala @@ -21,7 +21,7 @@ import java.io._ import java.nio.ByteBuffer import org.apache.spark.SparkEnv -import org.apache.spark.network.{ManagedBuffer, FileSegmentManagedBuffer} +import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, ManagedBuffer} import org.apache.spark.storage._ /** diff --git a/core/src/main/scala/org/apache/spark/shuffle/ShuffleBlockManager.scala b/core/src/main/scala/org/apache/spark/shuffle/ShuffleBlockManager.scala index 63863cc0250a3..b521f0c7fc77e 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/ShuffleBlockManager.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/ShuffleBlockManager.scala @@ -18,8 +18,7 @@ package org.apache.spark.shuffle import java.nio.ByteBuffer - -import org.apache.spark.network.ManagedBuffer +import org.apache.spark.network.buffer.ManagedBuffer import org.apache.spark.storage.ShuffleBlockId private[spark] diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index ac0599f30ef22..4d8b5c1e1b084 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -17,15 +17,13 @@ package org.apache.spark.storage -import java.io.{File, InputStream, OutputStream, BufferedOutputStream, ByteArrayOutputStream} +import java.io.{BufferedOutputStream, ByteArrayOutputStream, File, InputStream, OutputStream} import java.nio.{ByteBuffer, MappedByteBuffer} -import scala.concurrent.ExecutionContext.Implicits.global - -import scala.collection.mutable import scala.collection.mutable.{ArrayBuffer, HashMap} -import scala.concurrent.{Await, Future} +import scala.concurrent.ExecutionContext.Implicits.global import scala.concurrent.duration._ +import scala.concurrent.{Await, Future} import scala.util.Random import akka.actor.{ActorSystem, Props} @@ -35,11 +33,11 @@ import org.apache.spark._ import org.apache.spark.executor._ import org.apache.spark.io.CompressionCodec import org.apache.spark.network._ +import org.apache.spark.network.buffer.{ManagedBuffer, NioManagedBuffer} import org.apache.spark.serializer.Serializer import org.apache.spark.shuffle.ShuffleManager import org.apache.spark.util._ - private[spark] sealed trait BlockValues private[spark] case class ByteBufferValues(buffer: ByteBuffer) extends BlockValues private[spark] case class IteratorValues(iterator: Iterator[Any]) extends BlockValues @@ -215,17 +213,17 @@ private[spark] class BlockManager( * Interface to get local block data. Throws an exception if the block cannot be found or * cannot be read successfully. */ - override def getBlockData(blockId: String): ManagedBuffer = { - val bid = BlockId(blockId) - if (bid.isShuffle) { - shuffleManager.shuffleBlockManager.getBlockData(bid.asInstanceOf[ShuffleBlockId]) + override def getBlockData(blockId: BlockId): ManagedBuffer = { + if (blockId.isShuffle) { + shuffleManager.shuffleBlockManager.getBlockData(blockId.asInstanceOf[ShuffleBlockId]) } else { - val blockBytesOpt = doGetLocal(bid, asBlockResult = false).asInstanceOf[Option[ByteBuffer]] + val blockBytesOpt = doGetLocal(blockId, asBlockResult = false) + .asInstanceOf[Option[ByteBuffer]] if (blockBytesOpt.isDefined) { val buffer = blockBytesOpt.get new NioManagedBuffer(buffer) } else { - throw new BlockNotFoundException(blockId) + throw new BlockNotFoundException(blockId.toString) } } } @@ -233,8 +231,8 @@ private[spark] class BlockManager( /** * Put the block locally, using the given storage level. */ - override def putBlockData(blockId: String, data: ManagedBuffer, level: StorageLevel): Unit = { - putBytes(BlockId(blockId), data.nioByteBuffer(), level) + override def putBlockData(blockId: BlockId, data: ManagedBuffer, level: StorageLevel): Unit = { + putBytes(blockId, data.nioByteBuffer(), level) } /** diff --git a/core/src/main/scala/org/apache/spark/storage/BlockNotFoundException.scala b/core/src/main/scala/org/apache/spark/storage/BlockNotFoundException.scala index 9ef453605f4f1..81f5f2d31dbd8 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockNotFoundException.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockNotFoundException.scala @@ -17,5 +17,4 @@ package org.apache.spark.storage - class BlockNotFoundException(blockId: String) extends Exception(s"Block $blockId not found") diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala index d095452a261db..23313fe9271fd 100644 --- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala +++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala @@ -19,14 +19,13 @@ package org.apache.spark.storage import java.util.concurrent.LinkedBlockingQueue -import scala.collection.mutable.ArrayBuffer -import scala.collection.mutable.HashSet -import scala.collection.mutable.Queue +import scala.collection.mutable.{ArrayBuffer, HashSet, Queue} -import org.apache.spark.{Logging, TaskContext} -import org.apache.spark.network.{ManagedBuffer, BlockFetchingListener, BlockTransferService} +import org.apache.spark.network.{BlockFetchingListener, BlockTransferService} import org.apache.spark.serializer.Serializer +import org.apache.spark.network.buffer.ManagedBuffer import org.apache.spark.util.{CompletionIterator, Utils} +import org.apache.spark.{Logging, TaskContext} /** @@ -228,7 +227,7 @@ final class ShuffleBlockFetcherIterator( while (iter.hasNext) { val blockId = iter.next() try { - val buf = blockManager.getBlockData(blockId.toString) + val buf = blockManager.getBlockData(blockId) shuffleMetrics.localBlocksFetched += 1 buf.retain() results.put(new FetchResult(blockId, 0, buf)) diff --git a/core/src/main/scala/org/apache/spark/storage/StorageLevel.scala b/core/src/main/scala/org/apache/spark/storage/StorageLevel.scala index 2fc7c7d9b8312..1e35abaab5353 100644 --- a/core/src/main/scala/org/apache/spark/storage/StorageLevel.scala +++ b/core/src/main/scala/org/apache/spark/storage/StorageLevel.scala @@ -42,7 +42,7 @@ class StorageLevel private( extends Externalizable { // TODO: Also add fields for caching priority, dataset ID, and flushing. - private[spark] def this(flags: Int, replication: Int) { + private def this(flags: Int, replication: Int) { this((flags & 8) != 0, (flags & 4) != 0, (flags & 2) != 0, (flags & 1) != 0, replication) } @@ -98,7 +98,6 @@ class StorageLevel private( } override def writeExternal(out: ObjectOutput) { - /* If the wire protocol changes, please also update [[ClientRequestEncoder]] */ out.writeByte(toInt) out.writeByte(_replication) } diff --git a/core/src/test/scala/org/apache/spark/network/netty/BlockClientFactorySuite.scala b/core/src/test/scala/org/apache/spark/network/netty/BlockClientFactorySuite.scala deleted file mode 100644 index 2d4baafcf03d0..0000000000000 --- a/core/src/test/scala/org/apache/spark/network/netty/BlockClientFactorySuite.scala +++ /dev/null @@ -1,91 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network.netty - -import scala.concurrent.{Await, future} -import scala.concurrent.duration._ -import scala.concurrent.ExecutionContext.Implicits.global - -import org.scalatest.{BeforeAndAfterAll, FunSuite} - -import org.apache.spark.SparkConf - - -class BlockClientFactorySuite extends FunSuite with BeforeAndAfterAll { - - private val conf = new SparkConf - private var server1: BlockServer = _ - private var server2: BlockServer = _ - - override def beforeAll() { - server1 = new BlockServer(new NettyConfig(conf), null) - server2 = new BlockServer(new NettyConfig(conf), null) - } - - override def afterAll() { - if (server1 != null) { - server1.close() - } - if (server2 != null) { - server2.close() - } - } - - test("BlockClients created are active and reused") { - val factory = new BlockClientFactory(conf) - val c1 = factory.createClient(server1.hostName, server1.port) - val c2 = factory.createClient(server1.hostName, server1.port) - val c3 = factory.createClient(server2.hostName, server2.port) - assert(c1.isActive) - assert(c3.isActive) - assert(c1 === c2) - assert(c1 !== c3) - factory.close() - } - - test("never return inactive clients") { - val factory = new BlockClientFactory(conf) - val c1 = factory.createClient(server1.hostName, server1.port) - c1.close() - - // Block until c1 is no longer active - val f = future { - while (c1.isActive) { - Thread.sleep(10) - } - } - Await.result(f, 3.seconds) - assert(!c1.isActive) - - // Create c2, which should be different from c1 - val c2 = factory.createClient(server1.hostName, server1.port) - assert(c1 !== c2) - factory.close() - } - - test("BlockClients are close when BlockClientFactory is stopped") { - val factory = new BlockClientFactory(conf) - val c1 = factory.createClient(server1.hostName, server1.port) - val c2 = factory.createClient(server2.hostName, server2.port) - assert(c1.isActive) - assert(c2.isActive) - factory.close() - assert(!c1.isActive) - assert(!c2.isActive) - } -} diff --git a/core/src/test/scala/org/apache/spark/network/netty/BlockClientHandlerSuite.scala b/core/src/test/scala/org/apache/spark/network/netty/BlockClientHandlerSuite.scala deleted file mode 100644 index 4c3a649081574..0000000000000 --- a/core/src/test/scala/org/apache/spark/network/netty/BlockClientHandlerSuite.scala +++ /dev/null @@ -1,114 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network.netty - -import java.nio.ByteBuffer - -import io.netty.buffer.Unpooled -import io.netty.channel.embedded.EmbeddedChannel - -import org.mockito.Mockito._ -import org.mockito.Matchers.{any, eq => meq} - -import org.scalatest.{FunSuite, PrivateMethodTester} - -import org.apache.spark.network._ - - -class BlockClientHandlerSuite extends FunSuite with PrivateMethodTester { - - /** Helper method to get num. outstanding requests from a private field using reflection. */ - private def sizeOfOutstandingRequests(handler: BlockClientHandler): Int = { - val f = handler.getClass.getDeclaredField( - "org$apache$spark$network$netty$BlockClientHandler$$outstandingFetches") - f.setAccessible(true) - f.get(handler).asInstanceOf[java.util.Map[_, _]].size - } - - test("handling block data (successful fetch)") { - val blockId = "test_block" - val blockData = "blahblahblahblahblah" - val handler = new BlockClientHandler - val listener = mock(classOf[BlockFetchingListener]) - handler.addFetchRequest(blockId, listener) - assert(sizeOfOutstandingRequests(handler) === 1) - - val channel = new EmbeddedChannel(handler) - val buf = ByteBuffer.allocate(blockData.size) // 4 bytes for the length field itself - buf.put(blockData.getBytes) - buf.flip() - - channel.writeInbound(BlockFetchSuccess(blockId, new NioManagedBuffer(buf))) - verify(listener, times(1)).onBlockFetchSuccess(meq(blockId), any()) - assert(sizeOfOutstandingRequests(handler) === 0) - assert(channel.finish() === false) - } - - test("handling error message (failed fetch)") { - val blockId = "test_block" - val handler = new BlockClientHandler - val listener = mock(classOf[BlockFetchingListener]) - handler.addFetchRequest(blockId, listener) - assert(sizeOfOutstandingRequests(handler) === 1) - - val channel = new EmbeddedChannel(handler) - channel.writeInbound(BlockFetchFailure(blockId, "some error msg")) - verify(listener, times(0)).onBlockFetchSuccess(any(), any()) - verify(listener, times(1)).onBlockFetchFailure(meq(blockId), any()) - assert(sizeOfOutstandingRequests(handler) === 0) - assert(channel.finish() === false) - } - - test("clear all outstanding request upon uncaught exception") { - val handler = new BlockClientHandler - val listener = mock(classOf[BlockFetchingListener]) - handler.addFetchRequest("b1", listener) - handler.addFetchRequest("b2", listener) - handler.addFetchRequest("b3", listener) - assert(sizeOfOutstandingRequests(handler) === 3) - - val channel = new EmbeddedChannel(handler) - channel.writeInbound(BlockFetchSuccess("b1", new NettyManagedBuffer(Unpooled.buffer()))) - channel.pipeline().fireExceptionCaught(new Exception("duh duh duh")) - - // should fail both b2 and b3 - verify(listener, times(1)).onBlockFetchSuccess(any(), any()) - verify(listener, times(2)).onBlockFetchFailure(any(), any()) - assert(sizeOfOutstandingRequests(handler) === 0) - assert(channel.finish() === false) - } - - test("clear all outstanding request upon connection close") { - val handler = new BlockClientHandler - val listener = mock(classOf[BlockFetchingListener]) - handler.addFetchRequest("c1", listener) - handler.addFetchRequest("c2", listener) - handler.addFetchRequest("c3", listener) - assert(sizeOfOutstandingRequests(handler) === 3) - - val channel = new EmbeddedChannel(handler) - channel.writeInbound(BlockFetchSuccess("c1", new NettyManagedBuffer(Unpooled.buffer()))) - channel.finish() - - // should fail both b2 and b3 - verify(listener, times(1)).onBlockFetchSuccess(any(), any()) - verify(listener, times(2)).onBlockFetchFailure(any(), any()) - assert(sizeOfOutstandingRequests(handler) === 0) - assert(channel.finish() === false) - } -} diff --git a/core/src/test/scala/org/apache/spark/network/netty/ProtocolSuite.scala b/core/src/test/scala/org/apache/spark/network/netty/ProtocolSuite.scala deleted file mode 100644 index 8d1b7276f4082..0000000000000 --- a/core/src/test/scala/org/apache/spark/network/netty/ProtocolSuite.scala +++ /dev/null @@ -1,113 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network.netty - -import io.netty.channel.embedded.EmbeddedChannel - -import org.scalatest.FunSuite - -import org.apache.spark.api.java.StorageLevels - - -/** - * Test client/server encoder/decoder protocol. - */ -class ProtocolSuite extends FunSuite { - - /** - * Helper to test server to client message protocol by encoding a message and decoding it. - */ - private def testServerToClient(msg: ServerResponse) { - val serverChannel = new EmbeddedChannel(new ServerResponseEncoder) - serverChannel.writeOutbound(msg) - - val clientChannel = new EmbeddedChannel( - ProtocolUtils.createFrameDecoder(), - new ServerResponseDecoder) - - // Drain all server outbound messages and write them to the client's server decoder. - while (!serverChannel.outboundMessages().isEmpty) { - clientChannel.writeInbound(serverChannel.readOutbound()) - } - - assert(clientChannel.inboundMessages().size === 1) - // Must put "msg === ..." instead of "... === msg" since only TestManagedBuffer equals is - // overridden. - assert(msg === clientChannel.readInbound()) - } - - /** - * Helper to test client to server message protocol by encoding a message and decoding it. - */ - private def testClientToServer(msg: ClientRequest) { - val clientChannel = new EmbeddedChannel(new ClientRequestEncoder) - clientChannel.writeOutbound(msg) - - val serverChannel = new EmbeddedChannel( - ProtocolUtils.createFrameDecoder(), - new ClientRequestDecoder) - - // Drain all client outbound messages and write them to the server's decoder. - while (!clientChannel.outboundMessages().isEmpty) { - serverChannel.writeInbound(clientChannel.readOutbound()) - } - - assert(serverChannel.inboundMessages().size === 1) - // Must put "msg === ..." instead of "... === msg" since only TestManagedBuffer equals is - // overridden. - assert(msg === serverChannel.readInbound()) - } - - test("server to client protocol - BlockFetchSuccess(\"a1234\", new TestManagedBuffer(10))") { - testServerToClient(BlockFetchSuccess("a1234", new TestManagedBuffer(10))) - } - - test("server to client protocol - BlockFetchSuccess(\"\", new TestManagedBuffer(0))") { - testServerToClient(BlockFetchSuccess("", new TestManagedBuffer(0))) - } - - test("server to client protocol - BlockFetchFailure(\"abcd\", \"this is an error\")") { - testServerToClient(BlockFetchFailure("abcd", "this is an error")) - } - - test("server to client protocol - BlockFetchFailure(\"\", \"\")") { - testServerToClient(BlockFetchFailure("", "")) - } - - test("client to server protocol - BlockFetchRequest(Seq.empty[String])") { - testClientToServer(BlockFetchRequest(Seq.empty[String])) - } - - test("client to server protocol - BlockFetchRequest(Seq(\"b1\"))") { - testClientToServer(BlockFetchRequest(Seq("b1"))) - } - - test("client to server protocol - BlockFetchRequest(Seq(\"b1\", \"b2\", \"b3\"))") { - testClientToServer(BlockFetchRequest(Seq("b1", "b2", "b3"))) - } - - test("client to server protocol - BlockUploadRequest(\"\", new TestManagedBuffer(0))") { - testClientToServer( - BlockUploadRequest("", new TestManagedBuffer(0), StorageLevels.MEMORY_AND_DISK)) - } - - test("client to server protocol - BlockUploadRequest(\"b_upload\", new TestManagedBuffer(10))") { - testClientToServer( - BlockUploadRequest("b_upload", new TestManagedBuffer(10), StorageLevels.MEMORY_AND_DISK_2)) - } -} diff --git a/core/src/test/scala/org/apache/spark/network/netty/ServerClientIntegrationSuite.scala b/core/src/test/scala/org/apache/spark/network/netty/ServerClientIntegrationSuite.scala deleted file mode 100644 index 35ff90a2dabc5..0000000000000 --- a/core/src/test/scala/org/apache/spark/network/netty/ServerClientIntegrationSuite.scala +++ /dev/null @@ -1,174 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network.netty - -import java.io.{RandomAccessFile, File} -import java.nio.ByteBuffer -import java.util.{Collections, HashSet} -import java.util.concurrent.{TimeUnit, Semaphore} - -import scala.collection.JavaConversions._ - -import io.netty.buffer.Unpooled - -import org.scalatest.{BeforeAndAfterAll, FunSuite} -import org.scalatest.concurrent.Eventually._ -import org.scalatest.time.Span -import org.scalatest.time.Seconds - -import org.apache.spark.SparkConf -import org.apache.spark.network._ -import org.apache.spark.storage.{BlockNotFoundException, StorageLevel} - - -/** -* Test cases that create real clients and servers and connect. -*/ -class ServerClientIntegrationSuite extends FunSuite with BeforeAndAfterAll { - - val bufSize = 100000 - var buf: ByteBuffer = _ - var testFile: File = _ - var server: BlockServer = _ - var clientFactory: BlockClientFactory = _ - - val bufferBlockId = "buffer_block" - val fileBlockId = "file_block" - - val fileContent = new Array[Byte](1024) - scala.util.Random.nextBytes(fileContent) - - override def beforeAll() = { - buf = ByteBuffer.allocate(bufSize) - for (i <- 1 to bufSize) { - buf.put(i.toByte) - } - buf.flip() - - testFile = File.createTempFile("netty-test-file", "txt") - val fp = new RandomAccessFile(testFile, "rw") - fp.write(fileContent) - fp.close() - - server = new BlockServer(new NettyConfig(new SparkConf), new BlockDataManager { - override def getBlockData(blockId: String): ManagedBuffer = { - if (blockId == bufferBlockId) { - new NioManagedBuffer(buf) - } else if (blockId == fileBlockId) { - new FileSegmentManagedBuffer(testFile, 10, testFile.length - 25) - } else { - throw new BlockNotFoundException(blockId) - } - } - - /** - * Put the block locally, using the given storage level. - */ - def putBlockData(blockId: String, data: ManagedBuffer, level: StorageLevel): Unit = ??? - }) - - clientFactory = new BlockClientFactory(new SparkConf) - } - - override def afterAll() = { - server.close() - clientFactory.close() - } - - /** A ByteBuf for buffer_block */ - lazy val byteBufferBlockReference = Unpooled.wrappedBuffer(buf) - - /** A ByteBuf for file_block */ - lazy val fileBlockReference = Unpooled.wrappedBuffer(fileContent, 10, fileContent.length - 25) - - def fetchBlocks(blockIds: Seq[String]): (Set[String], Set[ManagedBuffer], Set[String]) = { - val client = clientFactory.createClient(server.hostName, server.port) - val sem = new Semaphore(0) - val receivedBlockIds = Collections.synchronizedSet(new HashSet[String]) - val errorBlockIds = Collections.synchronizedSet(new HashSet[String]) - val receivedBuffers = Collections.synchronizedSet(new HashSet[ManagedBuffer]) - - client.fetchBlocks( - blockIds, - new BlockFetchingListener { - override def onBlockFetchFailure(blockId: String, exception: Throwable): Unit = { - errorBlockIds.add(blockId) - sem.release() - } - - override def onBlockFetchSuccess(blockId: String, data: ManagedBuffer): Unit = { - data.retain() - receivedBlockIds.add(blockId) - receivedBuffers.add(data) - sem.release() - } - } - ) - if (!sem.tryAcquire(blockIds.size, 5, TimeUnit.SECONDS)) { - fail("Timeout getting response from the server") - } - client.close() - (receivedBlockIds.toSet, receivedBuffers.toSet, errorBlockIds.toSet) - } - - test("fetch a ByteBuffer block") { - val (blockIds, buffers, failBlockIds) = fetchBlocks(Seq(bufferBlockId)) - assert(blockIds === Set(bufferBlockId)) - assert(buffers.map(_.convertToNetty()) === Set(byteBufferBlockReference)) - assert(failBlockIds.isEmpty) - buffers.foreach(_.release()) - } - - test("fetch a FileSegment block via zero-copy send") { - val (blockIds, buffers, failBlockIds) = fetchBlocks(Seq(fileBlockId)) - assert(blockIds === Set(fileBlockId)) - assert(buffers.map(_.convertToNetty()) === Set(fileBlockReference)) - assert(failBlockIds.isEmpty) - buffers.foreach(_.release()) - } - - test("fetch a non-existent block") { - val (blockIds, buffers, failBlockIds) = fetchBlocks(Seq("random-block")) - assert(blockIds.isEmpty) - assert(buffers.isEmpty) - assert(failBlockIds === Set("random-block")) - buffers.foreach(_.release()) - } - - test("fetch both ByteBuffer block and FileSegment block") { - val (blockIds, buffers, failBlockIds) = fetchBlocks(Seq(bufferBlockId, fileBlockId)) - assert(blockIds === Set(bufferBlockId, fileBlockId)) - assert(buffers.map(_.convertToNetty()) === Set(byteBufferBlockReference, fileBlockReference)) - assert(failBlockIds.isEmpty) - buffers.foreach(_.release()) - } - - test("fetch both ByteBuffer block and a non-existent block") { - val (blockIds, buffers, failBlockIds) = fetchBlocks(Seq(bufferBlockId, "random-block")) - assert(blockIds === Set(bufferBlockId)) - assert(buffers.map(_.convertToNetty()) === Set(byteBufferBlockReference)) - assert(failBlockIds === Set("random-block")) - buffers.foreach(_.release()) - } - - test("shutting down server should also close client") { - val client = clientFactory.createClient(server.hostName, server.port) - server.close() - eventually(timeout(Span(5, Seconds))) { assert(!client.isActive) } - } -} diff --git a/core/src/test/scala/org/apache/spark/network/netty/TestManagedBuffer.scala b/core/src/test/scala/org/apache/spark/network/netty/TestManagedBuffer.scala deleted file mode 100644 index e47e4d03fa898..0000000000000 --- a/core/src/test/scala/org/apache/spark/network/netty/TestManagedBuffer.scala +++ /dev/null @@ -1,72 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network.netty - -import java.io.InputStream -import java.nio.ByteBuffer - -import io.netty.buffer.Unpooled - -import org.apache.spark.network.{NettyManagedBuffer, ManagedBuffer} - - -/** - * A ManagedBuffer implementation that contains 0, 1, 2, 3, ..., (len-1). - * - * Used for testing. - */ -class TestManagedBuffer(len: Int) extends ManagedBuffer { - - require(len <= Byte.MaxValue) - - private val byteArray: Array[Byte] = Array.tabulate[Byte](len)(_.toByte) - - private val underlying = new NettyManagedBuffer(Unpooled.wrappedBuffer(byteArray)) - - override def size: Long = underlying.size - - override private[network] def convertToNetty(): AnyRef = underlying.convertToNetty() - - override def nioByteBuffer(): ByteBuffer = underlying.nioByteBuffer() - - override def inputStream(): InputStream = underlying.inputStream() - - override def toString: String = s"${getClass.getName}($len)" - - override def equals(other: Any): Boolean = other match { - case otherBuf: ManagedBuffer => - val nioBuf = otherBuf.nioByteBuffer() - if (nioBuf.remaining() != len) { - return false - } else { - var i = 0 - while (i < len) { - if (nioBuf.get() != i) { - return false - } - i += 1 - } - return true - } - case _ => false - } - - override def retain(): this.type = this - - override def release(): this.type = this -} diff --git a/core/src/test/scala/org/apache/spark/shuffle/hash/HashShuffleManagerSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/hash/HashShuffleManagerSuite.scala index ba47fe5e25b9b..6790388f96603 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/hash/HashShuffleManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/hash/HashShuffleManagerSuite.scala @@ -25,7 +25,7 @@ import org.scalatest.FunSuite import org.apache.spark.{SparkEnv, SparkContext, LocalSparkContext, SparkConf} import org.apache.spark.executor.ShuffleWriteMetrics -import org.apache.spark.network.{FileSegmentManagedBuffer, ManagedBuffer} +import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, ManagedBuffer} import org.apache.spark.serializer.JavaSerializer import org.apache.spark.shuffle.FileShuffleBlockManager import org.apache.spark.storage.{ShuffleBlockId, FileSegment} @@ -36,9 +36,9 @@ class HashShuffleManagerSuite extends FunSuite with LocalSparkContext { private def checkSegments(expected: FileSegment, buffer: ManagedBuffer) { assert(buffer.isInstanceOf[FileSegmentManagedBuffer]) val segment = buffer.asInstanceOf[FileSegmentManagedBuffer] - assert(expected.file.getCanonicalPath === segment.file.getCanonicalPath) - assert(expected.offset === segment.offset) - assert(expected.length === segment.length) + assert(expected.file.getCanonicalPath === segment.getFile.getCanonicalPath) + assert(expected.offset === segment.getOffset) + assert(expected.length === segment.getLength) } test("consolidated shuffle can write to shuffle group without messing existing offsets/lengths") { diff --git a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala index 7d4086313fcc1..3beb503b206f2 100644 --- a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala @@ -31,6 +31,7 @@ import org.scalatest.FunSuite import org.apache.spark.{SparkConf, TaskContext} import org.apache.spark.network._ +import org.apache.spark.network.buffer.ManagedBuffer import org.apache.spark.serializer.TestSerializer @@ -71,7 +72,7 @@ class ShuffleBlockFetcherIteratorSuite extends FunSuite { ShuffleBlockId(0, 1, 0) -> mock(classOf[ManagedBuffer]), ShuffleBlockId(0, 2, 0) -> mock(classOf[ManagedBuffer])) localBlocks.foreach { case (blockId, buf) => - doReturn(buf).when(blockManager).getBlockData(meq(blockId.toString)) + doReturn(buf).when(blockManager).getBlockData(meq(blockId)) } // Make sure remote blocks would return diff --git a/network/common/pom.xml b/network/common/pom.xml new file mode 100644 index 0000000000000..e3b7e328701b4 --- /dev/null +++ b/network/common/pom.xml @@ -0,0 +1,94 @@ + + + + + 4.0.0 + + org.apache.spark + spark-parent + 1.2.0-SNAPSHOT + ../../pom.xml + + + org.apache.spark + network + jar + Shuffle Streaming Service + http://spark.apache.org/ + + network + + + + + + io.netty + netty-all + + + org.slf4j + slf4j-api + + + + + com.google.guava + guava + provided + + + + + junit + junit + test + + + log4j + log4j + test + + + org.mockito + mockito-all + test + + + + + + target/java/classes + target/java/test-classes + + + org.apache.maven.plugins + maven-surefire-plugin + 2.17 + + false + + **/Test*.java + **/*Test.java + **/*Suite.java + + + + + + diff --git a/network/common/src/main/java/org/apache/spark/network/buffer/FileSegmentManagedBuffer.java b/network/common/src/main/java/org/apache/spark/network/buffer/FileSegmentManagedBuffer.java new file mode 100644 index 0000000000000..224f1e6c515ea --- /dev/null +++ b/network/common/src/main/java/org/apache/spark/network/buffer/FileSegmentManagedBuffer.java @@ -0,0 +1,146 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.buffer; + +import java.io.File; +import java.io.FileInputStream; +import java.io.IOException; +import java.io.InputStream; +import java.io.RandomAccessFile; +import java.nio.ByteBuffer; +import java.nio.channels.FileChannel; + +import com.google.common.base.Objects; +import com.google.common.io.ByteStreams; +import io.netty.channel.DefaultFileRegion; + +import org.apache.spark.network.util.JavaUtils; + +/** + * A {@link ManagedBuffer} backed by a segment in a file. + */ +public final class FileSegmentManagedBuffer extends ManagedBuffer { + + /** + * Memory mapping is expensive and can destabilize the JVM (SPARK-1145, SPARK-3889). + * Avoid unless there's a good reason not to. + */ + private static final long MIN_MEMORY_MAP_BYTES = 2 * 1024 * 1024; + + private final File file; + private final long offset; + private final long length; + + public FileSegmentManagedBuffer(File file, long offset, long length) { + this.file = file; + this.offset = offset; + this.length = length; + } + + @Override + public long size() { + return length; + } + + @Override + public ByteBuffer nioByteBuffer() throws IOException { + FileChannel channel = null; + try { + channel = new RandomAccessFile(file, "r").getChannel(); + // Just copy the buffer if it's sufficiently small, as memory mapping has a high overhead. + if (length < MIN_MEMORY_MAP_BYTES) { + ByteBuffer buf = ByteBuffer.allocate((int) length); + channel.read(buf, offset); + buf.flip(); + return buf; + } else { + return channel.map(FileChannel.MapMode.READ_ONLY, offset, length); + } + } catch (IOException e) { + try { + if (channel != null) { + long size = channel.size(); + throw new IOException("Error in reading " + this + " (actual file length " + size + ")", + e); + } + } catch (IOException ignored) { + // ignore + } + throw new IOException("Error in opening " + this, e); + } finally { + JavaUtils.closeQuietly(channel); + } + } + + @Override + public InputStream inputStream() throws IOException { + FileInputStream is = null; + try { + is = new FileInputStream(file); + is.skip(offset); + return ByteStreams.limit(is, length); + } catch (IOException e) { + try { + if (is != null) { + long size = file.length(); + throw new IOException("Error in reading " + this + " (actual file length " + size + ")", + e); + } + } catch (IOException ignored) { + // ignore + } finally { + JavaUtils.closeQuietly(is); + } + throw new IOException("Error in opening " + this, e); + } catch (RuntimeException e) { + JavaUtils.closeQuietly(is); + throw e; + } + } + + @Override + public ManagedBuffer retain() { + return this; + } + + @Override + public ManagedBuffer release() { + return this; + } + + @Override + public Object convertToNetty() throws IOException { + FileChannel fileChannel = new FileInputStream(file).getChannel(); + return new DefaultFileRegion(fileChannel, offset, length); + } + + public File getFile() { return file; } + + public long getOffset() { return offset; } + + public long getLength() { return length; } + + @Override + public String toString() { + return Objects.toStringHelper(this) + .add("file", file) + .add("offset", offset) + .add("length", length) + .toString(); + } +} diff --git a/network/common/src/main/java/org/apache/spark/network/buffer/ManagedBuffer.java b/network/common/src/main/java/org/apache/spark/network/buffer/ManagedBuffer.java new file mode 100644 index 0000000000000..1735f5540c61b --- /dev/null +++ b/network/common/src/main/java/org/apache/spark/network/buffer/ManagedBuffer.java @@ -0,0 +1,70 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.buffer; + +import java.io.IOException; +import java.io.InputStream; +import java.nio.ByteBuffer; + +/** + * This interface provides an immutable view for data in the form of bytes. The implementation + * should specify how the data is provided: + * + * - {@link FileSegmentManagedBuffer}: data backed by part of a file + * - {@link NioManagedBuffer}: data backed by a NIO ByteBuffer + * - {@link NettyManagedBuffer}: data backed by a Netty ByteBuf + * + * The concrete buffer implementation might be managed outside the JVM garbage collector. + * For example, in the case of {@link NettyManagedBuffer}, the buffers are reference counted. + * In that case, if the buffer is going to be passed around to a different thread, retain/release + * should be called. + */ +public abstract class ManagedBuffer { + + /** Number of bytes of the data. */ + public abstract long size(); + + /** + * Exposes this buffer's data as an NIO ByteBuffer. Changing the position and limit of the + * returned ByteBuffer should not affect the content of this buffer. + */ + public abstract ByteBuffer nioByteBuffer() throws IOException; + + /** + * Exposes this buffer's data as an InputStream. The underlying implementation does not + * necessarily check for the length of bytes read, so the caller is responsible for making sure + * it does not go over the limit. + */ + public abstract InputStream inputStream() throws IOException; + + /** + * Increment the reference count by one if applicable. + */ + public abstract ManagedBuffer retain(); + + /** + * If applicable, decrement the reference count by one and deallocates the buffer if the + * reference count reaches zero. + */ + public abstract ManagedBuffer release(); + + /** + * Convert the buffer into an Netty object, used to write the data out. + */ + public abstract Object convertToNetty() throws IOException; +} diff --git a/network/common/src/main/java/org/apache/spark/network/buffer/NettyManagedBuffer.java b/network/common/src/main/java/org/apache/spark/network/buffer/NettyManagedBuffer.java new file mode 100644 index 0000000000000..d928980423f1f --- /dev/null +++ b/network/common/src/main/java/org/apache/spark/network/buffer/NettyManagedBuffer.java @@ -0,0 +1,76 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.buffer; + +import java.io.IOException; +import java.io.InputStream; +import java.nio.ByteBuffer; + +import com.google.common.base.Objects; +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufInputStream; + +/** + * A {@link ManagedBuffer} backed by a Netty {@link ByteBuf}. + */ +public final class NettyManagedBuffer extends ManagedBuffer { + private final ByteBuf buf; + + public NettyManagedBuffer(ByteBuf buf) { + this.buf = buf; + } + + @Override + public long size() { + return buf.readableBytes(); + } + + @Override + public ByteBuffer nioByteBuffer() throws IOException { + return buf.nioBuffer(); + } + + @Override + public InputStream inputStream() throws IOException { + return new ByteBufInputStream(buf); + } + + @Override + public ManagedBuffer retain() { + buf.retain(); + return this; + } + + @Override + public ManagedBuffer release() { + buf.release(); + return this; + } + + @Override + public Object convertToNetty() throws IOException { + return buf; + } + + @Override + public String toString() { + return Objects.toStringHelper(this) + .add("buf", buf) + .toString(); + } +} diff --git a/network/common/src/main/java/org/apache/spark/network/buffer/NioManagedBuffer.java b/network/common/src/main/java/org/apache/spark/network/buffer/NioManagedBuffer.java new file mode 100644 index 0000000000000..3953ef89fbf88 --- /dev/null +++ b/network/common/src/main/java/org/apache/spark/network/buffer/NioManagedBuffer.java @@ -0,0 +1,75 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.buffer; + +import java.io.IOException; +import java.io.InputStream; +import java.nio.ByteBuffer; + +import com.google.common.base.Objects; +import io.netty.buffer.ByteBufInputStream; +import io.netty.buffer.Unpooled; + +/** + * A {@link ManagedBuffer} backed by {@link ByteBuffer}. + */ +public final class NioManagedBuffer extends ManagedBuffer { + private final ByteBuffer buf; + + public NioManagedBuffer(ByteBuffer buf) { + this.buf = buf; + } + + @Override + public long size() { + return buf.remaining(); + } + + @Override + public ByteBuffer nioByteBuffer() throws IOException { + return buf.duplicate(); + } + + @Override + public InputStream inputStream() throws IOException { + return new ByteBufInputStream(Unpooled.wrappedBuffer(buf)); + } + + @Override + public ManagedBuffer retain() { + return this; + } + + @Override + public ManagedBuffer release() { + return this; + } + + @Override + public Object convertToNetty() throws IOException { + return Unpooled.wrappedBuffer(buf); + } + + @Override + public String toString() { + return Objects.toStringHelper(this) + .add("buf", buf) + .toString(); + } +} + diff --git a/network/common/src/main/java/org/apache/spark/network/client/ChunkFetchFailureException.java b/network/common/src/main/java/org/apache/spark/network/client/ChunkFetchFailureException.java new file mode 100644 index 0000000000000..40a1fe67b1c5b --- /dev/null +++ b/network/common/src/main/java/org/apache/spark/network/client/ChunkFetchFailureException.java @@ -0,0 +1,37 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.client; + +/** + * General exception caused by a remote exception while fetching a chunk. + */ +public class ChunkFetchFailureException extends RuntimeException { + private final int chunkIndex; + + public ChunkFetchFailureException(int chunkIndex, String errorMsg, Throwable cause) { + super(errorMsg, cause); + this.chunkIndex = chunkIndex; + } + + public ChunkFetchFailureException(int chunkIndex, String errorMsg) { + super(errorMsg); + this.chunkIndex = chunkIndex; + } + + public int getChunkIndex() { return chunkIndex; } +} diff --git a/network/common/src/main/java/org/apache/spark/network/client/ChunkReceivedCallback.java b/network/common/src/main/java/org/apache/spark/network/client/ChunkReceivedCallback.java new file mode 100644 index 0000000000000..519e6cb470d0d --- /dev/null +++ b/network/common/src/main/java/org/apache/spark/network/client/ChunkReceivedCallback.java @@ -0,0 +1,47 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.client; + +import org.apache.spark.network.buffer.ManagedBuffer; + +/** + * Callback for the result of a single chunk result. For a single stream, the callbacks are + * guaranteed to be called by the same thread in the same order as the requests for chunks were + * made. + * + * Note that if a general stream failure occurs, all outstanding chunk requests may be failed. + */ +public interface ChunkReceivedCallback { + /** + * Called upon receipt of a particular chunk. + * + * The given buffer will initially have a refcount of 1, but will be release()'d as soon as this + * call returns. You must therefore either retain() the buffer or copy its contents before + * returning. + */ + void onSuccess(int chunkIndex, ManagedBuffer buffer); + + /** + * Called upon failure to fetch a particular chunk. Note that this may actually be called due + * to failure to fetch a prior chunk in this stream. + * + * After receiving a failure, the stream may or may not be valid. The client should not assume + * that the server's side of the stream has been closed. + */ + void onFailure(int chunkIndex, Throwable e); +} diff --git a/network/common/src/main/java/org/apache/spark/network/client/RpcResponseCallback.java b/network/common/src/main/java/org/apache/spark/network/client/RpcResponseCallback.java new file mode 100644 index 0000000000000..6ec960d795420 --- /dev/null +++ b/network/common/src/main/java/org/apache/spark/network/client/RpcResponseCallback.java @@ -0,0 +1,30 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.client; + +/** + * Callback for the result of a single RPC. This will be invoked once with either success or + * failure. + */ +public interface RpcResponseCallback { + /** Successful serialized result from server. */ + void onSuccess(byte[] response); + + /** Exception either propagated from server or raised on client side. */ + void onFailure(Throwable e); +} diff --git a/network/common/src/main/java/org/apache/spark/network/client/SluiceClient.java b/network/common/src/main/java/org/apache/spark/network/client/SluiceClient.java new file mode 100644 index 0000000000000..1f7d3b0234e38 --- /dev/null +++ b/network/common/src/main/java/org/apache/spark/network/client/SluiceClient.java @@ -0,0 +1,161 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.client; + +import java.io.Closeable; +import java.util.UUID; + +import io.netty.channel.ChannelFuture; +import io.netty.channel.ChannelFutureListener; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import org.apache.spark.network.protocol.StreamChunkId; +import org.apache.spark.network.protocol.request.ChunkFetchRequest; +import org.apache.spark.network.protocol.request.RpcRequest; + +/** + * Client for fetching consecutive chunks of a pre-negotiated stream. This API is intended to allow + * efficient transfer of a large amount of data, broken up into chunks with size ranging from + * hundreds of KB to a few MB. + * + * Note that while this client deals with the fetching of chunks from a stream (i.e., data plane), + * the actual setup of the streams is done outside the scope of Sluice. The convenience method + * "sendRPC" is provided to enable control plane communication between the client and server to + * perform this setup. + * + * For example, a typical workflow might be: + * client.sendRPC(new OpenFile("/foo")) --> returns StreamId = 100 + * client.fetchChunk(streamId = 100, chunkIndex = 0, callback) + * client.fetchChunk(streamId = 100, chunkIndex = 1, callback) + * ... + * client.sendRPC(new CloseStream(100)) + * + * Construct an instance of SluiceClient using {@link SluiceClientFactory}. A single SluiceClient + * may be used for multiple streams, but any given stream must be restricted to a single client, + * in order to avoid out-of-order responses. + * + * NB: This class is used to make requests to the server, while {@link SluiceClientHandler} is + * responsible for handling responses from the server. + * + * Concurrency: thread safe and can be called from multiple threads. + */ +public class SluiceClient implements Closeable { + private final Logger logger = LoggerFactory.getLogger(SluiceClient.class); + + private final ChannelFuture cf; + private final SluiceClientHandler handler; + + private final String serverAddr; + + SluiceClient(ChannelFuture cf, SluiceClientHandler handler) { + this.cf = cf; + this.handler = handler; + + if (cf != null && cf.channel() != null && cf.channel().remoteAddress() != null) { + serverAddr = cf.channel().remoteAddress().toString(); + } else { + serverAddr = ""; + } + } + + public boolean isActive() { + return cf.channel().isActive(); + } + + /** + * Requests a single chunk from the remote side, from the pre-negotiated streamId. + * + * Chunk indices go from 0 onwards. It is valid to request the same chunk multiple times, though + * some streams may not support this. + * + * Multiple fetchChunk requests may be outstanding simultaneously, and the chunks are guaranteed + * to be returned in the same order that they were requested, assuming only a single SluiceClient + * is used to fetch the chunks. + * + * @param streamId Identifier that refers to a stream in the remote StreamManager. This should + * be agreed upon by client and server beforehand. + * @param chunkIndex 0-based index of the chunk to fetch + * @param callback Callback invoked upon successful receipt of chunk, or upon any failure. + */ + public void fetchChunk( + long streamId, + final int chunkIndex, + final ChunkReceivedCallback callback) { + final long startTime = System.currentTimeMillis(); + logger.debug("Sending fetch chunk request {} to {}", chunkIndex, serverAddr); + + final StreamChunkId streamChunkId = new StreamChunkId(streamId, chunkIndex); + handler.addFetchRequest(streamChunkId, callback); + + cf.channel().writeAndFlush(new ChunkFetchRequest(streamChunkId)).addListener( + new ChannelFutureListener() { + @Override + public void operationComplete(ChannelFuture future) throws Exception { + if (future.isSuccess()) { + long timeTaken = System.currentTimeMillis() - startTime; + logger.debug("Sending request {} to {} took {} ms", streamChunkId, serverAddr, + timeTaken); + } else { + // Fail all blocks. + String errorMsg = String.format("Failed to send request %s to %s: %s", streamChunkId, + serverAddr, future.cause().getMessage()); + logger.error(errorMsg, future.cause()); + future.cause().printStackTrace(); + handler.removeFetchRequest(streamChunkId); + callback.onFailure(chunkIndex, new RuntimeException(errorMsg)); + } + } + }); + } + + /** + * Sends an opaque message to the RpcHandler on the server-side. The callback will be invoked + * with the server's response or upon any failure. + */ + public void sendRpc(byte[] message, final RpcResponseCallback callback) { + final long startTime = System.currentTimeMillis(); + logger.debug("Sending RPC to {}", serverAddr); + + final long tag = UUID.randomUUID().getLeastSignificantBits(); + handler.addRpcRequest(tag, callback); + + cf.channel().writeAndFlush(new RpcRequest(tag, message)).addListener( + new ChannelFutureListener() { + @Override + public void operationComplete(ChannelFuture future) throws Exception { + if (future.isSuccess()) { + long timeTaken = System.currentTimeMillis() - startTime; + logger.debug("Sending request {} to {} took {} ms", tag, serverAddr, timeTaken); + } else { + // Fail all blocks. + String errorMsg = String.format("Failed to send request %s to %s: %s", tag, + serverAddr, future.cause().getMessage()); + logger.error(errorMsg, future.cause()); + handler.removeRpcRequest(tag); + callback.onFailure(new RuntimeException(errorMsg)); + } + } + }); + } + + @Override + public void close() { + cf.channel().close(); + } +} diff --git a/network/common/src/main/java/org/apache/spark/network/client/SluiceClientFactory.java b/network/common/src/main/java/org/apache/spark/network/client/SluiceClientFactory.java new file mode 100644 index 0000000000000..17491dc3f8720 --- /dev/null +++ b/network/common/src/main/java/org/apache/spark/network/client/SluiceClientFactory.java @@ -0,0 +1,173 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.client; + +import java.io.Closeable; +import java.lang.reflect.Field; +import java.net.InetSocketAddress; +import java.net.SocketAddress; +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.TimeoutException; + +import io.netty.bootstrap.Bootstrap; +import io.netty.buffer.PooledByteBufAllocator; +import io.netty.channel.Channel; +import io.netty.channel.ChannelFuture; +import io.netty.channel.ChannelInitializer; +import io.netty.channel.ChannelOption; +import io.netty.channel.EventLoopGroup; +import io.netty.channel.socket.SocketChannel; +import io.netty.util.internal.PlatformDependent; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import org.apache.spark.network.protocol.request.ClientRequestEncoder; +import org.apache.spark.network.protocol.response.ServerResponseDecoder; +import org.apache.spark.network.util.IOMode; +import org.apache.spark.network.util.NettyUtils; +import org.apache.spark.network.util.SluiceConfig; + +/** + * Factory for creating {@link SluiceClient}s by using createClient. + * + * The factory maintains a connection pool to other hosts and should return the same + * {@link SluiceClient} for the same remote host. It also shares a single worker thread pool for + * all {@link SluiceClient}s. + */ +public class SluiceClientFactory implements Closeable { + private final Logger logger = LoggerFactory.getLogger(SluiceClientFactory.class); + + private final SluiceConfig conf; + private final Map connectionPool; + private final ClientRequestEncoder encoder; + private final ServerResponseDecoder decoder; + + private final Class socketChannelClass; + private final EventLoopGroup workerGroup; + + public SluiceClientFactory(SluiceConfig conf) { + this.conf = conf; + this.connectionPool = new ConcurrentHashMap(); + this.encoder = new ClientRequestEncoder(); + this.decoder = new ServerResponseDecoder(); + + IOMode ioMode = IOMode.valueOf(conf.ioMode()); + this.socketChannelClass = NettyUtils.getClientChannelClass(ioMode); + this.workerGroup = NettyUtils.createEventLoop(ioMode, conf.clientThreads(), "shuffle-client"); + } + + /** + * Create a new BlockFetchingClient connecting to the given remote host / port. + * + * This blocks until a connection is successfully established. + * + * Concurrency: This method is safe to call from multiple threads. + */ + public SluiceClient createClient(String remoteHost, int remotePort) throws TimeoutException { + // Get connection from the connection pool first. + // If it is not found or not active, create a new one. + InetSocketAddress address = new InetSocketAddress(remoteHost, remotePort); + SluiceClient cachedClient = connectionPool.get(address); + if (cachedClient != null && cachedClient.isActive()) { + return cachedClient; + } + + logger.debug("Creating new connection to " + address); + + // There is a chance two threads are creating two different clients connecting to the same host. + // But that's probably ok, as long as the caller hangs on to their client for a single stream. + final SluiceClientHandler handler = new SluiceClientHandler(); + + Bootstrap bootstrap = new Bootstrap(); + bootstrap.group(workerGroup) + .channel(socketChannelClass) + // Disable Nagle's Algorithm since we don't want packets to wait + .option(ChannelOption.TCP_NODELAY, true) + .option(ChannelOption.SO_KEEPALIVE, true) + .option(ChannelOption.CONNECT_TIMEOUT_MILLIS, conf.connectionTimeoutMs()); + + // Use pooled buffers to reduce temporary buffer allocation + bootstrap.option(ChannelOption.ALLOCATOR, createPooledByteBufAllocator()); + + bootstrap.handler(new ChannelInitializer() { + @Override + public void initChannel(SocketChannel ch) { + ch.pipeline() + .addLast("clientRequestEncoder", encoder) + .addLast("frameDecoder", NettyUtils.createFrameDecoder()) + .addLast("serverResponseDecoder", decoder) + .addLast("handler", handler); + } + }); + + // Connect to the remote server + ChannelFuture cf = bootstrap.connect(address); + if (!cf.awaitUninterruptibly(conf.connectionTimeoutMs())) { + throw new TimeoutException( + String.format("Connecting to %s timed out (%s ms)", address, conf.connectionTimeoutMs())); + } + + SluiceClient client = new SluiceClient(cf, handler); + connectionPool.put(address, client); + return client; + } + + /** Close all connections in the connection pool, and shutdown the worker thread pool. */ + @Override + public void close() { + for (SluiceClient client : connectionPool.values()) { + client.close(); + } + connectionPool.clear(); + + if (workerGroup != null) { + workerGroup.shutdownGracefully(); + } + } + + /** + * Create a pooled ByteBuf allocator but disables the thread-local cache. Thread-local caches + * are disabled because the ByteBufs are allocated by the event loop thread, but released by the + * executor thread rather than the event loop thread. Those thread-local caches actually delay + * the recycling of buffers, leading to larger memory usage. + */ + private PooledByteBufAllocator createPooledByteBufAllocator() { + return new PooledByteBufAllocator( + PlatformDependent.directBufferPreferred(), + getPrivateStaticField("DEFAULT_NUM_HEAP_ARENA"), + getPrivateStaticField("DEFAULT_NUM_DIRECT_ARENA"), + getPrivateStaticField("DEFAULT_PAGE_SIZE"), + getPrivateStaticField("DEFAULT_MAX_ORDER"), + 0, // tinyCacheSize + 0, // smallCacheSize + 0 // normalCacheSize + ); + } + + /** Used to get defaults from Netty's private static fields. */ + private int getPrivateStaticField(String name) { + try { + Field f = PooledByteBufAllocator.DEFAULT.getClass().getDeclaredField(name); + f.setAccessible(true); + return f.getInt(null); + } catch (Exception e) { + throw new RuntimeException(e); + } + } +} diff --git a/network/common/src/main/java/org/apache/spark/network/client/SluiceClientHandler.java b/network/common/src/main/java/org/apache/spark/network/client/SluiceClientHandler.java new file mode 100644 index 0000000000000..ed20b032931c3 --- /dev/null +++ b/network/common/src/main/java/org/apache/spark/network/client/SluiceClientHandler.java @@ -0,0 +1,155 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.client; + +import java.net.SocketAddress; +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; + +import com.google.common.annotations.VisibleForTesting; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.SimpleChannelInboundHandler; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import org.apache.spark.network.protocol.StreamChunkId; +import org.apache.spark.network.protocol.response.ChunkFetchFailure; +import org.apache.spark.network.protocol.response.ChunkFetchSuccess; +import org.apache.spark.network.protocol.response.RpcFailure; +import org.apache.spark.network.protocol.response.RpcResponse; +import org.apache.spark.network.protocol.response.ServerResponse; + +/** + * Handler that processes server responses, in response to requests issued from [[SluiceClient]]. + * It works by tracking the list of outstanding requests (and their callbacks). + * + * Concurrency: thread safe and can be called from multiple threads. + */ +public class SluiceClientHandler extends SimpleChannelInboundHandler { + private final Logger logger = LoggerFactory.getLogger(SluiceClientHandler.class); + + private final Map outstandingFetches = + new ConcurrentHashMap(); + + private final Map outstandingRpcs = + new ConcurrentHashMap(); + + public void addFetchRequest(StreamChunkId streamChunkId, ChunkReceivedCallback callback) { + outstandingFetches.put(streamChunkId, callback); + } + + public void removeFetchRequest(StreamChunkId streamChunkId) { + outstandingFetches.remove(streamChunkId); + } + + public void addRpcRequest(long tag, RpcResponseCallback callback) { + outstandingRpcs.put(tag, callback); + } + + public void removeRpcRequest(long tag) { + outstandingRpcs.remove(tag); + } + + /** + * Fire the failure callback for all outstanding requests. This is called when we have an + * uncaught exception or pre-mature connection termination. + */ + private void failOutstandingRequests(Throwable cause) { + for (Map.Entry entry : outstandingFetches.entrySet()) { + entry.getValue().onFailure(entry.getKey().chunkIndex, cause); + } + // TODO(rxin): Maybe we need to synchronize the access? Otherwise we could clear new requests + // as well. But I guess that is ok given the caller will fail as soon as any requests fail. + outstandingFetches.clear(); + } + + @Override + public void channelUnregistered(ChannelHandlerContext ctx) throws Exception { + if (outstandingFetches.size() > 0) { + SocketAddress remoteAddress = ctx.channel().remoteAddress(); + logger.error("Still have {} requests outstanding when contention from {} is closed", + outstandingFetches.size(), remoteAddress); + failOutstandingRequests(new RuntimeException("Connection from " + remoteAddress + " closed")); + } + super.channelUnregistered(ctx); + } + + @Override + public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception { + if (outstandingFetches.size() > 0) { + logger.error(String.format("Exception in connection from %s: %s", + ctx.channel().remoteAddress(), cause.getMessage()), cause); + failOutstandingRequests(cause); + } + ctx.close(); + } + + @Override + public void channelRead0(ChannelHandlerContext ctx, ServerResponse message) { + String server = ctx.channel().remoteAddress().toString(); + if (message instanceof ChunkFetchSuccess) { + ChunkFetchSuccess resp = (ChunkFetchSuccess) message; + ChunkReceivedCallback listener = outstandingFetches.get(resp.streamChunkId); + if (listener == null) { + logger.warn("Got a response for block {} from {} but it is not outstanding", + resp.streamChunkId, server); + resp.buffer.release(); + } else { + outstandingFetches.remove(resp.streamChunkId); + listener.onSuccess(resp.streamChunkId.chunkIndex, resp.buffer); + resp.buffer.release(); + } + } else if (message instanceof ChunkFetchFailure) { + ChunkFetchFailure resp = (ChunkFetchFailure) message; + ChunkReceivedCallback listener = outstandingFetches.get(resp.streamChunkId); + if (listener == null) { + logger.warn("Got a response for block {} from {} ({}) but it is not outstanding", + resp.streamChunkId, server, resp.errorString); + } else { + outstandingFetches.remove(resp.streamChunkId); + listener.onFailure(resp.streamChunkId.chunkIndex, + new ChunkFetchFailureException(resp.streamChunkId.chunkIndex, resp.errorString)); + } + } else if (message instanceof RpcResponse) { + RpcResponse resp = (RpcResponse) message; + RpcResponseCallback listener = outstandingRpcs.get(resp.tag); + if (listener == null) { + logger.warn("Got a response for RPC {} from {} ({} bytes) but it is not outstanding", + resp.tag, server, resp.response.length); + } else { + outstandingRpcs.remove(resp.tag); + listener.onSuccess(resp.response); + } + } else if (message instanceof RpcFailure) { + RpcFailure resp = (RpcFailure) message; + RpcResponseCallback listener = outstandingRpcs.get(resp.tag); + if (listener == null) { + logger.warn("Got a response for RPC {} from {} ({}) but it is not outstanding", + resp.tag, server, resp.errorString); + } else { + outstandingRpcs.remove(resp.tag); + listener.onFailure(new RuntimeException(resp.errorString)); + } + } + } + + @VisibleForTesting + public int numOutstandingRequests() { + return outstandingFetches.size(); + } +} diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/Encodable.java b/network/common/src/main/java/org/apache/spark/network/protocol/Encodable.java new file mode 100644 index 0000000000000..363ea5ecfa936 --- /dev/null +++ b/network/common/src/main/java/org/apache/spark/network/protocol/Encodable.java @@ -0,0 +1,35 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.protocol; + +import io.netty.buffer.ByteBuf; + +/** + * Interface for an object which can be encoded into a ByteBuf. Multiple Encodable objects are + * stored in a single, pre-allocated ByteBuf, so Encodables must also provide their length. + */ +public interface Encodable { + /** Number of bytes of the encoded form of this object. */ + int encodedLength(); + + /** + * Serializes this object by writing into the given ByteBuf. + * This method must write exactly encodedLength() bytes. + */ + void encode(ByteBuf buf); +} diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/StreamChunkId.java b/network/common/src/main/java/org/apache/spark/network/protocol/StreamChunkId.java new file mode 100644 index 0000000000000..d46a263884807 --- /dev/null +++ b/network/common/src/main/java/org/apache/spark/network/protocol/StreamChunkId.java @@ -0,0 +1,73 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.protocol; + +import com.google.common.base.Objects; +import io.netty.buffer.ByteBuf; + +/** +* Encapsulates a request for a particular chunk of a stream. +*/ +public final class StreamChunkId implements Encodable { + public final long streamId; + public final int chunkIndex; + + public StreamChunkId(long streamId, int chunkIndex) { + this.streamId = streamId; + this.chunkIndex = chunkIndex; + } + + @Override + public int encodedLength() { + return 8 + 4; + } + + public void encode(ByteBuf buffer) { + buffer.writeLong(streamId); + buffer.writeInt(chunkIndex); + } + + public static StreamChunkId decode(ByteBuf buffer) { + assert buffer.readableBytes() >= 8 + 4; + long streamId = buffer.readLong(); + int chunkIndex = buffer.readInt(); + return new StreamChunkId(streamId, chunkIndex); + } + + @Override + public int hashCode() { + return Objects.hashCode(streamId, chunkIndex); + } + + @Override + public boolean equals(Object other) { + if (other instanceof StreamChunkId) { + StreamChunkId o = (StreamChunkId) other; + return streamId == o.streamId && chunkIndex == o.chunkIndex; + } + return false; + } + + @Override + public String toString() { + return Objects.toStringHelper(this) + .add("streamId", streamId) + .add("chunkIndex", chunkIndex) + .toString(); + } +} diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/request/ChunkFetchRequest.java b/network/common/src/main/java/org/apache/spark/network/protocol/request/ChunkFetchRequest.java new file mode 100644 index 0000000000000..a79eb363cf58c --- /dev/null +++ b/network/common/src/main/java/org/apache/spark/network/protocol/request/ChunkFetchRequest.java @@ -0,0 +1,68 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.protocol.request; + +import com.google.common.base.Objects; +import io.netty.buffer.ByteBuf; + +import org.apache.spark.network.protocol.StreamChunkId; + +/** + * Request to fetch a sequence of a single chunk of a stream. This will correspond to a single + * {@link org.apache.spark.network.protocol.response.ServerResponse} (either success or failure). + */ +public final class ChunkFetchRequest implements ClientRequest { + public final StreamChunkId streamChunkId; + + public ChunkFetchRequest(StreamChunkId streamChunkId) { + this.streamChunkId = streamChunkId; + } + + @Override + public Type type() { return Type.ChunkFetchRequest; } + + @Override + public int encodedLength() { + return streamChunkId.encodedLength(); + } + + @Override + public void encode(ByteBuf buf) { + streamChunkId.encode(buf); + } + + public static ChunkFetchRequest decode(ByteBuf buf) { + return new ChunkFetchRequest(StreamChunkId.decode(buf)); + } + + @Override + public boolean equals(Object other) { + if (other instanceof ChunkFetchRequest) { + ChunkFetchRequest o = (ChunkFetchRequest) other; + return streamChunkId.equals(o.streamChunkId); + } + return false; + } + + @Override + public String toString() { + return Objects.toStringHelper(this) + .add("streamChunkId", streamChunkId) + .toString(); + } +} diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/request/ClientRequest.java b/network/common/src/main/java/org/apache/spark/network/protocol/request/ClientRequest.java new file mode 100644 index 0000000000000..db075c44b4cda --- /dev/null +++ b/network/common/src/main/java/org/apache/spark/network/protocol/request/ClientRequest.java @@ -0,0 +1,58 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.protocol.request; + +import io.netty.buffer.ByteBuf; + +import org.apache.spark.network.protocol.Encodable; + +/** Messages from the client to the server. */ +public interface ClientRequest extends Encodable { + /** Used to identify this request type. */ + Type type(); + + /** + * Preceding every serialized ClientRequest is the type, which allows us to deserialize + * the request. + */ + public static enum Type implements Encodable { + ChunkFetchRequest(0), RpcRequest(1); + + private final byte id; + + private Type(int id) { + assert id < 128 : "Cannot have more than 128 request types"; + this.id = (byte) id; + } + + public byte id() { return id; } + + @Override public int encodedLength() { return 1; } + + @Override public void encode(ByteBuf buf) { buf.writeByte(id); } + + public static Type decode(ByteBuf buf) { + byte id = buf.readByte(); + switch(id) { + case 0: return ChunkFetchRequest; + case 1: return RpcRequest; + default: throw new IllegalArgumentException("Unknown request type: " + id); + } + } + } +} diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/request/ClientRequestDecoder.java b/network/common/src/main/java/org/apache/spark/network/protocol/request/ClientRequestDecoder.java new file mode 100644 index 0000000000000..a937da4cecae0 --- /dev/null +++ b/network/common/src/main/java/org/apache/spark/network/protocol/request/ClientRequestDecoder.java @@ -0,0 +1,57 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.protocol.request; + +import java.util.List; + +import io.netty.buffer.ByteBuf; +import io.netty.channel.ChannelHandler; +import io.netty.channel.ChannelHandlerContext; +import io.netty.handler.codec.MessageToMessageDecoder; + +/** + * Decoder in the server side to decode client requests. + * This decoder is stateless so it is safe to be shared by multiple threads. + * + * This assumes the inbound messages have been processed by a frame decoder created by + * {@link org.apache.spark.network.util.NettyUtils#createFrameDecoder()}. + */ +@ChannelHandler.Sharable +public final class ClientRequestDecoder extends MessageToMessageDecoder { + + @Override + public void decode(ChannelHandlerContext ctx, ByteBuf in, List out) { + ClientRequest.Type msgType = ClientRequest.Type.decode(in); + ClientRequest decoded = decode(msgType, in); + assert decoded.type() == msgType; + assert in.readableBytes() == 0; + out.add(decoded); + } + + private ClientRequest decode(ClientRequest.Type msgType, ByteBuf in) { + switch (msgType) { + case ChunkFetchRequest: + return ChunkFetchRequest.decode(in); + + case RpcRequest: + return RpcRequest.decode(in); + + default: throw new IllegalArgumentException("Unexpected message type: " + msgType); + } + } +} diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/request/ClientRequestEncoder.java b/network/common/src/main/java/org/apache/spark/network/protocol/request/ClientRequestEncoder.java new file mode 100644 index 0000000000000..bcff4a0a25568 --- /dev/null +++ b/network/common/src/main/java/org/apache/spark/network/protocol/request/ClientRequestEncoder.java @@ -0,0 +1,46 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.protocol.request; + +import java.util.List; + +import io.netty.buffer.ByteBuf; +import io.netty.channel.ChannelHandler; +import io.netty.channel.ChannelHandlerContext; +import io.netty.handler.codec.MessageToMessageEncoder; + +/** + * Encoder for {@link ClientRequest} used in client side. + * + * This encoder is stateless so it is safe to be shared by multiple threads. + */ +@ChannelHandler.Sharable +public final class ClientRequestEncoder extends MessageToMessageEncoder { + @Override + public void encode(ChannelHandlerContext ctx, ClientRequest in, List out) { + ClientRequest.Type msgType = in.type(); + // Write 8 bytes for the frame's length, followed by the request type and request itself. + int frameLength = 8 + msgType.encodedLength() + in.encodedLength(); + ByteBuf buf = ctx.alloc().buffer(frameLength); + buf.writeLong(frameLength); + msgType.encode(buf); + in.encode(buf); + assert buf.writableBytes() == 0; + out.add(buf); + } +} diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/request/RpcRequest.java b/network/common/src/main/java/org/apache/spark/network/protocol/request/RpcRequest.java new file mode 100644 index 0000000000000..126370330f723 --- /dev/null +++ b/network/common/src/main/java/org/apache/spark/network/protocol/request/RpcRequest.java @@ -0,0 +1,81 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.protocol.request; + +import java.util.Arrays; + +import com.google.common.base.Objects; +import io.netty.buffer.ByteBuf; + +/** + * A generic RPC which is handled by a remote {@link org.apache.spark.network.server.RpcHandler}. + * This will correspond to a single {@link org.apache.spark.network.protocol.response.ServerResponse} + * (either success or failure). + */ +public final class RpcRequest implements ClientRequest { + /** Tag is used to link an RPC request with its response. */ + public final long tag; + + /** Serialized message to send to remote RpcHandler. */ + public final byte[] message; + + public RpcRequest(long tag, byte[] message) { + this.tag = tag; + this.message = message; + } + + @Override + public Type type() { return Type.RpcRequest; } + + @Override + public int encodedLength() { + return 8 + 4 + message.length; + } + + @Override + public void encode(ByteBuf buf) { + buf.writeLong(tag); + buf.writeInt(message.length); + buf.writeBytes(message); + } + + public static RpcRequest decode(ByteBuf buf) { + long tag = buf.readLong(); + int messageLen = buf.readInt(); + byte[] message = new byte[messageLen]; + buf.readBytes(message); + return new RpcRequest(tag, message); + } + + @Override + public boolean equals(Object other) { + if (other instanceof RpcRequest) { + RpcRequest o = (RpcRequest) other; + return tag == o.tag && Arrays.equals(message, o.message); + } + return false; + } + + @Override + public String toString() { + return Objects.toStringHelper(this) + .add("tag", tag) + .add("message", message) + .toString(); + } +} diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/response/ChunkFetchFailure.java b/network/common/src/main/java/org/apache/spark/network/protocol/response/ChunkFetchFailure.java new file mode 100644 index 0000000000000..3a57d71b4f3ea --- /dev/null +++ b/network/common/src/main/java/org/apache/spark/network/protocol/response/ChunkFetchFailure.java @@ -0,0 +1,78 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.protocol.response; + +import com.google.common.base.Objects; +import io.netty.buffer.ByteBuf; + +import org.apache.spark.network.protocol.StreamChunkId; + +/** + * Response to {@link org.apache.spark.network.protocol.request.ChunkFetchRequest} when there is an + * error fetching the chunk. + */ +public final class ChunkFetchFailure implements ServerResponse { + public final StreamChunkId streamChunkId; + public final String errorString; + + public ChunkFetchFailure(StreamChunkId streamChunkId, String errorString) { + this.streamChunkId = streamChunkId; + this.errorString = errorString; + } + + @Override + public Type type() { return Type.ChunkFetchFailure; } + + @Override + public int encodedLength() { + return streamChunkId.encodedLength() + 4 + errorString.getBytes().length; + } + + @Override + public void encode(ByteBuf buf) { + streamChunkId.encode(buf); + byte[] errorBytes = errorString.getBytes(); + buf.writeInt(errorBytes.length); + buf.writeBytes(errorBytes); + } + + public static ChunkFetchFailure decode(ByteBuf buf) { + StreamChunkId streamChunkId = StreamChunkId.decode(buf); + int numErrorStringBytes = buf.readInt(); + byte[] errorBytes = new byte[numErrorStringBytes]; + buf.readBytes(errorBytes); + return new ChunkFetchFailure(streamChunkId, new String(errorBytes)); + } + + @Override + public boolean equals(Object other) { + if (other instanceof ChunkFetchFailure) { + ChunkFetchFailure o = (ChunkFetchFailure) other; + return streamChunkId.equals(o.streamChunkId) && errorString.equals(o.errorString); + } + return false; + } + + @Override + public String toString() { + return Objects.toStringHelper(this) + .add("streamChunkId", streamChunkId) + .add("errorString", errorString) + .toString(); + } +} diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/response/ChunkFetchSuccess.java b/network/common/src/main/java/org/apache/spark/network/protocol/response/ChunkFetchSuccess.java new file mode 100644 index 0000000000000..874dc4f5940cf --- /dev/null +++ b/network/common/src/main/java/org/apache/spark/network/protocol/response/ChunkFetchSuccess.java @@ -0,0 +1,82 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.protocol.response; + +import com.google.common.base.Objects; +import io.netty.buffer.ByteBuf; + +import org.apache.spark.network.buffer.ManagedBuffer; +import org.apache.spark.network.buffer.NettyManagedBuffer; +import org.apache.spark.network.protocol.StreamChunkId; + +/** + * Response to {@link org.apache.spark.network.protocol.request.ChunkFetchRequest} when a chunk + * exists and has been successfully fetched. + * + * Note that the server-side encoding of this messages does NOT include the buffer itself, as this + * may be written by Netty in a more efficient manner (i.e., zero-copy write). + * Similarly, the client-side decoding will reuse the Netty ByteBuf as the buffer. + */ +public final class ChunkFetchSuccess implements ServerResponse { + public final StreamChunkId streamChunkId; + public final ManagedBuffer buffer; + + public ChunkFetchSuccess(StreamChunkId streamChunkId, ManagedBuffer buffer) { + this.streamChunkId = streamChunkId; + this.buffer = buffer; + } + + @Override + public Type type() { return Type.ChunkFetchSuccess; } + + @Override + public int encodedLength() { + return streamChunkId.encodedLength(); + } + + /** Encoding does NOT include buffer itself. See {@link ServerResponseEncoder}. */ + @Override + public void encode(ByteBuf buf) { + streamChunkId.encode(buf); + } + + /** Decoding uses the given ByteBuf as our data, and will retain() it. */ + public static ChunkFetchSuccess decode(ByteBuf buf) { + StreamChunkId streamChunkId = StreamChunkId.decode(buf); + buf.retain(); + NettyManagedBuffer managedBuf = new NettyManagedBuffer(buf.duplicate()); + return new ChunkFetchSuccess(streamChunkId, managedBuf); + } + + @Override + public boolean equals(Object other) { + if (other instanceof ChunkFetchSuccess) { + ChunkFetchSuccess o = (ChunkFetchSuccess) other; + return streamChunkId.equals(o.streamChunkId) && buffer.equals(o.buffer); + } + return false; + } + + @Override + public String toString() { + return Objects.toStringHelper(this) + .add("streamChunkId", streamChunkId) + .add("buffer", buffer) + .toString(); + } +} diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/response/RpcFailure.java b/network/common/src/main/java/org/apache/spark/network/protocol/response/RpcFailure.java new file mode 100644 index 0000000000000..274920b28bced --- /dev/null +++ b/network/common/src/main/java/org/apache/spark/network/protocol/response/RpcFailure.java @@ -0,0 +1,73 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.protocol.response; + +import com.google.common.base.Objects; +import io.netty.buffer.ByteBuf; + +/** Response to {@link org.apache.spark.network.protocol.request.RpcRequest} for a failed RPC. */ +public final class RpcFailure implements ServerResponse { + public final long tag; + public final String errorString; + + public RpcFailure(long tag, String errorString) { + this.tag = tag; + this.errorString = errorString; + } + + @Override + public Type type() { return Type.RpcFailure; } + + @Override + public int encodedLength() { + return 8 + 4 + errorString.getBytes().length; + } + + @Override + public void encode(ByteBuf buf) { + buf.writeLong(tag); + byte[] errorBytes = errorString.getBytes(); + buf.writeInt(errorBytes.length); + buf.writeBytes(errorBytes); + } + + public static RpcFailure decode(ByteBuf buf) { + long tag = buf.readLong(); + int numErrorStringBytes = buf.readInt(); + byte[] errorBytes = new byte[numErrorStringBytes]; + buf.readBytes(errorBytes); + return new RpcFailure(tag, new String(errorBytes)); + } + + @Override + public boolean equals(Object other) { + if (other instanceof RpcFailure) { + RpcFailure o = (RpcFailure) other; + return tag == o.tag && errorString.equals(o.errorString); + } + return false; + } + + @Override + public String toString() { + return Objects.toStringHelper(this) + .add("tag", tag) + .add("errorString", errorString) + .toString(); + } +} diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/response/RpcResponse.java b/network/common/src/main/java/org/apache/spark/network/protocol/response/RpcResponse.java new file mode 100644 index 0000000000000..0c6f8acdcdc4b --- /dev/null +++ b/network/common/src/main/java/org/apache/spark/network/protocol/response/RpcResponse.java @@ -0,0 +1,72 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.protocol.response; + +import java.util.Arrays; + +import com.google.common.base.Objects; +import io.netty.buffer.ByteBuf; + +/** Response to {@link org.apache.spark.network.protocol.request.RpcRequest} for a successful RPC. */ +public final class RpcResponse implements ServerResponse { + public final long tag; + public final byte[] response; + + public RpcResponse(long tag, byte[] response) { + this.tag = tag; + this.response = response; + } + + @Override + public Type type() { return Type.RpcResponse; } + + @Override + public int encodedLength() { return 8 + 4 + response.length; } + + @Override + public void encode(ByteBuf buf) { + buf.writeLong(tag); + buf.writeInt(response.length); + buf.writeBytes(response); + } + + public static RpcResponse decode(ByteBuf buf) { + long tag = buf.readLong(); + int responseLen = buf.readInt(); + byte[] response = new byte[responseLen]; + buf.readBytes(response); + return new RpcResponse(tag, response); + } + + @Override + public boolean equals(Object other) { + if (other instanceof RpcResponse) { + RpcResponse o = (RpcResponse) other; + return tag == o.tag && Arrays.equals(response, o.response); + } + return false; + } + + @Override + public String toString() { + return Objects.toStringHelper(this) + .add("tag", tag) + .add("response", response) + .toString(); + } +} diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/response/ServerResponse.java b/network/common/src/main/java/org/apache/spark/network/protocol/response/ServerResponse.java new file mode 100644 index 0000000000000..335f9e8ea69f9 --- /dev/null +++ b/network/common/src/main/java/org/apache/spark/network/protocol/response/ServerResponse.java @@ -0,0 +1,63 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.protocol.response; + +import io.netty.buffer.ByteBuf; + +import org.apache.spark.network.protocol.Encodable; + +/** + * Messages from server to client (usually in response to some + * {@link org.apache.spark.network.protocol.request.ClientRequest}. + */ +public interface ServerResponse extends Encodable { + /** Used to identify this response type. */ + Type type(); + + /** + * Preceding every serialized ServerResponse is the type, which allows us to deserialize + * the response. + */ + public static enum Type implements Encodable { + ChunkFetchSuccess(0), ChunkFetchFailure(1), RpcResponse(2), RpcFailure(3); + + private final byte id; + + private Type(int id) { + assert id < 128 : "Cannot have more than 128 response types"; + this.id = (byte) id; + } + + public byte id() { return id; } + + @Override public int encodedLength() { return 1; } + + @Override public void encode(ByteBuf buf) { buf.writeByte(id); } + + public static Type decode(ByteBuf buf) { + byte id = buf.readByte(); + switch(id) { + case 0: return ChunkFetchSuccess; + case 1: return ChunkFetchFailure; + case 2: return RpcResponse; + case 3: return RpcFailure; + default: throw new IllegalArgumentException("Unknown response type: " + id); + } + } + } +} diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/response/ServerResponseDecoder.java b/network/common/src/main/java/org/apache/spark/network/protocol/response/ServerResponseDecoder.java new file mode 100644 index 0000000000000..e06198284e620 --- /dev/null +++ b/network/common/src/main/java/org/apache/spark/network/protocol/response/ServerResponseDecoder.java @@ -0,0 +1,60 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.protocol.response; + +import java.util.List; + +import io.netty.buffer.ByteBuf; +import io.netty.channel.ChannelHandler; +import io.netty.channel.ChannelHandlerContext; +import io.netty.handler.codec.MessageToMessageDecoder; + +/** + * Decoder used by the client side to encode server-to-client responses. + * This encoder is stateless so it is safe to be shared by multiple threads. + */ +@ChannelHandler.Sharable +public final class ServerResponseDecoder extends MessageToMessageDecoder { + + @Override + public void decode(ChannelHandlerContext ctx, ByteBuf in, List out) { + ServerResponse.Type msgType = ServerResponse.Type.decode(in); + ServerResponse decoded = decode(msgType, in); + assert decoded.type() == msgType; + out.add(decoded); + } + + private ServerResponse decode(ServerResponse.Type msgType, ByteBuf in) { + switch (msgType) { + case ChunkFetchSuccess: + return ChunkFetchSuccess.decode(in); + + case ChunkFetchFailure: + return ChunkFetchFailure.decode(in); + + case RpcResponse: + return RpcResponse.decode(in); + + case RpcFailure: + return RpcFailure.decode(in); + + default: + throw new IllegalArgumentException("Unexpected message type: " + msgType); + } + } +} diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/response/ServerResponseEncoder.java b/network/common/src/main/java/org/apache/spark/network/protocol/response/ServerResponseEncoder.java new file mode 100644 index 0000000000000..069f42463a8fe --- /dev/null +++ b/network/common/src/main/java/org/apache/spark/network/protocol/response/ServerResponseEncoder.java @@ -0,0 +1,74 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.protocol.response; + +import java.util.List; + +import io.netty.buffer.ByteBuf; +import io.netty.channel.ChannelHandler; +import io.netty.channel.ChannelHandlerContext; +import io.netty.handler.codec.MessageToMessageEncoder; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * Encoder used by the server side to encode server-to-client responses. + * This encoder is stateless so it is safe to be shared by multiple threads. + */ +@ChannelHandler.Sharable +public final class ServerResponseEncoder extends MessageToMessageEncoder { + + private final Logger logger = LoggerFactory.getLogger(ServerResponseEncoder.class); + + @Override + public void encode(ChannelHandlerContext ctx, ServerResponse in, List out) { + Object body = null; + long bodyLength = 0; + + // Only ChunkFetchSuccesses have data besides the header. + // The body is used in order to enable zero-copy transfer for the payload. + if (in instanceof ChunkFetchSuccess) { + ChunkFetchSuccess resp = (ChunkFetchSuccess) in; + try { + bodyLength = resp.buffer.size(); + body = resp.buffer.convertToNetty(); + } catch (Exception e) { + // Re-encode this message as BlockFetchFailure. + logger.error(String.format("Error opening block %s for client %s", + resp.streamChunkId, ctx.channel().remoteAddress()), e); + encode(ctx, new ChunkFetchFailure(resp.streamChunkId, e.getMessage()), out); + return; + } + } + + ServerResponse.Type msgType = in.type(); + // All messages have the frame length, message type, and message itself. + int headerLength = 8 + msgType.encodedLength() + in.encodedLength(); + long frameLength = headerLength + bodyLength; + ByteBuf header = ctx.alloc().buffer(headerLength); + header.writeLong(frameLength); + msgType.encode(header); + in.encode(header); + assert header.writableBytes() == 0; + + out.add(header); + if (body != null && bodyLength > 0) { + out.add(body); + } + } +} diff --git a/network/common/src/main/java/org/apache/spark/network/server/DefaultStreamManager.java b/network/common/src/main/java/org/apache/spark/network/server/DefaultStreamManager.java new file mode 100644 index 0000000000000..04814d9a88c4a --- /dev/null +++ b/network/common/src/main/java/org/apache/spark/network/server/DefaultStreamManager.java @@ -0,0 +1,87 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.server; + +import java.util.Iterator; +import java.util.Map; +import java.util.Random; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.atomic.AtomicLong; + +import org.apache.spark.network.buffer.ManagedBuffer; + +/** + * StreamManager which allows registration of an Iterator, which are individually + * fetched as chunks by the client. + */ +public class DefaultStreamManager extends StreamManager { + private final AtomicLong nextStreamId; + private final Map streams; + + /** State of a single stream. */ + private static class StreamState { + final Iterator buffers; + + int curChunk = 0; + + StreamState(Iterator buffers) { + this.buffers = buffers; + } + } + + public DefaultStreamManager() { + // Start with a random stream id to help identifying different streams. + nextStreamId = new AtomicLong((long) new Random().nextInt(Integer.MAX_VALUE) * 1000); + streams = new ConcurrentHashMap(); + } + + @Override + public ManagedBuffer getChunk(long streamId, int chunkIndex) { + StreamState state = streams.get(streamId); + if (chunkIndex != state.curChunk) { + throw new IllegalStateException(String.format( + "Received out-of-order chunk index %s (expected %s)", chunkIndex, state.curChunk)); + } else if (!state.buffers.hasNext()) { + throw new IllegalStateException(String.format( + "Requested chunk index beyond end %s", chunkIndex)); + } + state.curChunk += 1; + return state.buffers.next(); + } + + @Override + public void connectionTerminated(long streamId) { + // Release all remaining buffers. + StreamState state = streams.remove(streamId); + if (state != null && state.buffers != null) { + while (state.buffers.hasNext()) { + state.buffers.next().release(); + } + } + } + + public long registerStream(Iterator buffers) { + long myStreamId = nextStreamId.getAndIncrement(); + streams.put(myStreamId, new StreamState(buffers)); + return myStreamId; + } + + public void unregisterStream(long streamId) { + streams.remove(streamId); + } +} diff --git a/network/common/src/main/java/org/apache/spark/network/server/RpcHandler.java b/network/common/src/main/java/org/apache/spark/network/server/RpcHandler.java new file mode 100644 index 0000000000000..abfbe66d875e8 --- /dev/null +++ b/network/common/src/main/java/org/apache/spark/network/server/RpcHandler.java @@ -0,0 +1,31 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.server; + +import org.apache.spark.network.client.RpcResponseCallback; + +/** + * Handler for sendRPC() messages sent by {@link org.apache.spark.network.client.SluiceClient}s. + */ +public interface RpcHandler { + /** + * Receive a single RPC message. Any exception thrown while in this method will be sent back to + * the client in string form as a standard RPC failure. + */ + void receive(byte[] message, RpcResponseCallback callback); +} diff --git a/network/common/src/main/java/org/apache/spark/network/server/SluiceServer.java b/network/common/src/main/java/org/apache/spark/network/server/SluiceServer.java new file mode 100644 index 0000000000000..aa81271024156 --- /dev/null +++ b/network/common/src/main/java/org/apache/spark/network/server/SluiceServer.java @@ -0,0 +1,124 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.server; + +import java.io.Closeable; +import java.net.InetSocketAddress; + +import io.netty.bootstrap.ServerBootstrap; +import io.netty.buffer.PooledByteBufAllocator; +import io.netty.channel.ChannelFuture; +import io.netty.channel.ChannelInitializer; +import io.netty.channel.ChannelOption; +import io.netty.channel.EventLoopGroup; +import io.netty.channel.socket.SocketChannel; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import org.apache.spark.network.protocol.request.ClientRequestDecoder; +import org.apache.spark.network.protocol.response.ServerResponseEncoder; +import org.apache.spark.network.util.IOMode; +import org.apache.spark.network.util.NettyUtils; +import org.apache.spark.network.util.SluiceConfig; + +/** + * Server for the efficient, low-level streaming service. + */ +public class SluiceServer implements Closeable { + private final Logger logger = LoggerFactory.getLogger(SluiceServer.class); + + private final SluiceConfig conf; + private final StreamManager streamManager; + private final RpcHandler rpcHandler; + + private ServerBootstrap bootstrap; + private ChannelFuture channelFuture; + private int port; + + public SluiceServer(SluiceConfig conf, StreamManager streamManager, RpcHandler rpcHandler) { + this.conf = conf; + this.streamManager = streamManager; + this.rpcHandler = rpcHandler; + + init(); + } + + public int getPort() { return port; } + + private void init() { + + IOMode ioMode = IOMode.valueOf(conf.ioMode()); + EventLoopGroup bossGroup = + NettyUtils.createEventLoop(ioMode, conf.serverThreads(), "shuffle-server"); + EventLoopGroup workerGroup = bossGroup; + + bootstrap = new ServerBootstrap() + .group(bossGroup, workerGroup) + .channel(NettyUtils.getServerChannelClass(ioMode)) + .option(ChannelOption.ALLOCATOR, PooledByteBufAllocator.DEFAULT) + .childOption(ChannelOption.ALLOCATOR, PooledByteBufAllocator.DEFAULT); + + if (conf.backLog() > 0) { + bootstrap.option(ChannelOption.SO_BACKLOG, conf.backLog()); + } + + if (conf.receiveBuf() > 0) { + bootstrap.childOption(ChannelOption.SO_RCVBUF, conf.receiveBuf()); + } + + if (conf.sendBuf() > 0) { + bootstrap.childOption(ChannelOption.SO_SNDBUF, conf.sendBuf()); + } + + bootstrap.childHandler(new ChannelInitializer() { + + @Override + protected void initChannel(SocketChannel ch) throws Exception { + ch.pipeline() + .addLast("frameDecoder", NettyUtils.createFrameDecoder()) + .addLast("clientRequestDecoder", new ClientRequestDecoder()) + .addLast("serverResponseEncoder", new ServerResponseEncoder()) + // NOTE: Chunks are currently guaranteed to be returned in the order of request, but this + // would require more logic to guarantee if this were not part of the same event loop. + .addLast("handler", new SluiceServerHandler(streamManager, rpcHandler)); + } + }); + + channelFuture = bootstrap.bind(new InetSocketAddress(conf.serverPort())); + channelFuture.syncUninterruptibly(); + + port = ((InetSocketAddress) channelFuture.channel().localAddress()).getPort(); + logger.debug("Shuffle server started on port :" + port); + } + + @Override + public void close() { + if (channelFuture != null) { + channelFuture.channel().close().awaitUninterruptibly(); + channelFuture = null; + } + if (bootstrap != null && bootstrap.group() != null) { + bootstrap.group().shutdownGracefully(); + } + if (bootstrap != null && bootstrap.childGroup() != null) { + bootstrap.childGroup().shutdownGracefully(); + } + bootstrap = null; + } + +} diff --git a/network/common/src/main/java/org/apache/spark/network/server/SluiceServerHandler.java b/network/common/src/main/java/org/apache/spark/network/server/SluiceServerHandler.java new file mode 100644 index 0000000000000..fad72fbfc711b --- /dev/null +++ b/network/common/src/main/java/org/apache/spark/network/server/SluiceServerHandler.java @@ -0,0 +1,153 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.server; + +import java.util.Set; + +import com.google.common.base.Throwables; +import com.google.common.collect.Sets; +import io.netty.channel.ChannelFuture; +import io.netty.channel.ChannelFutureListener; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.SimpleChannelInboundHandler; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import org.apache.spark.network.buffer.ManagedBuffer; +import org.apache.spark.network.client.RpcResponseCallback; +import org.apache.spark.network.protocol.Encodable; +import org.apache.spark.network.protocol.request.ChunkFetchRequest; +import org.apache.spark.network.protocol.request.ClientRequest; +import org.apache.spark.network.protocol.request.RpcRequest; +import org.apache.spark.network.protocol.response.ChunkFetchFailure; +import org.apache.spark.network.protocol.response.ChunkFetchSuccess; +import org.apache.spark.network.protocol.response.RpcFailure; +import org.apache.spark.network.protocol.response.RpcResponse; + +/** + * A handler that processes requests from clients and writes chunk data back. Each handler keeps + * track of which streams have been fetched via this channel, in order to clean them up if the + * channel is terminated (see #channelUnregistered). + * + * The messages should have been processed by the pipeline setup by {@link SluiceServer}. + */ +public class SluiceServerHandler extends SimpleChannelInboundHandler { + private final Logger logger = LoggerFactory.getLogger(SluiceServerHandler.class); + + /** Returns each chunk part of a stream. */ + private final StreamManager streamManager; + + /** Handles all RPC messages. */ + private final RpcHandler rpcHandler; + + /** List of all stream ids that have been read on this handler, used for cleanup. */ + private final Set streamIds; + + public SluiceServerHandler(StreamManager streamManager, RpcHandler rpcHandler) { + this.streamManager = streamManager; + this.rpcHandler = rpcHandler; + this.streamIds = Sets.newHashSet(); + } + + @Override + public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception { + logger.error("Exception in connection from " + ctx.channel().remoteAddress(), cause); + ctx.close(); + super.exceptionCaught(ctx, cause); + } + + @Override + public void channelUnregistered(ChannelHandlerContext ctx) throws Exception { + super.channelUnregistered(ctx); + // Inform the StreamManager that these streams will no longer be read from. + for (long streamId : streamIds) { + streamManager.connectionTerminated(streamId); + } + } + + @Override + public void channelRead0(ChannelHandlerContext ctx, ClientRequest request) { + if (request instanceof ChunkFetchRequest) { + processFetchRequest(ctx, (ChunkFetchRequest) request); + } else if (request instanceof RpcRequest) { + processRpcRequest(ctx, (RpcRequest) request); + } else { + throw new IllegalArgumentException("Unknown request type: " + request); + } + } + + private void processFetchRequest(final ChannelHandlerContext ctx, final ChunkFetchRequest req) { + final String client = ctx.channel().remoteAddress().toString(); + streamIds.add(req.streamChunkId.streamId); + + logger.trace("Received req from {} to fetch block {}", client, req.streamChunkId); + + ManagedBuffer buf; + try { + buf = streamManager.getChunk(req.streamChunkId.streamId, req.streamChunkId.chunkIndex); + } catch (Exception e) { + logger.error(String.format( + "Error opening block %s for request from %s", req.streamChunkId, client), e); + respond(ctx, new ChunkFetchFailure(req.streamChunkId, Throwables.getStackTraceAsString(e))); + return; + } + + respond(ctx, new ChunkFetchSuccess(req.streamChunkId, buf)); + } + + private void processRpcRequest(final ChannelHandlerContext ctx, final RpcRequest req) { + try { + rpcHandler.receive(req.message, new RpcResponseCallback() { + @Override + public void onSuccess(byte[] response) { + respond(ctx, new RpcResponse(req.tag, response)); + } + + @Override + public void onFailure(Throwable e) { + respond(ctx, new RpcFailure(req.tag, Throwables.getStackTraceAsString(e))); + } + }); + } catch (Exception e) { + logger.error("Error while invoking RpcHandler#receive() on RPC tag " + req.tag, e); + respond(ctx, new RpcFailure(req.tag, Throwables.getStackTraceAsString(e))); + } + } + + /** + * Responds to a single message with some Encodable object. If a failure occurs while sending, + * it will be logged and the channel closed. + */ + private void respond(final ChannelHandlerContext ctx, final Encodable result) { + final String remoteAddress = ctx.channel().remoteAddress().toString(); + ctx.writeAndFlush(result).addListener( + new ChannelFutureListener() { + @Override + public void operationComplete(ChannelFuture future) throws Exception { + if (future.isSuccess()) { + logger.trace(String.format("Sent result %s to client %s", result, remoteAddress)); + } else { + logger.error(String.format("Error sending result %s to %s; closing connection", + result, remoteAddress), future.cause()); + ctx.close(); + } + } + } + ); + } +} diff --git a/network/common/src/main/java/org/apache/spark/network/server/StreamManager.java b/network/common/src/main/java/org/apache/spark/network/server/StreamManager.java new file mode 100644 index 0000000000000..2e07f5a270cb9 --- /dev/null +++ b/network/common/src/main/java/org/apache/spark/network/server/StreamManager.java @@ -0,0 +1,52 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.server; + +import org.apache.spark.network.buffer.ManagedBuffer; + +/** + * The StreamManager is used to fetch individual chunks from a stream. This is used in + * {@link SluiceServerHandler} in order to respond to fetchChunk() requests. Creation of the + * stream is outside the scope of Sluice, but a given stream is guaranteed to be read by only one + * client connection, meaning that getChunk() for a particular stream will be called serially and + * that once the connection associated with the stream is closed, that stream will never be used + * again. + */ +public abstract class StreamManager { + /** + * Called in response to a fetchChunk() request. The returned buffer will be passed as-is to the + * client. A single stream will be associated with a single TCP connection, so this method + * will not be called in parallel for a particular stream. + * + * Chunks may be requested in any order, and requests may be repeated, but it is not required + * that implementations support this behavior. + * + * The returned ManagedBuffer will be release()'d after being written to the network. + * + * @param streamId id of a stream that has been previously registered with the StreamManager. + * @param chunkIndex 0-indexed chunk of the stream that's requested + */ + public abstract ManagedBuffer getChunk(long streamId, int chunkIndex); + + /** + * Indicates that the TCP connection that was tied to the given stream has been terminated. After + * this occurs, we are guaranteed not to read from the stream again, so any state can be cleaned + * up. + */ + public void connectionTerminated(long streamId) { } +} diff --git a/network/common/src/main/java/org/apache/spark/network/util/ConfigProvider.java b/network/common/src/main/java/org/apache/spark/network/util/ConfigProvider.java new file mode 100644 index 0000000000000..2dc0e248ae835 --- /dev/null +++ b/network/common/src/main/java/org/apache/spark/network/util/ConfigProvider.java @@ -0,0 +1,52 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.util; + +import java.util.NoSuchElementException; + +/** + * Provides a mechanism for constructing a {@link SluiceConfig} using some sort of configuration. + */ +public abstract class ConfigProvider { + /** Obtains the value of the given config, throws NoSuchElementException if it doesn't exist. */ + public abstract String get(String name); + + public String get(String name, String defaultValue) { + try { + return get(name); + } catch (NoSuchElementException e) { + return defaultValue; + } + } + + public int getInt(String name, int defaultValue) { + return Integer.parseInt(get(name, Integer.toString(defaultValue))); + } + + public long getLong(String name, long defaultValue) { + return Long.parseLong(get(name, Long.toString(defaultValue))); + } + + public double getDouble(String name, double defaultValue) { + return Double.parseDouble(get(name, Double.toString(defaultValue))); + } + + public boolean getBoolean(String name, boolean defaultValue) { + return Boolean.parseBoolean(get(name, Boolean.toString(defaultValue))); + } +} diff --git a/network/common/src/main/java/org/apache/spark/network/util/DefaultConfigProvider.java b/network/common/src/main/java/org/apache/spark/network/util/DefaultConfigProvider.java new file mode 100644 index 0000000000000..cef88c0091eff --- /dev/null +++ b/network/common/src/main/java/org/apache/spark/network/util/DefaultConfigProvider.java @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.util; + +import java.util.NoSuchElementException; + +/** Uses System properties to obtain config values. */ +public class DefaultConfigProvider extends ConfigProvider { + @Override + public String get(String name) { + String value = System.getProperty(name); + if (value == null) { + throw new NoSuchElementException(name); + } + return value; + } +} diff --git a/network/common/src/main/java/org/apache/spark/network/util/IOMode.java b/network/common/src/main/java/org/apache/spark/network/util/IOMode.java new file mode 100644 index 0000000000000..91cb3e0e6f8f6 --- /dev/null +++ b/network/common/src/main/java/org/apache/spark/network/util/IOMode.java @@ -0,0 +1,27 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.util; + +/** + * Selector for which form of low-level IO we should use. + * NIO is always available, while EPOLL is only available on certain machines. + * AUTO is used to select EPOLL if it's available, or NIO otherwise. + */ +public enum IOMode { + NIO, EPOLL, AUTO +} diff --git a/core/src/main/scala/org/apache/spark/network/exceptions.scala b/network/common/src/main/java/org/apache/spark/network/util/JavaUtils.java similarity index 65% rename from core/src/main/scala/org/apache/spark/network/exceptions.scala rename to network/common/src/main/java/org/apache/spark/network/util/JavaUtils.java index d918d358c4adb..fafdcad04aeb6 100644 --- a/core/src/main/scala/org/apache/spark/network/exceptions.scala +++ b/network/common/src/main/java/org/apache/spark/network/util/JavaUtils.java @@ -15,17 +15,16 @@ * limitations under the License. */ -package org.apache.spark.network +package org.apache.spark.network.util; -class BlockFetchFailureException(blockId: String, errorMsg: String, cause: Throwable) - extends Exception(errorMsg, cause) { +import java.io.Closeable; - def this(blockId: String, errorMsg: String) = this(blockId, errorMsg, null) -} - - -class BlockUploadFailureException(blockId: String, cause: Throwable) - extends Exception(s"Failed to fetch block $blockId", cause) { +import com.google.common.io.Closeables; - def this(blockId: String) = this(blockId, null) +public class JavaUtils { + /** Closes the given object, ignoring IOExceptions. */ + @SuppressWarnings("deprecation") + public static void closeQuietly(Closeable closable) { + Closeables.closeQuietly(closable); + } } diff --git a/network/common/src/main/java/org/apache/spark/network/util/NettyUtils.java b/network/common/src/main/java/org/apache/spark/network/util/NettyUtils.java new file mode 100644 index 0000000000000..3d20dc9e1c1cd --- /dev/null +++ b/network/common/src/main/java/org/apache/spark/network/util/NettyUtils.java @@ -0,0 +1,109 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.util; + +import java.util.concurrent.ThreadFactory; + +import com.google.common.util.concurrent.ThreadFactoryBuilder; +import io.netty.channel.Channel; +import io.netty.channel.EventLoopGroup; +import io.netty.channel.ServerChannel; +import io.netty.channel.epoll.Epoll; +import io.netty.channel.epoll.EpollEventLoopGroup; +import io.netty.channel.epoll.EpollServerSocketChannel; +import io.netty.channel.epoll.EpollSocketChannel; +import io.netty.channel.nio.NioEventLoopGroup; +import io.netty.channel.socket.nio.NioServerSocketChannel; +import io.netty.channel.socket.nio.NioSocketChannel; +import io.netty.handler.codec.ByteToMessageDecoder; +import io.netty.handler.codec.LengthFieldBasedFrameDecoder; + +/** + * Utilities for creating various Netty constructs based on whether we're using EPOLL or NIO. + */ +public class NettyUtils { + /** Creates a Netty EventLoopGroup based on the IOMode. */ + public static EventLoopGroup createEventLoop(IOMode mode, int numThreads, String threadPrefix) { + if (mode == IOMode.AUTO) { + mode = autoselectMode(); + } + + ThreadFactory threadFactory = new ThreadFactoryBuilder() + .setDaemon(true) + .setNameFormat(threadPrefix + "-%d") + .build(); + + switch(mode) { + case NIO: + return new NioEventLoopGroup(numThreads, threadFactory); + case EPOLL: + return new EpollEventLoopGroup(numThreads, threadFactory); + default: + throw new IllegalArgumentException("Unknown io mode: " + mode); + } + } + + /** Returns the correct (client) SocketChannel class based on IOMode. */ + public static Class getClientChannelClass(IOMode mode) { + if (mode == IOMode.AUTO) { + mode = autoselectMode(); + } + switch(mode) { + case NIO: + return NioSocketChannel.class; + case EPOLL: + return EpollSocketChannel.class; + default: + throw new IllegalArgumentException("Unknown io mode: " + mode); + } + } + + /** Returns the correct ServerSocketChannel class based on IOMode. */ + public static Class getServerChannelClass(IOMode mode) { + if (mode == IOMode.AUTO) { + mode = autoselectMode(); + } + switch(mode) { + case NIO: + return NioServerSocketChannel.class; + case EPOLL: + return EpollServerSocketChannel.class; + default: + throw new IllegalArgumentException("Unknown io mode: " + mode); + } + } + + /** + * Creates a LengthFieldBasedFrameDecoder where the first 8 bytes are the length of the frame. + * This is used before all decoders. + */ + public static ByteToMessageDecoder createFrameDecoder() { + // maxFrameLength = 2G + // lengthFieldOffset = 0 + // lengthFieldLength = 8 + // lengthAdjustment = -8, i.e. exclude the 8 byte length itself + // initialBytesToStrip = 8, i.e. strip out the length field itself + return new LengthFieldBasedFrameDecoder(Integer.MAX_VALUE, 0, 8, -8, 8); + } + + /** Returns EPOLL if it's available on this system, NIO otherwise. */ + private static IOMode autoselectMode() { + return Epoll.isAvailable() ? IOMode.EPOLL : IOMode.NIO; + } +} + diff --git a/core/src/main/scala/org/apache/spark/network/netty/NettyConfig.scala b/network/common/src/main/java/org/apache/spark/network/util/SluiceConfig.java similarity index 58% rename from core/src/main/scala/org/apache/spark/network/netty/NettyConfig.scala rename to network/common/src/main/java/org/apache/spark/network/util/SluiceConfig.java index 7c3074e939794..26fa3229c4721 100644 --- a/core/src/main/scala/org/apache/spark/network/netty/NettyConfig.scala +++ b/network/common/src/main/java/org/apache/spark/network/util/SluiceConfig.java @@ -15,35 +15,37 @@ * limitations under the License. */ -package org.apache.spark.network.netty - -import org.apache.spark.SparkConf +package org.apache.spark.network.util; /** - * A central location that tracks all the settings we exposed to users. + * A central location that tracks all the settings we expose to users. */ -private[spark] -class NettyConfig(conf: SparkConf) { +public class SluiceConfig { + private final ConfigProvider conf; + + public SluiceConfig(ConfigProvider conf) { + this.conf = conf; + } /** Port the server listens on. Default to a random port. */ - private[netty] val serverPort = conf.getInt("spark.shuffle.io.port", 0) + public int serverPort() { return conf.getInt("spark.shuffle.io.port", 0); } - /** IO mode: nio, oio, epoll, or auto (try epoll first and then nio). */ - private[netty] val ioMode = conf.get("spark.shuffle.io.mode", "nio").toLowerCase + /** IO mode: nio, epoll, or auto (try epoll first and then nio). */ + public String ioMode() { return conf.get("spark.shuffle.io.mode", "NIO").toUpperCase(); } /** Connect timeout in secs. Default 120 secs. */ - private[netty] val connectTimeoutMs = { - conf.getInt("spark.shuffle.io.connectionTimeout", 120) * 1000 + public int connectionTimeoutMs() { + return conf.getInt("spark.shuffle.io.connectionTimeout", 120) * 1000; } - /** Requested maximum length of the queue of incoming connections. */ - private[netty] val backLog: Option[Int] = conf.getOption("spark.shuffle.io.backLog").map(_.toInt) + /** Requested maximum length of the queue of incoming connections. Default -1 for no backlog. */ + public int backLog() { return conf.getInt("spark.shuffle.io.backLog", -1); } /** Number of threads used in the server thread pool. Default to 0, which is 2x#cores. */ - private[netty] val serverThreads: Int = conf.getInt("spark.shuffle.io.serverThreads", 0) + public int serverThreads() { return conf.getInt("spark.shuffle.io.serverThreads", 0); } /** Number of threads used in the client thread pool. Default to 0, which is 2x#cores. */ - private[netty] val clientThreads: Int = conf.getInt("spark.shuffle.io.clientThreads", 0) + public int clientThreads() { return conf.getInt("spark.shuffle.io.clientThreads", 0); } /** * Receive buffer size (SO_RCVBUF). @@ -52,10 +54,8 @@ class NettyConfig(conf: SparkConf) { * Assuming latency = 1ms, network_bandwidth = 10Gbps * buffer size should be ~ 1.25MB */ - private[netty] val receiveBuf: Option[Int] = - conf.getOption("spark.shuffle.io.receiveBuffer").map(_.toInt) + public int receiveBuf() { return conf.getInt("spark.shuffle.io.receiveBuffer", -1); } /** Send buffer size (SO_SNDBUF). */ - private[netty] val sendBuf: Option[Int] = - conf.getOption("spark.shuffle.io.sendBuffer").map(_.toInt) + public int sendBuf() { return conf.getInt("spark.shuffle.io.sendBuffer", -1); } } diff --git a/network/common/src/test/java/org/apache/spark/network/IntegrationSuite.java b/network/common/src/test/java/org/apache/spark/network/IntegrationSuite.java new file mode 100644 index 0000000000000..d20528558cae1 --- /dev/null +++ b/network/common/src/test/java/org/apache/spark/network/IntegrationSuite.java @@ -0,0 +1,217 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network; + +import java.io.File; +import java.io.RandomAccessFile; +import java.nio.ByteBuffer; +import java.util.Collections; +import java.util.HashSet; +import java.util.LinkedList; +import java.util.List; +import java.util.Random; +import java.util.Set; +import java.util.concurrent.Semaphore; +import java.util.concurrent.TimeUnit; + +import com.google.common.collect.Lists; +import com.google.common.collect.Sets; +import org.junit.AfterClass; +import org.junit.BeforeClass; +import org.junit.Test; + +import static org.junit.Assert.*; + +import org.apache.spark.network.buffer.FileSegmentManagedBuffer; +import org.apache.spark.network.buffer.ManagedBuffer; +import org.apache.spark.network.buffer.NioManagedBuffer; +import org.apache.spark.network.client.ChunkReceivedCallback; +import org.apache.spark.network.client.SluiceClient; +import org.apache.spark.network.client.SluiceClientFactory; +import org.apache.spark.network.server.SluiceServer; +import org.apache.spark.network.server.StreamManager; +import org.apache.spark.network.util.DefaultConfigProvider; +import org.apache.spark.network.util.SluiceConfig; + +public class IntegrationSuite { + static final long STREAM_ID = 1; + static final int BUFFER_CHUNK_INDEX = 0; + static final int FILE_CHUNK_INDEX = 1; + + static SluiceServer server; + static SluiceClientFactory clientFactory; + static StreamManager streamManager; + static File testFile; + + static ManagedBuffer bufferChunk; + static ManagedBuffer fileChunk; + + @BeforeClass + public static void setUp() throws Exception { + int bufSize = 100000; + final ByteBuffer buf = ByteBuffer.allocate(bufSize); + for (int i = 0; i < bufSize; i ++) { + buf.put((byte) i); + } + buf.flip(); + bufferChunk = new NioManagedBuffer(buf); + + testFile = File.createTempFile("shuffle-test-file", "txt"); + testFile.deleteOnExit(); + RandomAccessFile fp = new RandomAccessFile(testFile, "rw"); + byte[] fileContent = new byte[1024]; + new Random().nextBytes(fileContent); + fp.write(fileContent); + fp.close(); + fileChunk = new FileSegmentManagedBuffer(testFile, 10, testFile.length() - 25); + + SluiceConfig conf = new SluiceConfig(new DefaultConfigProvider()); + streamManager = new StreamManager() { + @Override + public ManagedBuffer getChunk(long streamId, int chunkIndex) { + assertEquals(STREAM_ID, streamId); + if (chunkIndex == BUFFER_CHUNK_INDEX) { + return new NioManagedBuffer(buf); + } else if (chunkIndex == FILE_CHUNK_INDEX) { + return new FileSegmentManagedBuffer(testFile, 10, testFile.length() - 25); + } else { + throw new IllegalArgumentException("Invalid chunk index: " + chunkIndex); + } + } + }; + server = new SluiceServer(conf, streamManager, new NoOpRpcHandler()); + clientFactory = new SluiceClientFactory(conf); + } + + @AfterClass + public static void tearDown() { + server.close(); + clientFactory.close(); + testFile.delete(); + } + + class FetchResult { + public Set successChunks; + public Set failedChunks; + public List buffers; + + public void releaseBuffers() { + for (ManagedBuffer buffer : buffers) { + buffer.release(); + } + } + } + + private FetchResult fetchChunks(List chunkIndices) throws Exception { + SluiceClient client = clientFactory.createClient(TestUtils.getLocalHost(), server.getPort()); + final Semaphore sem = new Semaphore(0); + + final FetchResult res = new FetchResult(); + res.successChunks = Collections.synchronizedSet(new HashSet()); + res.failedChunks = Collections.synchronizedSet(new HashSet()); + res.buffers = Collections.synchronizedList(new LinkedList()); + + ChunkReceivedCallback callback = new ChunkReceivedCallback() { + @Override + public void onSuccess(int chunkIndex, ManagedBuffer buffer) { + buffer.retain(); + res.successChunks.add(chunkIndex); + res.buffers.add(buffer); + sem.release(); + } + + @Override + public void onFailure(int chunkIndex, Throwable e) { + res.failedChunks.add(chunkIndex); + sem.release(); + } + }; + + for (int chunkIndex : chunkIndices) { + client.fetchChunk(STREAM_ID, chunkIndex, callback); + } + if (!sem.tryAcquire(chunkIndices.size(), 5, TimeUnit.SECONDS)) { + fail("Timeout getting response from the server"); + } + client.close(); + return res; + } + + @Test + public void fetchBufferChunk() throws Exception { + FetchResult res = fetchChunks(Lists.newArrayList(BUFFER_CHUNK_INDEX)); + assertEquals(res.successChunks, Sets.newHashSet(BUFFER_CHUNK_INDEX)); + assertTrue(res.failedChunks.isEmpty()); + assertBufferListsEqual(res.buffers, Lists.newArrayList(bufferChunk)); + res.releaseBuffers(); + } + + @Test + public void fetchFileChunk() throws Exception { + FetchResult res = fetchChunks(Lists.newArrayList(FILE_CHUNK_INDEX)); + assertEquals(res.successChunks, Sets.newHashSet(FILE_CHUNK_INDEX)); + assertTrue(res.failedChunks.isEmpty()); + assertBufferListsEqual(res.buffers, Lists.newArrayList(fileChunk)); + res.releaseBuffers(); + } + + @Test + public void fetchBothChunks() throws Exception { + FetchResult res = fetchChunks(Lists.newArrayList(BUFFER_CHUNK_INDEX, FILE_CHUNK_INDEX)); + assertEquals(res.successChunks, Sets.newHashSet(BUFFER_CHUNK_INDEX, FILE_CHUNK_INDEX)); + assertTrue(res.failedChunks.isEmpty()); + assertBufferListsEqual(res.buffers, Lists.newArrayList(bufferChunk, fileChunk)); + res.releaseBuffers(); + } + + @Test + public void fetchNonExistentChunk() throws Exception { + FetchResult res = fetchChunks(Lists.newArrayList(12345)); + assertTrue(res.successChunks.isEmpty()); + assertEquals(res.failedChunks, Sets.newHashSet(12345)); + assertTrue(res.buffers.isEmpty()); + } + + @Test + public void fetchChunkAndNonExistent() throws Exception { + FetchResult res = fetchChunks(Lists.newArrayList(BUFFER_CHUNK_INDEX, 12345)); + assertEquals(res.successChunks, Sets.newHashSet(BUFFER_CHUNK_INDEX)); + assertEquals(res.failedChunks, Sets.newHashSet(12345)); + assertBufferListsEqual(res.buffers, Lists.newArrayList(bufferChunk)); + res.releaseBuffers(); + } + + private void assertBufferListsEqual(List list0, List list1) + throws Exception { + assertEquals(list0.size(), list1.size()); + for (int i = 0; i < list0.size(); i ++) { + assertBuffersEqual(list0.get(i), list1.get(i)); + } + } + + private void assertBuffersEqual(ManagedBuffer buffer0, ManagedBuffer buffer1) throws Exception { + ByteBuffer nio0 = buffer0.nioByteBuffer(); + ByteBuffer nio1 = buffer1.nioByteBuffer(); + + int len = nio0.remaining(); + assertEquals(nio0.remaining(), nio1.remaining()); + for (int i = 0; i < len; i ++) { + assertEquals(nio0.get(), nio1.get()); + } + } +} diff --git a/network/common/src/test/java/org/apache/spark/network/NoOpRpcHandler.java b/network/common/src/test/java/org/apache/spark/network/NoOpRpcHandler.java new file mode 100644 index 0000000000000..af35709319957 --- /dev/null +++ b/network/common/src/test/java/org/apache/spark/network/NoOpRpcHandler.java @@ -0,0 +1,26 @@ +package org.apache.spark.network;/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import org.apache.spark.network.client.RpcResponseCallback; +import org.apache.spark.network.server.RpcHandler; + +public class NoOpRpcHandler implements RpcHandler { + @Override + public void receive(byte[] message, RpcResponseCallback callback) { + callback.onSuccess(new byte[0]); + } +} diff --git a/network/common/src/test/java/org/apache/spark/network/ProtocolSuite.java b/network/common/src/test/java/org/apache/spark/network/ProtocolSuite.java new file mode 100644 index 0000000000000..cf74a9d8993fe --- /dev/null +++ b/network/common/src/test/java/org/apache/spark/network/ProtocolSuite.java @@ -0,0 +1,84 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network; + +import io.netty.channel.embedded.EmbeddedChannel; +import org.junit.Test; + +import static org.junit.Assert.assertEquals; + +import org.apache.spark.network.protocol.StreamChunkId; +import org.apache.spark.network.protocol.request.ChunkFetchRequest; +import org.apache.spark.network.protocol.request.ClientRequest; +import org.apache.spark.network.protocol.request.ClientRequestDecoder; +import org.apache.spark.network.protocol.request.ClientRequestEncoder; +import org.apache.spark.network.protocol.response.ChunkFetchFailure; +import org.apache.spark.network.protocol.response.ChunkFetchSuccess; +import org.apache.spark.network.protocol.response.ServerResponse; +import org.apache.spark.network.protocol.response.ServerResponseDecoder; +import org.apache.spark.network.protocol.response.ServerResponseEncoder; +import org.apache.spark.network.util.NettyUtils; + +public class ProtocolSuite { + private void testServerToClient(ServerResponse msg) { + EmbeddedChannel serverChannel = new EmbeddedChannel(new ServerResponseEncoder()); + serverChannel.writeOutbound(msg); + + EmbeddedChannel clientChannel = new EmbeddedChannel( + NettyUtils.createFrameDecoder(), new ServerResponseDecoder()); + + while (!serverChannel.outboundMessages().isEmpty()) { + clientChannel.writeInbound(serverChannel.readOutbound()); + } + + assertEquals(1, clientChannel.inboundMessages().size()); + assertEquals(msg, clientChannel.readInbound()); + } + + private void testClientToServer(ClientRequest msg) { + EmbeddedChannel clientChannel = new EmbeddedChannel(new ClientRequestEncoder()); + clientChannel.writeOutbound(msg); + + EmbeddedChannel serverChannel = new EmbeddedChannel( + NettyUtils.createFrameDecoder(), new ClientRequestDecoder()); + + while (!clientChannel.outboundMessages().isEmpty()) { + serverChannel.writeInbound(clientChannel.readOutbound()); + } + + assertEquals(1, serverChannel.inboundMessages().size()); + assertEquals(msg, serverChannel.readInbound()); + } + + @Test + public void s2cChunkFetchSuccess() { + testServerToClient(new ChunkFetchSuccess(new StreamChunkId(1, 2), new TestManagedBuffer(10))); + testServerToClient(new ChunkFetchSuccess(new StreamChunkId(1, 2), new TestManagedBuffer(0))); + } + + @Test + public void s2cBlockFetchFailure() { + testServerToClient(new ChunkFetchFailure(new StreamChunkId(1, 2), "this is an error")); + testServerToClient(new ChunkFetchFailure(new StreamChunkId(1, 2), "")); + } + + @Test + public void c2sChunkFetchRequest() { + testClientToServer(new ChunkFetchRequest(new StreamChunkId(1, 2))); + } +} diff --git a/network/common/src/test/java/org/apache/spark/network/SluiceClientFactorySuite.java b/network/common/src/test/java/org/apache/spark/network/SluiceClientFactorySuite.java new file mode 100644 index 0000000000000..e6b59b9ad8e5c --- /dev/null +++ b/network/common/src/test/java/org/apache/spark/network/SluiceClientFactorySuite.java @@ -0,0 +1,101 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network; + +import java.util.concurrent.TimeoutException; + +import org.junit.After; +import org.junit.Before; +import org.junit.Test; + +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertTrue; + +import org.apache.spark.network.client.SluiceClient; +import org.apache.spark.network.client.SluiceClientFactory; +import org.apache.spark.network.server.DefaultStreamManager; +import org.apache.spark.network.server.RpcHandler; +import org.apache.spark.network.server.SluiceServer; +import org.apache.spark.network.server.StreamManager; +import org.apache.spark.network.util.DefaultConfigProvider; +import org.apache.spark.network.util.JavaUtils; +import org.apache.spark.network.util.SluiceConfig; + +public class SluiceClientFactorySuite { + private SluiceConfig conf; + private SluiceServer server1; + private SluiceServer server2; + + @Before + public void setUp() { + conf = new SluiceConfig(new DefaultConfigProvider()); + StreamManager streamManager = new DefaultStreamManager(); + RpcHandler rpcHandler = new NoOpRpcHandler(); + server1 = new SluiceServer(conf, streamManager, rpcHandler); + server2 = new SluiceServer(conf, streamManager, rpcHandler); + } + + @After + public void tearDown() { + JavaUtils.closeQuietly(server1); + JavaUtils.closeQuietly(server2); + } + + @Test + public void createAndReuseBlockClients() throws TimeoutException { + SluiceClientFactory factory = new SluiceClientFactory(conf); + SluiceClient c1 = factory.createClient(TestUtils.getLocalHost(), server1.getPort()); + SluiceClient c2 = factory.createClient(TestUtils.getLocalHost(), server1.getPort()); + SluiceClient c3 = factory.createClient(TestUtils.getLocalHost(), server2.getPort()); + assertTrue(c1.isActive()); + assertTrue(c3.isActive()); + assertTrue(c1 == c2); + assertTrue(c1 != c3); + factory.close(); + } + + @Test + public void neverReturnInactiveClients() throws Exception { + SluiceClientFactory factory = new SluiceClientFactory(conf); + SluiceClient c1 = factory.createClient(TestUtils.getLocalHost(), server1.getPort()); + c1.close(); + + long start = System.currentTimeMillis(); + while (c1.isActive() && (System.currentTimeMillis() - start) < 3000) { + Thread.sleep(10); + } + assertFalse(c1.isActive()); + + SluiceClient c2 = factory.createClient(TestUtils.getLocalHost(), server1.getPort()); + assertFalse(c1 == c2); + assertTrue(c2.isActive()); + factory.close(); + } + + @Test + public void closeBlockClientsWithFactory() throws TimeoutException { + SluiceClientFactory factory = new SluiceClientFactory(conf); + SluiceClient c1 = factory.createClient(TestUtils.getLocalHost(), server1.getPort()); + SluiceClient c2 = factory.createClient(TestUtils.getLocalHost(), server2.getPort()); + assertTrue(c1.isActive()); + assertTrue(c2.isActive()); + factory.close(); + assertFalse(c1.isActive()); + assertFalse(c2.isActive()); + } +} diff --git a/network/common/src/test/java/org/apache/spark/network/SluiceClientHandlerSuite.java b/network/common/src/test/java/org/apache/spark/network/SluiceClientHandlerSuite.java new file mode 100644 index 0000000000000..cab0597fb948a --- /dev/null +++ b/network/common/src/test/java/org/apache/spark/network/SluiceClientHandlerSuite.java @@ -0,0 +1,90 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network; + +import io.netty.channel.embedded.EmbeddedChannel; +import org.junit.Test; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.mockito.Matchers.any; +import static org.mockito.Matchers.eq; +import static org.mockito.Mockito.*; + +import org.apache.spark.network.buffer.ManagedBuffer; +import org.apache.spark.network.client.ChunkReceivedCallback; +import org.apache.spark.network.client.SluiceClientHandler; +import org.apache.spark.network.protocol.StreamChunkId; +import org.apache.spark.network.protocol.response.ChunkFetchFailure; +import org.apache.spark.network.protocol.response.ChunkFetchSuccess; + +public class SluiceClientHandlerSuite { + @Test + public void handleSuccessfulFetch() { + StreamChunkId streamChunkId = new StreamChunkId(1, 0); + + SluiceClientHandler handler = new SluiceClientHandler(); + ChunkReceivedCallback callback = mock(ChunkReceivedCallback.class); + handler.addFetchRequest(streamChunkId, callback); + assertEquals(1, handler.numOutstandingRequests()); + + EmbeddedChannel channel = new EmbeddedChannel(handler); + + channel.writeInbound(new ChunkFetchSuccess(streamChunkId, new TestManagedBuffer(123))); + verify(callback, times(1)).onSuccess(eq(0), (ManagedBuffer) any()); + assertEquals(0, handler.numOutstandingRequests()); + assertFalse(channel.finish()); + } + + @Test + public void handleFailedFetch() { + StreamChunkId streamChunkId = new StreamChunkId(1, 0); + SluiceClientHandler handler = new SluiceClientHandler(); + ChunkReceivedCallback callback = mock(ChunkReceivedCallback.class); + handler.addFetchRequest(streamChunkId, callback); + assertEquals(1, handler.numOutstandingRequests()); + + EmbeddedChannel channel = new EmbeddedChannel(handler); + channel.writeInbound(new ChunkFetchFailure(streamChunkId, "some error msg")); + verify(callback, times(1)).onFailure(eq(0), (Throwable) any()); + assertEquals(0, handler.numOutstandingRequests()); + assertFalse(channel.finish()); + } + + @Test + public void clearAllOutstandingRequests() { + SluiceClientHandler handler = new SluiceClientHandler(); + ChunkReceivedCallback callback = mock(ChunkReceivedCallback.class); + handler.addFetchRequest(new StreamChunkId(1, 0), callback); + handler.addFetchRequest(new StreamChunkId(1, 1), callback); + handler.addFetchRequest(new StreamChunkId(1, 2), callback); + assertEquals(3, handler.numOutstandingRequests()); + + EmbeddedChannel channel = new EmbeddedChannel(handler); + + channel.writeInbound(new ChunkFetchSuccess(new StreamChunkId(1, 0), new TestManagedBuffer(12))); + channel.pipeline().fireExceptionCaught(new Exception("duh duh duhhhh")); + + // should fail both b2 and b3 + verify(callback, times(1)).onSuccess(eq(0), (ManagedBuffer) any()); + verify(callback, times(1)).onFailure(eq(1), (Throwable) any()); + verify(callback, times(1)).onFailure(eq(2), (Throwable) any()); + assertEquals(0, handler.numOutstandingRequests()); + assertFalse(channel.finish()); + } +} diff --git a/network/common/src/test/java/org/apache/spark/network/TestManagedBuffer.java b/network/common/src/test/java/org/apache/spark/network/TestManagedBuffer.java new file mode 100644 index 0000000000000..7e7554af70f42 --- /dev/null +++ b/network/common/src/test/java/org/apache/spark/network/TestManagedBuffer.java @@ -0,0 +1,104 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network; + +import java.io.IOException; +import java.io.InputStream; +import java.nio.ByteBuffer; + +import com.google.common.base.Preconditions; +import io.netty.buffer.Unpooled; + +import org.apache.spark.network.buffer.ManagedBuffer; +import org.apache.spark.network.buffer.NettyManagedBuffer; + +/** + * A ManagedBuffer implementation that contains 0, 1, 2, 3, ..., (len-1). + * + * Used for testing. + */ +public class TestManagedBuffer extends ManagedBuffer { + + private final int len; + private NettyManagedBuffer underlying; + + public TestManagedBuffer(int len) { + Preconditions.checkArgument(len <= Byte.MAX_VALUE); + this.len = len; + byte[] byteArray = new byte[len]; + for (int i = 0; i < len; i ++) { + byteArray[i] = (byte) i; + } + this.underlying = new NettyManagedBuffer(Unpooled.wrappedBuffer(byteArray)); + } + + + @Override + public long size() { + return underlying.size(); + } + + @Override + public ByteBuffer nioByteBuffer() throws IOException { + return underlying.nioByteBuffer(); + } + + @Override + public InputStream inputStream() throws IOException { + return underlying.inputStream(); + } + + @Override + public ManagedBuffer retain() { + underlying.retain(); + return this; + } + + @Override + public ManagedBuffer release() { + underlying.release(); + return this; + } + + @Override + public Object convertToNetty() throws IOException { + return underlying.convertToNetty(); + } + + @Override + public boolean equals(Object other) { + if (other instanceof ManagedBuffer) { + try { + ByteBuffer nioBuf = ((ManagedBuffer) other).nioByteBuffer(); + if (nioBuf.remaining() != len) { + return false; + } else { + for (int i = 0; i < len; i ++) { + if (nioBuf.get() != i) { + return false; + } + } + return true; + } + } catch (IOException e) { + throw new RuntimeException(e); + } + } + return false; + } +} diff --git a/network/common/src/test/java/org/apache/spark/network/TestUtils.java b/network/common/src/test/java/org/apache/spark/network/TestUtils.java new file mode 100644 index 0000000000000..56a2b805f154c --- /dev/null +++ b/network/common/src/test/java/org/apache/spark/network/TestUtils.java @@ -0,0 +1,30 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network; + +import java.net.InetAddress; + +public class TestUtils { + public static String getLocalHost() { + try { + return InetAddress.getLocalHost().getHostAddress(); + } catch (Exception e) { + throw new RuntimeException(e); + } + } +} diff --git a/pom.xml b/pom.xml index 7756c89b00cad..b0d39cfec1e8d 100644 --- a/pom.xml +++ b/pom.xml @@ -91,6 +91,7 @@ graphx mllib tools + network/common streaming sql/catalyst sql/core diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 8a1b2d3b91327..71041e7fe1a14 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -51,8 +51,6 @@ object MimaExcludes { // MapStatus should be private[spark] ProblemFilters.exclude[IncompatibleTemplateDefProblem]( "org.apache.spark.scheduler.MapStatus"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.api.java.JavaRDDLike.foreachAsync"), ProblemFilters.exclude[MissingClassProblem]( "org.apache.spark.network.netty.PathResolver"), ProblemFilters.exclude[MissingClassProblem]( diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala index 7149dbc12a365..190373e0cb5f2 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala @@ -122,7 +122,7 @@ class ReceiverTracker(ssc: StreamingContext) extends Logging { sender: ActorRef ) { if (!receiverInputStreamMap.contains(streamId)) { - throw new Exception("Register received for unexpected id " + streamId) + throw new Exception("Register received for unexpected type " + streamId) } receiverInfo(streamId) = ReceiverInfo( streamId, s"${typ}-${streamId}", receiverActor, true, host) From ae4083aafd762a71dc52f0963efefeaf215040a2 Mon Sep 17 00:00:00 2001 From: Anand Avati Date: Fri, 10 Oct 2014 00:46:56 -0700 Subject: [PATCH 29/46] [SPARK-2805] Upgrade Akka to 2.3.4 This is a second rev of the Akka upgrade (earlier merged, but reverted). I made a slight modification which is that I also upgrade Hive to deal with a compatibility issue related to the protocol buffers library. Author: Anand Avati Author: Patrick Wendell Closes #2752 from pwendell/akka-upgrade and squashes the following commits: 4c7ca3f [Patrick Wendell] Upgrading to new hive->protobuf version 57a2315 [Anand Avati] SPARK-1812: streaming - remove tests which depend on akka.actor.IO 2a551d3 [Anand Avati] SPARK-1812: core - upgrade to akka 2.3.4 --- .../org/apache/spark/deploy/Client.scala | 2 +- .../spark/deploy/client/AppClient.scala | 2 +- .../spark/deploy/worker/WorkerWatcher.scala | 2 +- .../apache/spark/MapOutputTrackerSuite.scala | 4 +- pom.xml | 4 +- .../spark/streaming/InputStreamsSuite.scala | 71 ------------------- 6 files changed, 7 insertions(+), 78 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/Client.scala b/core/src/main/scala/org/apache/spark/deploy/Client.scala index 065ddda50e65e..f2687ce6b42b4 100644 --- a/core/src/main/scala/org/apache/spark/deploy/Client.scala +++ b/core/src/main/scala/org/apache/spark/deploy/Client.scala @@ -130,7 +130,7 @@ private class ClientActor(driverArgs: ClientArguments, conf: SparkConf) println(s"Error connecting to master ${driverArgs.master} ($remoteAddress), exiting.") System.exit(-1) - case AssociationErrorEvent(cause, _, remoteAddress, _) => + case AssociationErrorEvent(cause, _, remoteAddress, _, _) => println(s"Error connecting to master ${driverArgs.master} ($remoteAddress), exiting.") println(s"Cause was: $cause") System.exit(-1) diff --git a/core/src/main/scala/org/apache/spark/deploy/client/AppClient.scala b/core/src/main/scala/org/apache/spark/deploy/client/AppClient.scala index 32790053a6be8..98a93d1fcb2a3 100644 --- a/core/src/main/scala/org/apache/spark/deploy/client/AppClient.scala +++ b/core/src/main/scala/org/apache/spark/deploy/client/AppClient.scala @@ -154,7 +154,7 @@ private[spark] class AppClient( logWarning(s"Connection to $address failed; waiting for master to reconnect...") markDisconnected() - case AssociationErrorEvent(cause, _, address, _) if isPossibleMaster(address) => + case AssociationErrorEvent(cause, _, address, _, _) if isPossibleMaster(address) => logWarning(s"Could not connect to $address: $cause") case StopAppClient => diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/WorkerWatcher.scala b/core/src/main/scala/org/apache/spark/deploy/worker/WorkerWatcher.scala index 6d0d0bbe5ecec..63a8ac817b618 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/WorkerWatcher.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/WorkerWatcher.scala @@ -54,7 +54,7 @@ private[spark] class WorkerWatcher(workerUrl: String) case AssociatedEvent(localAddress, remoteAddress, inbound) if isWorker(remoteAddress) => logInfo(s"Successfully connected to $workerUrl") - case AssociationErrorEvent(cause, localAddress, remoteAddress, inbound) + case AssociationErrorEvent(cause, localAddress, remoteAddress, inbound, _) if isWorker(remoteAddress) => // These logs may not be seen if the worker (and associated pipe) has died logError(s"Could not initialize connection to worker $workerUrl. Exiting.") diff --git a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala index 1fef79ad1001f..cbc0bd178d894 100644 --- a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala +++ b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala @@ -146,7 +146,7 @@ class MapOutputTrackerSuite extends FunSuite with LocalSparkContext { val masterTracker = new MapOutputTrackerMaster(conf) val actorSystem = ActorSystem("test") val actorRef = TestActorRef[MapOutputTrackerMasterActor]( - new MapOutputTrackerMasterActor(masterTracker, newConf))(actorSystem) + Props(new MapOutputTrackerMasterActor(masterTracker, newConf)))(actorSystem) val masterActor = actorRef.underlyingActor // Frame size should be ~123B, and no exception should be thrown @@ -164,7 +164,7 @@ class MapOutputTrackerSuite extends FunSuite with LocalSparkContext { val masterTracker = new MapOutputTrackerMaster(conf) val actorSystem = ActorSystem("test") val actorRef = TestActorRef[MapOutputTrackerMasterActor]( - new MapOutputTrackerMasterActor(masterTracker, newConf))(actorSystem) + Props(new MapOutputTrackerMasterActor(masterTracker, newConf)))(actorSystem) val masterActor = actorRef.underlyingActor // Frame size should be ~1.1MB, and MapOutputTrackerMasterActor should throw exception. diff --git a/pom.xml b/pom.xml index b0d39cfec1e8d..f72baeb0c6dd1 100644 --- a/pom.xml +++ b/pom.xml @@ -119,7 +119,7 @@ 0.18.1 shaded-protobuf org.spark-project.akka - 2.2.3-shaded-protobuf + 2.3.4-spark 1.7.5 1.2.17 1.0.4 @@ -128,7 +128,7 @@ 0.94.6 1.4.0 3.4.5 - 0.12.0 + 0.12.0-protobuf 1.4.3 1.2.3 8.1.14.v20131031 diff --git a/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala index a44a45a3e9bd6..fa04fa326e370 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala @@ -18,8 +18,6 @@ package org.apache.spark.streaming import akka.actor.Actor -import akka.actor.IO -import akka.actor.IOManager import akka.actor.Props import akka.util.ByteString @@ -143,59 +141,6 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter { conf.set("spark.streaming.clock", "org.apache.spark.streaming.util.ManualClock") } - // TODO: This test works in IntelliJ but not through SBT - ignore("actor input stream") { - // Start the server - val testServer = new TestServer() - val port = testServer.port - testServer.start() - - // Set up the streaming context and input streams - val ssc = new StreamingContext(conf, batchDuration) - val networkStream = ssc.actorStream[String](Props(new TestActor(port)), "TestActor", - // Had to pass the local value of port to prevent from closing over entire scope - StorageLevel.MEMORY_AND_DISK) - val outputBuffer = new ArrayBuffer[Seq[String]] with SynchronizedBuffer[Seq[String]] - val outputStream = new TestOutputStream(networkStream, outputBuffer) - def output = outputBuffer.flatMap(x => x) - outputStream.register() - ssc.start() - - // Feed data to the server to send to the network receiver - val clock = ssc.scheduler.clock.asInstanceOf[ManualClock] - val input = 1 to 9 - val expectedOutput = input.map(x => x.toString) - Thread.sleep(1000) - for (i <- 0 until input.size) { - testServer.send(input(i).toString) - Thread.sleep(500) - clock.addToTime(batchDuration.milliseconds) - } - Thread.sleep(1000) - logInfo("Stopping server") - testServer.stop() - logInfo("Stopping context") - ssc.stop() - - // Verify whether data received was as expected - logInfo("--------------------------------") - logInfo("output.size = " + outputBuffer.size) - logInfo("output") - outputBuffer.foreach(x => logInfo("[" + x.mkString(",") + "]")) - logInfo("expected output.size = " + expectedOutput.size) - logInfo("expected output") - expectedOutput.foreach(x => logInfo("[" + x.mkString(",") + "]")) - logInfo("--------------------------------") - - // Verify whether all the elements received are as expected - // (whether the elements were received one in each interval is not verified) - assert(output.size === expectedOutput.size) - for (i <- 0 until output.size) { - assert(output(i) === expectedOutput(i)) - } - } - - test("multi-thread receiver") { // set up the test receiver val numThreads = 10 @@ -377,22 +322,6 @@ class TestServer(portToBind: Int = 0) extends Logging { def port = serverSocket.getLocalPort } -/** This is an actor for testing actor input stream */ -class TestActor(port: Int) extends Actor with ActorHelper { - - def bytesToString(byteString: ByteString) = byteString.utf8String - - override def preStart(): Unit = { - @deprecated("suppress compile time deprecation warning", "1.0.0") - val unit = IOManager(context.system).connect(new InetSocketAddress(port)) - } - - def receive = { - case IO.Read(socket, bytes) => - store(bytesToString(bytes)) - } -} - /** This is a receiver to test multiple threads inserting data using block generator */ class MultiThreadTestReceiver(numThreads: Int, numRecordsPerThread: Int) extends Receiver[Int](StorageLevel.MEMORY_ONLY_SER) with Logging { From 020691e28cb22cf356d44bb801a80aca74495778 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Fri, 10 Oct 2014 14:14:05 -0700 Subject: [PATCH 30/46] [SPARK-3886] [PySpark] use AutoBatchedSerializer by default Use AutoBatchedSerializer by default, which will choose the proper batch size based on size of serialized objects, let the size of serialized batch fall in into [64k - 640k]. In JVM, the serializer will also track the objects in batch to figure out duplicated objects, larger batch may cause OOM in JVM. Author: Davies Liu Closes #2740 from davies/batchsize and squashes the following commits: 52cdb88 [Davies Liu] update docs 185f2b9 [Davies Liu] use AutoBatchedSerializer by default --- python/pyspark/context.py | 11 +++++++---- python/pyspark/serializers.py | 4 ++-- 2 files changed, 9 insertions(+), 6 deletions(-) diff --git a/python/pyspark/context.py b/python/pyspark/context.py index 6fb30d65c5edd..85c04624da4a6 100644 --- a/python/pyspark/context.py +++ b/python/pyspark/context.py @@ -29,7 +29,7 @@ from pyspark.files import SparkFiles from pyspark.java_gateway import launch_gateway from pyspark.serializers import PickleSerializer, BatchedSerializer, UTF8Deserializer, \ - PairDeserializer, CompressedSerializer + PairDeserializer, CompressedSerializer, AutoBatchedSerializer from pyspark.storagelevel import StorageLevel from pyspark.rdd import RDD from pyspark.traceback_utils import CallSite, first_spark_call @@ -67,7 +67,7 @@ class SparkContext(object): _default_batch_size_for_serialized_input = 10 def __init__(self, master=None, appName=None, sparkHome=None, pyFiles=None, - environment=None, batchSize=1024, serializer=PickleSerializer(), conf=None, + environment=None, batchSize=0, serializer=PickleSerializer(), conf=None, gateway=None): """ Create a new SparkContext. At least the master and app name should be set, @@ -83,8 +83,9 @@ def __init__(self, master=None, appName=None, sparkHome=None, pyFiles=None, :param environment: A dictionary of environment variables to set on worker nodes. :param batchSize: The number of Python objects represented as a single - Java object. Set 1 to disable batching or -1 to use an - unlimited batch size. + Java object. Set 1 to disable batching, 0 to automatically choose + the batch size based on object sizes, or -1 to use an unlimited + batch size :param serializer: The serializer for RDDs. :param conf: A L{SparkConf} object setting Spark properties. :param gateway: Use an existing gateway and JVM, otherwise a new JVM @@ -117,6 +118,8 @@ def _do_init(self, master, appName, sparkHome, pyFiles, environment, batchSize, self._unbatched_serializer = serializer if batchSize == 1: self.serializer = self._unbatched_serializer + elif batchSize == 0: + self.serializer = AutoBatchedSerializer(self._unbatched_serializer) else: self.serializer = BatchedSerializer(self._unbatched_serializer, batchSize) diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py index 099fa54cf2bd7..3d1a34b281acc 100644 --- a/python/pyspark/serializers.py +++ b/python/pyspark/serializers.py @@ -220,7 +220,7 @@ class AutoBatchedSerializer(BatchedSerializer): Choose the size of batch automatically based on the size of object """ - def __init__(self, serializer, bestSize=1 << 20): + def __init__(self, serializer, bestSize=1 << 16): BatchedSerializer.__init__(self, serializer, -1) self.bestSize = bestSize @@ -247,7 +247,7 @@ def __eq__(self, other): other.serializer == self.serializer) def __str__(self): - return "BatchedSerializer<%s>" % str(self.serializer) + return "AutoBatchedSerializer<%s>" % str(self.serializer) class CartesianDeserializer(FramedSerializer): From 2c5d9dc7712579410e65d76be93bb980c6ec2fd0 Mon Sep 17 00:00:00 2001 From: Patrick Wendell Date: Fri, 10 Oct 2014 16:49:19 -0700 Subject: [PATCH 31/46] HOTFIX: Fix build issue with Akka 2.3.4 upgrade. We had to upgrade our Hive 0.12 version as well to deal with a protobuf conflict (both hive and akka have been using a shaded protobuf version). This is testing a correctly patched version of Hive 0.12. Author: Patrick Wendell Closes #2756 from pwendell/hotfix and squashes the following commits: cc979d0 [Patrick Wendell] HOTFIX: Fix build issue with Akka 2.3.4 upgrade. --- pom.xml | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/pom.xml b/pom.xml index f72baeb0c6dd1..061585cd012ed 100644 --- a/pom.xml +++ b/pom.xml @@ -128,7 +128,7 @@ 0.94.6 1.4.0 3.4.5 - 0.12.0-protobuf + 0.12.0-protobuf-2.5 1.4.3 1.2.3 8.1.14.v20131031 @@ -224,6 +224,18 @@ false + + + spark-staging + Spring Staging Repository + https://oss.sonatype.org/content/repositories/orgspark-project-1085 + + true + + + false + + From 5b5dbe62576df2964039874c24173cc244b38c5f Mon Sep 17 00:00:00 2001 From: Prashant Sharma Date: Fri, 10 Oct 2014 18:39:55 -0700 Subject: [PATCH 32/46] [SPARK-2924] Required by scala 2.11, only one fun/ctor amongst overriden alternatives, can have default argument(s). ...riden alternatives, can have default argument. Author: Prashant Sharma Closes #2750 from ScrapCodes/SPARK-2924/default-args-removed and squashes the following commits: d9785c3 [Prashant Sharma] [SPARK-2924] Required by scala 2.11, only one function/ctor amongst overriden alternatives, can have default argument. --- .../org/apache/spark/util/FileLogger.scala | 19 +++++++++++++++++-- .../apache/spark/util/FileLoggerSuite.scala | 8 ++++---- 2 files changed, 21 insertions(+), 6 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/util/FileLogger.scala b/core/src/main/scala/org/apache/spark/util/FileLogger.scala index 6d1fc05a15d2c..fdc73f08261a6 100644 --- a/core/src/main/scala/org/apache/spark/util/FileLogger.scala +++ b/core/src/main/scala/org/apache/spark/util/FileLogger.scala @@ -51,12 +51,27 @@ private[spark] class FileLogger( def this( logDir: String, sparkConf: SparkConf, - compress: Boolean = false, - overwrite: Boolean = true) = { + compress: Boolean, + overwrite: Boolean) = { this(logDir, sparkConf, SparkHadoopUtil.get.newConfiguration(sparkConf), compress = compress, overwrite = overwrite) } + def this( + logDir: String, + sparkConf: SparkConf, + compress: Boolean) = { + this(logDir, sparkConf, SparkHadoopUtil.get.newConfiguration(sparkConf), compress = compress, + overwrite = true) + } + + def this( + logDir: String, + sparkConf: SparkConf) = { + this(logDir, sparkConf, SparkHadoopUtil.get.newConfiguration(sparkConf), compress = false, + overwrite = true) + } + private val dateFormat = new ThreadLocal[SimpleDateFormat]() { override def initialValue(): SimpleDateFormat = new SimpleDateFormat("yyyy/MM/dd HH:mm:ss") } diff --git a/core/src/test/scala/org/apache/spark/util/FileLoggerSuite.scala b/core/src/test/scala/org/apache/spark/util/FileLoggerSuite.scala index dc2a05631d83d..72466a3aa1130 100644 --- a/core/src/test/scala/org/apache/spark/util/FileLoggerSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/FileLoggerSuite.scala @@ -74,13 +74,13 @@ class FileLoggerSuite extends FunSuite with BeforeAndAfter { test("Logging when directory already exists") { // Create the logging directory multiple times - new FileLogger(logDirPathString, new SparkConf, overwrite = true).start() - new FileLogger(logDirPathString, new SparkConf, overwrite = true).start() - new FileLogger(logDirPathString, new SparkConf, overwrite = true).start() + new FileLogger(logDirPathString, new SparkConf, compress = false, overwrite = true).start() + new FileLogger(logDirPathString, new SparkConf, compress = false, overwrite = true).start() + new FileLogger(logDirPathString, new SparkConf, compress = false, overwrite = true).start() // If overwrite is not enabled, an exception should be thrown intercept[IOException] { - new FileLogger(logDirPathString, new SparkConf, overwrite = false).start() + new FileLogger(logDirPathString, new SparkConf, compress = false, overwrite = false).start() } } From 8dc1ded0cec5952bf73ee58e2bbd16d9479dbdcc Mon Sep 17 00:00:00 2001 From: cocoatomo Date: Sat, 11 Oct 2014 11:26:17 -0700 Subject: [PATCH 33/46] [SPARK-3867][PySpark] ./python/run-tests failed when it run with Python 2.6 and unittest2 is not installed ./python/run-tests search a Python 2.6 executable on PATH and use it if available. When using Python 2.6, it is going to import unittest2 module which is not a standard library in Python 2.6, so it fails with ImportError. Author: cocoatomo Closes #2759 from cocoatomo/issues/3867-unittest2-import-error and squashes the following commits: f068eb5 [cocoatomo] [SPARK-3867] ./python/run-tests failed when it run with Python 2.6 and unittest2 is not installed --- python/pyspark/mllib/tests.py | 6 +++++- python/pyspark/tests.py | 6 +++++- 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/python/pyspark/mllib/tests.py b/python/pyspark/mllib/tests.py index 5c20e100e144f..463faf7b6f520 100644 --- a/python/pyspark/mllib/tests.py +++ b/python/pyspark/mllib/tests.py @@ -25,7 +25,11 @@ from numpy import array, array_equal if sys.version_info[:2] <= (2, 6): - import unittest2 as unittest + try: + import unittest2 as unittest + except ImportError: + sys.stderr.write('Please install unittest2 to test with Python 2.6 or earlier') + sys.exit(1) else: import unittest diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index 7f05d48ade2b3..ceab57464f013 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -34,7 +34,11 @@ from platform import python_implementation if sys.version_info[:2] <= (2, 6): - import unittest2 as unittest + try: + import unittest2 as unittest + except ImportError: + sys.stderr.write('Please install unittest2 to test with Python 2.6 or earlier') + sys.exit(1) else: import unittest From aa58f67bb6973776f2d3e2a9b99b2e1428548649 Mon Sep 17 00:00:00 2001 From: cocoatomo Date: Sat, 11 Oct 2014 11:51:59 -0700 Subject: [PATCH 34/46] [SPARK-3909][PySpark][Doc] A corrupted format in Sphinx documents and building warnings Sphinx documents contains a corrupted ReST format and have some warnings. The purpose of this issue is same as https://issues.apache.org/jira/browse/SPARK-3773. commit: 0e8203f4fb721158fb27897680da476174d24c4b output ``` $ cd ./python/docs $ make clean html rm -rf _build/* sphinx-build -b html -d _build/doctrees . _build/html Making output directory... Running Sphinx v1.2.3 loading pickled environment... not yet created building [html]: targets for 4 source files that are out of date updating environment: 4 added, 0 changed, 0 removed reading sources... [100%] pyspark.sql /Users//MyRepos/Scala/spark/python/pyspark/mllib/feature.py:docstring of pyspark.mllib.feature.Word2VecModel.findSynonyms:4: WARNING: Field list ends without a blank line; unexpected unindent. /Users//MyRepos/Scala/spark/python/pyspark/mllib/feature.py:docstring of pyspark.mllib.feature.Word2VecModel.transform:3: WARNING: Field list ends without a blank line; unexpected unindent. /Users//MyRepos/Scala/spark/python/pyspark/sql.py:docstring of pyspark.sql:4: WARNING: Bullet list ends without a blank line; unexpected unindent. looking for now-outdated files... none found pickling environment... done checking consistency... done preparing documents... done writing output... [100%] pyspark.sql writing additional files... (12 module code pages) _modules/index search copying static files... WARNING: html_static_path entry u'/Users//MyRepos/Scala/spark/python/docs/_static' does not exist done copying extra files... done dumping search index... done dumping object inventory... done build succeeded, 4 warnings. Build finished. The HTML pages are in _build/html. ``` Author: cocoatomo Closes #2766 from cocoatomo/issues/3909-sphinx-build-warnings and squashes the following commits: 2c7faa8 [cocoatomo] [SPARK-3909][PySpark][Doc] A corrupted format in Sphinx documents and building warnings --- python/docs/conf.py | 2 +- python/pyspark/mllib/feature.py | 2 ++ python/pyspark/rdd.py | 2 +- python/pyspark/sql.py | 10 +++++----- 4 files changed, 9 insertions(+), 7 deletions(-) diff --git a/python/docs/conf.py b/python/docs/conf.py index 8e6324f058251..e58d97ae6a746 100644 --- a/python/docs/conf.py +++ b/python/docs/conf.py @@ -131,7 +131,7 @@ # Add any paths that contain custom static files (such as style sheets) here, # relative to this directory. They are copied after the builtin static files, # so a file named "default.css" will overwrite the builtin "default.css". -html_static_path = ['_static'] +#html_static_path = ['_static'] # Add any extra paths that contain custom files (such as robots.txt or # .htaccess) here, relative to this directory. These files are copied diff --git a/python/pyspark/mllib/feature.py b/python/pyspark/mllib/feature.py index a44a27fd3b6a6..f4cbf31b94fe2 100644 --- a/python/pyspark/mllib/feature.py +++ b/python/pyspark/mllib/feature.py @@ -44,6 +44,7 @@ def transform(self, word): """ :param word: a word :return: vector representation of word + Transforms a word to its vector representation Note: local use only @@ -57,6 +58,7 @@ def findSynonyms(self, x, num): :param x: a word or a vector representation of word :param num: number of synonyms to find :return: array of (word, cosineSimilarity) + Find synonyms of a word Note: local use only diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index 6797d50659a92..e13bab946c44a 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -2009,7 +2009,7 @@ def countApproxDistinct(self, relativeSD=0.05): of The Art Cardinality Estimation Algorithm", available here. - :param relativeSD Relative accuracy. Smaller values create + :param relativeSD: Relative accuracy. Smaller values create counters that require more space. It must be greater than 0.000017. diff --git a/python/pyspark/sql.py b/python/pyspark/sql.py index d3d36eb995ab6..b31a82f9b19ac 100644 --- a/python/pyspark/sql.py +++ b/python/pyspark/sql.py @@ -19,14 +19,14 @@ public classes of Spark SQL: - L{SQLContext} - Main entry point for SQL functionality. + Main entry point for SQL functionality. - L{SchemaRDD} - A Resilient Distributed Dataset (RDD) with Schema information for the data contained. In - addition to normal RDD operations, SchemaRDDs also support SQL. + A Resilient Distributed Dataset (RDD) with Schema information for the data contained. In + addition to normal RDD operations, SchemaRDDs also support SQL. - L{Row} - A Row of data returned by a Spark SQL query. + A Row of data returned by a Spark SQL query. - L{HiveContext} - Main entry point for accessing data stored in Apache Hive.. + Main entry point for accessing data stored in Apache Hive.. """ import itertools From 939f276fe42feb3c333e39f371e9c6400fe22ddc Mon Sep 17 00:00:00 2001 From: Aaron Davidson Date: Sun, 12 Oct 2014 16:45:55 -0700 Subject: [PATCH 35/46] Attempt to make comm. bidirectional --- .../spark/network/BlockFetchingListener.scala | 4 +- .../spark/network/BlockTransferService.scala | 15 ++- .../network/netty/NettyBlockFetcher.scala | 6 +- .../network/netty/NettyBlockRpcServer.scala | 33 ++++-- .../netty/NettyBlockTransferService.scala | 56 ++++++--- .../network/nio/NioBlockTransferService.scala | 4 +- .../apache/spark/serializer/Serializer.scala | 51 +++++++- .../apache/spark/storage/BlockManager.scala | 6 +- .../org/apache/spark/util/AkkaUtils.scala | 2 +- .../org/apache/spark/ShuffleNettySuite.scala | 4 +- .../apache/spark/network/SluiceContext.java | 111 ++++++++++++++++++ .../spark/network/client/SluiceClient.java | 50 ++++---- .../network/client/SluiceClientFactory.java | 55 +++++---- ...andler.java => SluiceResponseHandler.java} | 55 +++++---- .../ClientRequest.java => Message.java} | 26 ++-- .../protocol/request/ChunkFetchRequest.java | 4 +- .../request/ClientRequestDecoder.java | 57 --------- .../request/ClientRequestEncoder.java | 46 -------- .../protocol/request/RequestMessage.java | 25 ++++ .../network/protocol/request/RpcRequest.java | 6 +- .../protocol/response/ChunkFetchFailure.java | 7 +- .../protocol/response/ChunkFetchSuccess.java | 4 +- ...sponseDecoder.java => MessageDecoder.java} | 22 +++- ...sponseEncoder.java => MessageEncoder.java} | 16 ++- .../protocol/response/ResponseMessage.java | 25 ++++ .../network/protocol/response/RpcFailure.java | 7 +- .../protocol/response/RpcResponse.java | 2 +- .../protocol/response/ServerResponse.java | 63 ---------- .../network/server/DefaultStreamManager.java | 14 ++- .../spark/network/server/MessageHandler.java | 36 ++++++ .../spark/network/server/RpcHandler.java | 9 +- .../network/server/SluiceChannelHandler.java | 88 ++++++++++++++ ...Handler.java => SluiceRequestHandler.java} | 71 ++++++----- .../spark/network/server/SluiceServer.java | 26 ++-- .../spark/network/server/StreamManager.java | 2 +- .../org/apache/spark/network/util/IOMode.java | 2 +- .../apache/spark/network/util/JavaUtils.java | 14 ++- .../apache/spark/network/util/NettyUtils.java | 21 ++-- .../spark/network/IntegrationSuite.java | 22 ++-- .../apache/spark/network/NoOpRpcHandler.java | 3 +- .../apache/spark/network/ProtocolSuite.java | 21 ++-- .../network/SluiceClientFactorySuite.java | 12 +- .../network/SluiceClientHandlerSuite.java | 26 ++-- 43 files changed, 702 insertions(+), 427 deletions(-) create mode 100644 network/common/src/main/java/org/apache/spark/network/SluiceContext.java rename network/common/src/main/java/org/apache/spark/network/client/{SluiceClientHandler.java => SluiceResponseHandler.java} (75%) rename network/common/src/main/java/org/apache/spark/network/protocol/{request/ClientRequest.java => Message.java} (67%) delete mode 100644 network/common/src/main/java/org/apache/spark/network/protocol/request/ClientRequestDecoder.java delete mode 100644 network/common/src/main/java/org/apache/spark/network/protocol/request/ClientRequestEncoder.java create mode 100644 network/common/src/main/java/org/apache/spark/network/protocol/request/RequestMessage.java rename network/common/src/main/java/org/apache/spark/network/protocol/response/{ServerResponseDecoder.java => MessageDecoder.java} (70%) rename network/common/src/main/java/org/apache/spark/network/protocol/response/{ServerResponseEncoder.java => MessageEncoder.java} (78%) create mode 100644 network/common/src/main/java/org/apache/spark/network/protocol/response/ResponseMessage.java delete mode 100644 network/common/src/main/java/org/apache/spark/network/protocol/response/ServerResponse.java create mode 100644 network/common/src/main/java/org/apache/spark/network/server/MessageHandler.java create mode 100644 network/common/src/main/java/org/apache/spark/network/server/SluiceChannelHandler.java rename network/common/src/main/java/org/apache/spark/network/server/{SluiceServerHandler.java => SluiceRequestHandler.java} (65%) diff --git a/core/src/main/scala/org/apache/spark/network/BlockFetchingListener.scala b/core/src/main/scala/org/apache/spark/network/BlockFetchingListener.scala index e35fdb4e95899..645793fde806d 100644 --- a/core/src/main/scala/org/apache/spark/network/BlockFetchingListener.scala +++ b/core/src/main/scala/org/apache/spark/network/BlockFetchingListener.scala @@ -29,7 +29,9 @@ private[spark] trait BlockFetchingListener extends EventListener { /** - * Called once per successfully fetched block. + * Called once per successfully fetched block. After this call returns, data will be released + * automatically. If the data will be passed to another thread, the receiver should retain() + * and release() the buffer on their own, or copy the data to a new buffer. */ def onBlockFetchSuccess(blockId: String, data: ManagedBuffer): Unit diff --git a/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala b/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala index 8287a0fc81cfe..b083f465334fe 100644 --- a/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala +++ b/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala @@ -18,14 +18,14 @@ package org.apache.spark.network import java.io.Closeable - -import org.apache.spark.network.buffer.ManagedBuffer +import java.nio.ByteBuffer import scala.concurrent.{Await, Future} import scala.concurrent.duration.Duration import org.apache.spark.Logging -import org.apache.spark.storage.StorageLevel +import org.apache.spark.network.buffer.{NioManagedBuffer, ManagedBuffer} +import org.apache.spark.storage.{BlockId, StorageLevel} import org.apache.spark.util.Utils private[spark] @@ -72,7 +72,7 @@ abstract class BlockTransferService extends Closeable with Logging { def uploadBlock( hostname: String, port: Int, - blockId: String, + blockId: BlockId, blockData: ManagedBuffer, level: StorageLevel): Future[Unit] @@ -94,7 +94,10 @@ abstract class BlockTransferService extends Closeable with Logging { } override def onBlockFetchSuccess(blockId: String, data: ManagedBuffer): Unit = { lock.synchronized { - result = Left(data) + val ret = ByteBuffer.allocate(data.size.toInt) + ret.put(data.nioByteBuffer()) + ret.flip() + result = Left(new NioManagedBuffer(ret)) lock.notify() } } @@ -126,7 +129,7 @@ abstract class BlockTransferService extends Closeable with Logging { def uploadBlockSync( hostname: String, port: Int, - blockId: String, + blockId: BlockId, blockData: ManagedBuffer, level: StorageLevel): Unit = { Await.result(uploadBlock(hostname, port, blockId, blockData, level), Duration.Inf) diff --git a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockFetcher.scala b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockFetcher.scala index aefd8a6335b2a..a03e7c39428ee 100644 --- a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockFetcher.scala +++ b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockFetcher.scala @@ -20,9 +20,10 @@ package org.apache.spark.network.netty import java.nio.ByteBuffer import java.util -import org.apache.spark.Logging +import org.apache.spark.{SparkConf, Logging} import org.apache.spark.network.BlockFetchingListener -import org.apache.spark.serializer.Serializer +import org.apache.spark.network.netty.NettyMessages._ +import org.apache.spark.serializer.{JavaSerializer, Serializer} import org.apache.spark.network.buffer.ManagedBuffer import org.apache.spark.network.client.{RpcResponseCallback, ChunkReceivedCallback, SluiceClient} import org.apache.spark.storage.BlockId @@ -52,7 +53,6 @@ class NettyBlockFetcher( val chunkCallback = new ChunkReceivedCallback { // On receipt of a chunk, pass it upwards as a block. def onSuccess(chunkIndex: Int, buffer: ManagedBuffer): Unit = Utils.logUncaughtExceptions { - buffer.retain() listener.onBlockFetchSuccess(blockIds(chunkIndex), buffer) } diff --git a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala index c8658ec98b82c..9206237256e0b 100644 --- a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala +++ b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala @@ -22,18 +22,24 @@ import java.nio.ByteBuffer import org.apache.spark.Logging import org.apache.spark.network.BlockDataManager import org.apache.spark.serializer.Serializer -import org.apache.spark.network.buffer.ManagedBuffer -import org.apache.spark.network.client.RpcResponseCallback +import org.apache.spark.network.buffer.{NioManagedBuffer, ManagedBuffer} +import org.apache.spark.network.client.{SluiceClient, RpcResponseCallback} import org.apache.spark.network.server.{DefaultStreamManager, RpcHandler} -import org.apache.spark.storage.BlockId +import org.apache.spark.storage.{StorageLevel, BlockId} import scala.collection.JavaConversions._ -/** Request to read a set of blocks. Returns [[ShuffleStreamHandle]] to identify the stream. */ -case class OpenBlocks(blockIds: Seq[BlockId]) +object NettyMessages { -/** Identifier for a fixed number of chunks to read from a stream created by [[OpenBlocks]]. */ -case class ShuffleStreamHandle(streamId: Long, numChunks: Int) + /** Request to read a set of blocks. Returns [[ShuffleStreamHandle]] to identify the stream. */ + case class OpenBlocks(blockIds: Seq[BlockId]) + + /** Request to upload a block with a certain StorageLevel. Returns nothing (empty byte array). */ + case class UploadBlock(blockId: BlockId, blockData: Array[Byte], level: StorageLevel) + + /** Identifier for a fixed number of chunks to read from a stream created by [[OpenBlocks]]. */ + case class ShuffleStreamHandle(streamId: Long, numChunks: Int) +} /** * Serves requests to open blocks by simply registering one chunk per block requested. @@ -44,16 +50,27 @@ class NettyBlockRpcServer( blockManager: BlockDataManager) extends RpcHandler with Logging { - override def receive(messageBytes: Array[Byte], responseContext: RpcResponseCallback): Unit = { + import NettyMessages._ + + override def receive( + client: SluiceClient, + messageBytes: Array[Byte], + responseContext: RpcResponseCallback): Unit = { val ser = serializer.newInstance() val message = ser.deserialize[AnyRef](ByteBuffer.wrap(messageBytes)) logTrace(s"Received request: $message") + message match { case OpenBlocks(blockIds) => val blocks: Seq[ManagedBuffer] = blockIds.map(blockManager.getBlockData) val streamId = streamManager.registerStream(blocks.iterator) + logTrace(s"Registered streamId $streamId with ${blocks.size} buffers") responseContext.onSuccess( ser.serialize(new ShuffleStreamHandle(streamId, blocks.size)).array()) + + case UploadBlock(blockId, blockData, level) => + blockManager.putBlockData(blockId, new NioManagedBuffer(ByteBuffer.wrap(blockData)), level) + responseContext.onSuccess(new Array[Byte](0)) } } } diff --git a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala index 7576d51e22175..6145c86c65617 100644 --- a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala +++ b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala @@ -17,24 +17,23 @@ package org.apache.spark.network.netty +import scala.concurrent.{Promise, Future} + import org.apache.spark.SparkConf import org.apache.spark.network._ import org.apache.spark.network.buffer.ManagedBuffer -import org.apache.spark.network.client.{SluiceClient, SluiceClientFactory} -import org.apache.spark.network.server.{DefaultStreamManager, SluiceServer} +import org.apache.spark.network.client.{RpcResponseCallback, SluiceClient, SluiceClientFactory} +import org.apache.spark.network.netty.NettyMessages.UploadBlock +import org.apache.spark.network.server._ import org.apache.spark.network.util.{ConfigProvider, SluiceConfig} import org.apache.spark.serializer.JavaSerializer -import org.apache.spark.storage.StorageLevel +import org.apache.spark.storage.{BlockId, StorageLevel} import org.apache.spark.util.Utils -import scala.concurrent.Future - /** * A BlockTransferService that uses Netty to fetch a set of blocks at at time. */ class NettyBlockTransferService(conf: SparkConf) extends BlockTransferService { - var client: SluiceClient = _ - // TODO: Don't use Java serialization, use a more cross-version compatible serialization format. val serializer = new JavaSerializer(conf) @@ -42,22 +41,24 @@ class NettyBlockTransferService(conf: SparkConf) extends BlockTransferService { private[this] val sluiceConf = new SluiceConfig( new ConfigProvider { override def get(name: String) = conf.get(name) }) + private[this] var sluiceContext: SluiceContext = _ private[this] var server: SluiceServer = _ private[this] var clientFactory: SluiceClientFactory = _ override def init(blockDataManager: BlockDataManager): Unit = { val streamManager = new DefaultStreamManager val rpcHandler = new NettyBlockRpcServer(serializer, streamManager, blockDataManager) - server = new SluiceServer(sluiceConf, streamManager, rpcHandler) - clientFactory = new SluiceClientFactory(sluiceConf) + sluiceContext = new SluiceContext(sluiceConf, streamManager, rpcHandler) + clientFactory = sluiceContext.createClientFactory() + server = sluiceContext.createServer() } override def fetchBlocks( - hostName: String, + hostname: String, port: Int, blockIds: Seq[String], listener: BlockFetchingListener): Unit = { - val client = clientFactory.createClient(hostName, port) + val client = clientFactory.createClient(hostname, port) new NettyBlockFetcher(serializer, client, blockIds, listener) } @@ -65,13 +66,40 @@ class NettyBlockTransferService(conf: SparkConf) extends BlockTransferService { override def port: Int = server.getPort - // TODO: Implement override def uploadBlock( hostname: String, port: Int, - blockId: String, + blockId: BlockId, blockData: ManagedBuffer, - level: StorageLevel): Future[Unit] = ??? + level: StorageLevel): Future[Unit] = { + val result = Promise[Unit]() + val client = clientFactory.createClient(hostname, port) + + // Convert or copy nio buffer into array in order to serialize it. + val nioBuffer = blockData.nioByteBuffer() + val array = if (nioBuffer.hasArray) { + nioBuffer.array() + } else { + val data = new Array[Byte](nioBuffer.remaining()) + nioBuffer.get(data) + data + } + + val ser = serializer.newInstance() + client.sendRpc(ser.serialize(new UploadBlock(blockId, array, level)).array(), + new RpcResponseCallback { + override def onSuccess(response: Array[Byte]): Unit = { + logTrace(s"Successfully uploaded block $blockId") + result.success() + } + override def onFailure(e: Throwable): Unit = { + logError(s"Error while uploading block $blockId", e) + result.failure(e) + } + }) + + result.future + } override def close(): Unit = server.close() } diff --git a/core/src/main/scala/org/apache/spark/network/nio/NioBlockTransferService.scala b/core/src/main/scala/org/apache/spark/network/nio/NioBlockTransferService.scala index bce1069548437..e91f0af0e87a7 100644 --- a/core/src/main/scala/org/apache/spark/network/nio/NioBlockTransferService.scala +++ b/core/src/main/scala/org/apache/spark/network/nio/NioBlockTransferService.scala @@ -127,12 +127,12 @@ final class NioBlockTransferService(conf: SparkConf, securityManager: SecurityMa override def uploadBlock( hostname: String, port: Int, - blockId: String, + blockId: BlockId, blockData: ManagedBuffer, level: StorageLevel) : Future[Unit] = { checkInit() - val msg = PutBlock(BlockId(blockId), blockData.nioByteBuffer(), level) + val msg = PutBlock(blockId, blockData.nioByteBuffer(), level) val blockMessageArray = new BlockMessageArray(BlockMessage.fromPutBlock(msg)) val remoteCmId = new ConnectionManagerId(hostName, port) val reply = cm.sendMessageReliably(remoteCmId, blockMessageArray.toBufferMessage) diff --git a/core/src/main/scala/org/apache/spark/serializer/Serializer.scala b/core/src/main/scala/org/apache/spark/serializer/Serializer.scala index a9144cdd97b8c..4024dea31845c 100644 --- a/core/src/main/scala/org/apache/spark/serializer/Serializer.scala +++ b/core/src/main/scala/org/apache/spark/serializer/Serializer.scala @@ -17,14 +17,14 @@ package org.apache.spark.serializer -import java.io.{ByteArrayOutputStream, EOFException, InputStream, OutputStream} +import java.io._ import java.nio.ByteBuffer import scala.reflect.ClassTag -import org.apache.spark.SparkEnv +import org.apache.spark.{SparkConf, SparkEnv} import org.apache.spark.annotation.DeveloperApi -import org.apache.spark.util.{ByteBufferInputStream, NextIterator} +import org.apache.spark.util.{Utils, ByteBufferInputStream, NextIterator} /** * :: DeveloperApi :: @@ -142,3 +142,48 @@ abstract class DeserializationStream { } } } + + +class NoOpReadSerializer(conf: SparkConf) extends Serializer with Serializable { + override def newInstance(): SerializerInstance = { + new NoOpReadSerializerInstance() + } +} + +private[spark] class NoOpReadSerializerInstance() + extends SerializerInstance { + + override def serialize[T: ClassTag](t: T): ByteBuffer = { + val bos = new ByteArrayOutputStream() + val out = serializeStream(bos) + out.writeObject(t) + out.close() + ByteBuffer.wrap(bos.toByteArray) + } + + override def deserialize[T: ClassTag](bytes: ByteBuffer): T = { + null.asInstanceOf[T] + } + + override def deserialize[T: ClassTag](bytes: ByteBuffer, loader: ClassLoader): T = { + null.asInstanceOf[T] + } + + override def serializeStream(s: OutputStream): SerializationStream = { + new JavaSerializationStream(s, 100) + } + + override def deserializeStream(s: InputStream): DeserializationStream = { + new NoOpDeserializationStream(s, Utils.getContextOrSparkClassLoader) + } + + def deserializeStream(s: InputStream, loader: ClassLoader): DeserializationStream = { + new NoOpDeserializationStream(s, loader) + } +} + +private[spark] class NoOpDeserializationStream(in: InputStream, loader: ClassLoader) + extends DeserializationStream { + def readObject[T: ClassTag](): T = throw new EOFException() + def close() { } +} diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index 4d8b5c1e1b084..6bbc49f9de829 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -855,9 +855,9 @@ private[spark] class BlockManager( data.rewind() logTrace(s"Trying to replicate $blockId of ${data.limit()} bytes to $peer") blockTransferService.uploadBlockSync( - peer.host, peer.port, blockId.toString, new NioManagedBuffer(data), tLevel) - logTrace(s"Replicated $blockId of ${data.limit()} bytes to $peer in %f ms" - .format((System.currentTimeMillis - onePeerStartTime))) + peer.host, peer.port, blockId, new NioManagedBuffer(data), tLevel) + logTrace(s"Replicated $blockId of ${data.limit()} bytes to $peer in %s ms" + .format(System.currentTimeMillis - onePeerStartTime)) peersReplicatedTo += peer peersForReplication -= peer replicationFailed = false diff --git a/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala b/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala index e2d32c859bbda..f41c8d0315cb3 100644 --- a/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala +++ b/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala @@ -77,7 +77,7 @@ private[spark] object AkkaUtils extends Logging { val logAkkaConfig = if (conf.getBoolean("spark.akka.logAkkaConfig", false)) "on" else "off" - val akkaHeartBeatPauses = conf.getInt("spark.akka.heartbeat.pauses", 600) + val akkaHeartBeatPauses = conf.getInt("spark.akka.heartbeat.pauses", 6000) val akkaFailureDetector = conf.getDouble("spark.akka.failure-detector.threshold", 300.0) val akkaHeartBeatInterval = conf.getInt("spark.akka.heartbeat.interval", 1000) diff --git a/core/src/test/scala/org/apache/spark/ShuffleNettySuite.scala b/core/src/test/scala/org/apache/spark/ShuffleNettySuite.scala index d7b2d2e1e330f..840d8273cb6a8 100644 --- a/core/src/test/scala/org/apache/spark/ShuffleNettySuite.scala +++ b/core/src/test/scala/org/apache/spark/ShuffleNettySuite.scala @@ -24,10 +24,10 @@ class ShuffleNettySuite extends ShuffleSuite with BeforeAndAfterAll { // This test suite should run all tests in ShuffleSuite with Netty shuffle mode. override def beforeAll() { - System.setProperty("spark.shuffle.use.netty", "true") + System.setProperty("spark.shuffle.blockTransferService", "netty") } override def afterAll() { - System.clearProperty("spark.shuffle.use.netty") + System.clearProperty("spark.shuffle.blockTransferService") } } diff --git a/network/common/src/main/java/org/apache/spark/network/SluiceContext.java b/network/common/src/main/java/org/apache/spark/network/SluiceContext.java new file mode 100644 index 0000000000000..7845ceb8b7d06 --- /dev/null +++ b/network/common/src/main/java/org/apache/spark/network/SluiceContext.java @@ -0,0 +1,111 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network; + +import io.netty.channel.Channel; +import io.netty.channel.socket.SocketChannel; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import org.apache.spark.network.client.SluiceClient; +import org.apache.spark.network.client.SluiceClientFactory; +import org.apache.spark.network.client.SluiceResponseHandler; +import org.apache.spark.network.protocol.response.MessageDecoder; +import org.apache.spark.network.protocol.response.MessageEncoder; +import org.apache.spark.network.server.RpcHandler; +import org.apache.spark.network.server.SluiceChannelHandler; +import org.apache.spark.network.server.SluiceRequestHandler; +import org.apache.spark.network.server.SluiceServer; +import org.apache.spark.network.server.StreamManager; +import org.apache.spark.network.util.NettyUtils; +import org.apache.spark.network.util.SluiceConfig; + +/** + * Contains the context to create a {@link SluiceServer}, {@link SluiceClientFactory}, and to setup + * Netty Channel pipelines with a {@link SluiceChannelHandler}. + * + * The SluiceServer and SluiceClientFactory both create a SluiceChannelHandler for each channel. + * As each SluiceChannelHandler contains a SluiceClient, this enables server processes to send + * messages back to the client on an existing channel. + */ +public class SluiceContext { + private final Logger logger = LoggerFactory.getLogger(SluiceContext.class); + + private final SluiceConfig conf; + private final StreamManager streamManager; + private final RpcHandler rpcHandler; + + private final MessageEncoder encoder; + private final MessageDecoder decoder; + + public SluiceContext(SluiceConfig conf, StreamManager streamManager, RpcHandler rpcHandler) { + this.conf = conf; + this.streamManager = streamManager; + this.rpcHandler = rpcHandler; + this.encoder = new MessageEncoder(); + this.decoder = new MessageDecoder(); + } + + public SluiceClientFactory createClientFactory() { + return new SluiceClientFactory(this); + } + + public SluiceServer createServer() { + return new SluiceServer(this); + } + + /** + * Initializes a client or server Netty Channel Pipeline which encodes/decodes messages and + * has a {@link SluiceChannelHandler} to handle request or response messages. + * + * @return Returns the created SluiceChannelHandler, which includes a SluiceClient that can be + * used to communicate on this channel. The SluiceClient is directly associated with a + * ChannelHandler to ensure all users of the same channel get the same SluiceClient object. + */ + public SluiceChannelHandler initializePipeline(SocketChannel channel) { + try { + SluiceChannelHandler channelHandler = createChannelHandler(channel); + channel.pipeline() + .addLast("encoder", encoder) + .addLast("frameDecoder", NettyUtils.createFrameDecoder()) + .addLast("decoder", decoder) + // NOTE: Chunks are currently guaranteed to be returned in the order of request, but this + // would require more logic to guarantee if this were not part of the same event loop. + .addLast("handler", channelHandler); + return channelHandler; + } catch (RuntimeException e) { + logger.error("Error while initializing Netty pipeline", e); + throw e; + } + } + + /** + * Creates the server- and client-side handler which is used to handle both RequestMessages and + * ResponseMessages. The channel is expected to have been successfully created, though certain + * properties (such as the remoteAddress()) may not be available yet. + */ + private SluiceChannelHandler createChannelHandler(Channel channel) { + SluiceResponseHandler responseHandler = new SluiceResponseHandler(channel); + SluiceClient client = new SluiceClient(channel, responseHandler); + SluiceRequestHandler requestHandler = new SluiceRequestHandler(channel, client, streamManager, + rpcHandler); + return new SluiceChannelHandler(client, responseHandler, requestHandler); + } + + public SluiceConfig getConf() { return conf; } +} diff --git a/network/common/src/main/java/org/apache/spark/network/client/SluiceClient.java b/network/common/src/main/java/org/apache/spark/network/client/SluiceClient.java index 1f7d3b0234e38..d6d97981eebd6 100644 --- a/network/common/src/main/java/org/apache/spark/network/client/SluiceClient.java +++ b/network/common/src/main/java/org/apache/spark/network/client/SluiceClient.java @@ -19,7 +19,10 @@ import java.io.Closeable; import java.util.UUID; +import java.util.concurrent.TimeUnit; +import com.google.common.base.Preconditions; +import io.netty.channel.Channel; import io.netty.channel.ChannelFuture; import io.netty.channel.ChannelFutureListener; import org.slf4j.Logger; @@ -28,6 +31,7 @@ import org.apache.spark.network.protocol.StreamChunkId; import org.apache.spark.network.protocol.request.ChunkFetchRequest; import org.apache.spark.network.protocol.request.RpcRequest; +import org.apache.spark.network.util.NettyUtils; /** * Client for fetching consecutive chunks of a pre-negotiated stream. This API is intended to allow @@ -50,7 +54,7 @@ * may be used for multiple streams, but any given stream must be restricted to a single client, * in order to avoid out-of-order responses. * - * NB: This class is used to make requests to the server, while {@link SluiceClientHandler} is + * NB: This class is used to make requests to the server, while {@link SluiceResponseHandler} is * responsible for handling responses from the server. * * Concurrency: thread safe and can be called from multiple threads. @@ -58,24 +62,16 @@ public class SluiceClient implements Closeable { private final Logger logger = LoggerFactory.getLogger(SluiceClient.class); - private final ChannelFuture cf; - private final SluiceClientHandler handler; + private final Channel channel; + private final SluiceResponseHandler handler; - private final String serverAddr; - - SluiceClient(ChannelFuture cf, SluiceClientHandler handler) { - this.cf = cf; - this.handler = handler; - - if (cf != null && cf.channel() != null && cf.channel().remoteAddress() != null) { - serverAddr = cf.channel().remoteAddress().toString(); - } else { - serverAddr = ""; - } + public SluiceClient(Channel channel, SluiceResponseHandler handler) { + this.channel = Preconditions.checkNotNull(channel); + this.handler = Preconditions.checkNotNull(handler); } public boolean isActive() { - return cf.channel().isActive(); + return channel.isOpen() || channel.isRegistered() || channel.isActive(); } /** @@ -97,28 +93,27 @@ public void fetchChunk( long streamId, final int chunkIndex, final ChunkReceivedCallback callback) { + final String serverAddr = NettyUtils.getRemoteAddress(channel); final long startTime = System.currentTimeMillis(); logger.debug("Sending fetch chunk request {} to {}", chunkIndex, serverAddr); final StreamChunkId streamChunkId = new StreamChunkId(streamId, chunkIndex); handler.addFetchRequest(streamChunkId, callback); - cf.channel().writeAndFlush(new ChunkFetchRequest(streamChunkId)).addListener( + channel.writeAndFlush(new ChunkFetchRequest(streamChunkId)).addListener( new ChannelFutureListener() { @Override public void operationComplete(ChannelFuture future) throws Exception { if (future.isSuccess()) { long timeTaken = System.currentTimeMillis() - startTime; logger.debug("Sending request {} to {} took {} ms", streamChunkId, serverAddr, - timeTaken); + timeTaken); } else { - // Fail all blocks. String errorMsg = String.format("Failed to send request %s to %s: %s", streamChunkId, - serverAddr, future.cause().getMessage()); + serverAddr, future.cause()); logger.error(errorMsg, future.cause()); - future.cause().printStackTrace(); handler.removeFetchRequest(streamChunkId); - callback.onFailure(chunkIndex, new RuntimeException(errorMsg)); + callback.onFailure(chunkIndex, new RuntimeException(errorMsg, future.cause())); } } }); @@ -129,13 +124,14 @@ public void operationComplete(ChannelFuture future) throws Exception { * with the server's response or upon any failure. */ public void sendRpc(byte[] message, final RpcResponseCallback callback) { + final String serverAddr = NettyUtils.getRemoteAddress(channel); final long startTime = System.currentTimeMillis(); logger.debug("Sending RPC to {}", serverAddr); final long tag = UUID.randomUUID().getLeastSignificantBits(); handler.addRpcRequest(tag, callback); - cf.channel().writeAndFlush(new RpcRequest(tag, message)).addListener( + channel.writeAndFlush(new RpcRequest(tag, message)).addListener( new ChannelFutureListener() { @Override public void operationComplete(ChannelFuture future) throws Exception { @@ -143,12 +139,11 @@ public void operationComplete(ChannelFuture future) throws Exception { long timeTaken = System.currentTimeMillis() - startTime; logger.debug("Sending request {} to {} took {} ms", tag, serverAddr, timeTaken); } else { - // Fail all blocks. - String errorMsg = String.format("Failed to send request %s to %s: %s", tag, - serverAddr, future.cause().getMessage()); + String errorMsg = String.format("Failed to send RPC %s to %s: %s", tag, + serverAddr, future.cause()); logger.error(errorMsg, future.cause()); handler.removeRpcRequest(tag); - callback.onFailure(new RuntimeException(errorMsg)); + callback.onFailure(new RuntimeException(errorMsg, future.cause())); } } }); @@ -156,6 +151,7 @@ public void operationComplete(ChannelFuture future) throws Exception { @Override public void close() { - cf.channel().close(); + // close is a local operation and should finish with milliseconds; timeout just to be safe + channel.close().awaitUninterruptibly(10, TimeUnit.SECONDS); } } diff --git a/network/common/src/main/java/org/apache/spark/network/client/SluiceClientFactory.java b/network/common/src/main/java/org/apache/spark/network/client/SluiceClientFactory.java index 17491dc3f8720..5de998ef6ed55 100644 --- a/network/common/src/main/java/org/apache/spark/network/client/SluiceClientFactory.java +++ b/network/common/src/main/java/org/apache/spark/network/client/SluiceClientFactory.java @@ -21,7 +21,6 @@ import java.lang.reflect.Field; import java.net.InetSocketAddress; import java.net.SocketAddress; -import java.util.Map; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.TimeoutException; @@ -37,8 +36,10 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import org.apache.spark.network.protocol.request.ClientRequestEncoder; -import org.apache.spark.network.protocol.response.ServerResponseDecoder; +import org.apache.spark.network.SluiceContext; +import org.apache.spark.network.protocol.response.MessageDecoder; +import org.apache.spark.network.protocol.response.MessageEncoder; +import org.apache.spark.network.server.SluiceChannelHandler; import org.apache.spark.network.util.IOMode; import org.apache.spark.network.util.NettyUtils; import org.apache.spark.network.util.SluiceConfig; @@ -53,19 +54,17 @@ public class SluiceClientFactory implements Closeable { private final Logger logger = LoggerFactory.getLogger(SluiceClientFactory.class); + private final SluiceContext context; private final SluiceConfig conf; - private final Map connectionPool; - private final ClientRequestEncoder encoder; - private final ServerResponseDecoder decoder; + private final ConcurrentHashMap connectionPool; private final Class socketChannelClass; private final EventLoopGroup workerGroup; - public SluiceClientFactory(SluiceConfig conf) { - this.conf = conf; + public SluiceClientFactory(SluiceContext context) { + this.context = context; + this.conf = context.getConf(); this.connectionPool = new ConcurrentHashMap(); - this.encoder = new ClientRequestEncoder(); - this.decoder = new ServerResponseDecoder(); IOMode ioMode = IOMode.valueOf(conf.ioMode()); this.socketChannelClass = NettyUtils.getClientChannelClass(ioMode); @@ -82,18 +81,18 @@ public SluiceClientFactory(SluiceConfig conf) { public SluiceClient createClient(String remoteHost, int remotePort) throws TimeoutException { // Get connection from the connection pool first. // If it is not found or not active, create a new one. - InetSocketAddress address = new InetSocketAddress(remoteHost, remotePort); + final InetSocketAddress address = new InetSocketAddress(remoteHost, remotePort); SluiceClient cachedClient = connectionPool.get(address); if (cachedClient != null && cachedClient.isActive()) { + System.out.println("Reusing cached client: " + cachedClient); return cachedClient; + } else if (cachedClient != null) { + connectionPool.remove(address, cachedClient); // Remove inactive clients. } + System.out.println("Creating new client: " + cachedClient); logger.debug("Creating new connection to " + address); - // There is a chance two threads are creating two different clients connecting to the same host. - // But that's probably ok, as long as the caller hangs on to their client for a single stream. - final SluiceClientHandler handler = new SluiceClientHandler(); - Bootstrap bootstrap = new Bootstrap(); bootstrap.group(workerGroup) .channel(socketChannelClass) @@ -108,11 +107,14 @@ public SluiceClient createClient(String remoteHost, int remotePort) throws Timeo bootstrap.handler(new ChannelInitializer() { @Override public void initChannel(SocketChannel ch) { - ch.pipeline() - .addLast("clientRequestEncoder", encoder) - .addLast("frameDecoder", NettyUtils.createFrameDecoder()) - .addLast("serverResponseDecoder", decoder) - .addLast("handler", handler); + SluiceChannelHandler channelHandler = context.initializePipeline(ch); + SluiceClient oldClient = connectionPool.putIfAbsent(address, channelHandler.getClient()); + if (oldClient != null) { + logger.debug("Two clients were created concurrently, second one will be disposed."); + ch.close(); + // Note: this type of failure is still considered a success by Netty, and thus the + // ChannelFuture will complete successfully. + } } }); @@ -120,11 +122,18 @@ public void initChannel(SocketChannel ch) { ChannelFuture cf = bootstrap.connect(address); if (!cf.awaitUninterruptibly(conf.connectionTimeoutMs())) { throw new TimeoutException( - String.format("Connecting to %s timed out (%s ms)", address, conf.connectionTimeoutMs())); + String.format("Connecting to %s timed out (%s ms)", address, conf.connectionTimeoutMs())); } - SluiceClient client = new SluiceClient(cf, handler); - connectionPool.put(address, client); + SluiceClient client = connectionPool.get(address); + if (client == null) { + // The only way we should be able to reach here is if the client we created started out + // in the "inactive" state, and someone else simultaneously tried to create another client to + // the same server. This is an error condition, as the first client failed to connect. + throw new IllegalStateException("Client was unset! Must have been immediately inactive."); + } else if (!client.isActive()) { + throw new IllegalStateException("Failed to create active client."); + } return client; } diff --git a/network/common/src/main/java/org/apache/spark/network/client/SluiceClientHandler.java b/network/common/src/main/java/org/apache/spark/network/client/SluiceResponseHandler.java similarity index 75% rename from network/common/src/main/java/org/apache/spark/network/client/SluiceClientHandler.java rename to network/common/src/main/java/org/apache/spark/network/client/SluiceResponseHandler.java index ed20b032931c3..9fbd487da86a7 100644 --- a/network/common/src/main/java/org/apache/spark/network/client/SluiceClientHandler.java +++ b/network/common/src/main/java/org/apache/spark/network/client/SluiceResponseHandler.java @@ -17,37 +17,43 @@ package org.apache.spark.network.client; -import java.net.SocketAddress; import java.util.Map; import java.util.concurrent.ConcurrentHashMap; import com.google.common.annotations.VisibleForTesting; -import io.netty.channel.ChannelHandlerContext; -import io.netty.channel.SimpleChannelInboundHandler; +import io.netty.channel.Channel; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import org.apache.spark.network.protocol.response.ResponseMessage; import org.apache.spark.network.protocol.StreamChunkId; import org.apache.spark.network.protocol.response.ChunkFetchFailure; import org.apache.spark.network.protocol.response.ChunkFetchSuccess; import org.apache.spark.network.protocol.response.RpcFailure; import org.apache.spark.network.protocol.response.RpcResponse; -import org.apache.spark.network.protocol.response.ServerResponse; +import org.apache.spark.network.server.MessageHandler; +import org.apache.spark.network.util.NettyUtils; /** - * Handler that processes server responses, in response to requests issued from [[SluiceClient]]. + * Handler that processes server responses, in response to requests issued from a [[SluiceClient]]. * It works by tracking the list of outstanding requests (and their callbacks). * * Concurrency: thread safe and can be called from multiple threads. */ -public class SluiceClientHandler extends SimpleChannelInboundHandler { - private final Logger logger = LoggerFactory.getLogger(SluiceClientHandler.class); +public class SluiceResponseHandler extends MessageHandler { + private final Logger logger = LoggerFactory.getLogger(SluiceResponseHandler.class); - private final Map outstandingFetches = - new ConcurrentHashMap(); + private final Channel channel; - private final Map outstandingRpcs = - new ConcurrentHashMap(); + private final Map outstandingFetches; + + private final Map outstandingRpcs; + + public SluiceResponseHandler(Channel channel) { + this.channel = channel; + this.outstandingFetches = new ConcurrentHashMap(); + this.outstandingRpcs = new ConcurrentHashMap(); + } public void addFetchRequest(StreamChunkId streamChunkId, ChunkReceivedCallback callback) { outstandingFetches.put(streamChunkId, callback); @@ -73,41 +79,36 @@ private void failOutstandingRequests(Throwable cause) { for (Map.Entry entry : outstandingFetches.entrySet()) { entry.getValue().onFailure(entry.getKey().chunkIndex, cause); } - // TODO(rxin): Maybe we need to synchronize the access? Otherwise we could clear new requests - // as well. But I guess that is ok given the caller will fail as soon as any requests fail. + // It's OK if new fetches appear, as they will fail immediately. outstandingFetches.clear(); } @Override - public void channelUnregistered(ChannelHandlerContext ctx) throws Exception { + public void channelUnregistered() { if (outstandingFetches.size() > 0) { - SocketAddress remoteAddress = ctx.channel().remoteAddress(); + String remoteAddress = NettyUtils.getRemoteAddress(channel); logger.error("Still have {} requests outstanding when contention from {} is closed", outstandingFetches.size(), remoteAddress); failOutstandingRequests(new RuntimeException("Connection from " + remoteAddress + " closed")); } - super.channelUnregistered(ctx); } @Override - public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception { + public void exceptionCaught(Throwable cause) { if (outstandingFetches.size() > 0) { - logger.error(String.format("Exception in connection from %s: %s", - ctx.channel().remoteAddress(), cause.getMessage()), cause); failOutstandingRequests(cause); } - ctx.close(); } @Override - public void channelRead0(ChannelHandlerContext ctx, ServerResponse message) { - String server = ctx.channel().remoteAddress().toString(); + public void handle(ResponseMessage message) { + String remoteAddress = NettyUtils.getRemoteAddress(channel); if (message instanceof ChunkFetchSuccess) { ChunkFetchSuccess resp = (ChunkFetchSuccess) message; ChunkReceivedCallback listener = outstandingFetches.get(resp.streamChunkId); if (listener == null) { logger.warn("Got a response for block {} from {} but it is not outstanding", - resp.streamChunkId, server); + resp.streamChunkId, remoteAddress); resp.buffer.release(); } else { outstandingFetches.remove(resp.streamChunkId); @@ -119,7 +120,7 @@ public void channelRead0(ChannelHandlerContext ctx, ServerResponse message) { ChunkReceivedCallback listener = outstandingFetches.get(resp.streamChunkId); if (listener == null) { logger.warn("Got a response for block {} from {} ({}) but it is not outstanding", - resp.streamChunkId, server, resp.errorString); + resp.streamChunkId, remoteAddress, resp.errorString); } else { outstandingFetches.remove(resp.streamChunkId); listener.onFailure(resp.streamChunkId.chunkIndex, @@ -130,7 +131,7 @@ public void channelRead0(ChannelHandlerContext ctx, ServerResponse message) { RpcResponseCallback listener = outstandingRpcs.get(resp.tag); if (listener == null) { logger.warn("Got a response for RPC {} from {} ({} bytes) but it is not outstanding", - resp.tag, server, resp.response.length); + resp.tag, remoteAddress, resp.response.length); } else { outstandingRpcs.remove(resp.tag); listener.onSuccess(resp.response); @@ -140,11 +141,13 @@ public void channelRead0(ChannelHandlerContext ctx, ServerResponse message) { RpcResponseCallback listener = outstandingRpcs.get(resp.tag); if (listener == null) { logger.warn("Got a response for RPC {} from {} ({}) but it is not outstanding", - resp.tag, server, resp.errorString); + resp.tag, remoteAddress, resp.errorString); } else { outstandingRpcs.remove(resp.tag); listener.onFailure(new RuntimeException(resp.errorString)); } + } else { + throw new IllegalStateException("Unknown response type: " + message.type()); } } diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/request/ClientRequest.java b/network/common/src/main/java/org/apache/spark/network/protocol/Message.java similarity index 67% rename from network/common/src/main/java/org/apache/spark/network/protocol/request/ClientRequest.java rename to network/common/src/main/java/org/apache/spark/network/protocol/Message.java index db075c44b4cda..6731b3f53ae82 100644 --- a/network/common/src/main/java/org/apache/spark/network/protocol/request/ClientRequest.java +++ b/network/common/src/main/java/org/apache/spark/network/protocol/Message.java @@ -15,28 +15,24 @@ * limitations under the License. */ -package org.apache.spark.network.protocol.request; +package org.apache.spark.network.protocol; import io.netty.buffer.ByteBuf; -import org.apache.spark.network.protocol.Encodable; - /** Messages from the client to the server. */ -public interface ClientRequest extends Encodable { +public interface Message extends Encodable { /** Used to identify this request type. */ Type type(); - /** - * Preceding every serialized ClientRequest is the type, which allows us to deserialize - * the request. - */ + /** Preceding every serialized Message is its type, which allows us to deserialize it. */ public static enum Type implements Encodable { - ChunkFetchRequest(0), RpcRequest(1); + ChunkFetchRequest(0), ChunkFetchSuccess(1), ChunkFetchFailure(2), + RpcRequest(3), RpcResponse(4), RpcFailure(5); private final byte id; private Type(int id) { - assert id < 128 : "Cannot have more than 128 request types"; + assert id < 128 : "Cannot have more than 128 message types"; this.id = (byte) id; } @@ -48,10 +44,14 @@ private Type(int id) { public static Type decode(ByteBuf buf) { byte id = buf.readByte(); - switch(id) { + switch (id) { case 0: return ChunkFetchRequest; - case 1: return RpcRequest; - default: throw new IllegalArgumentException("Unknown request type: " + id); + case 1: return ChunkFetchSuccess; + case 2: return ChunkFetchFailure; + case 3: return RpcRequest; + case 4: return RpcResponse; + case 5: return RpcFailure; + default: throw new IllegalArgumentException("Unknown message type: " + id); } } } diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/request/ChunkFetchRequest.java b/network/common/src/main/java/org/apache/spark/network/protocol/request/ChunkFetchRequest.java index a79eb363cf58c..99cbb8777a873 100644 --- a/network/common/src/main/java/org/apache/spark/network/protocol/request/ChunkFetchRequest.java +++ b/network/common/src/main/java/org/apache/spark/network/protocol/request/ChunkFetchRequest.java @@ -24,9 +24,9 @@ /** * Request to fetch a sequence of a single chunk of a stream. This will correspond to a single - * {@link org.apache.spark.network.protocol.response.ServerResponse} (either success or failure). + * {@link org.apache.spark.network.protocol.response.ResponseMessage} (either success or failure). */ -public final class ChunkFetchRequest implements ClientRequest { +public final class ChunkFetchRequest implements RequestMessage { public final StreamChunkId streamChunkId; public ChunkFetchRequest(StreamChunkId streamChunkId) { diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/request/ClientRequestDecoder.java b/network/common/src/main/java/org/apache/spark/network/protocol/request/ClientRequestDecoder.java deleted file mode 100644 index a937da4cecae0..0000000000000 --- a/network/common/src/main/java/org/apache/spark/network/protocol/request/ClientRequestDecoder.java +++ /dev/null @@ -1,57 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network.protocol.request; - -import java.util.List; - -import io.netty.buffer.ByteBuf; -import io.netty.channel.ChannelHandler; -import io.netty.channel.ChannelHandlerContext; -import io.netty.handler.codec.MessageToMessageDecoder; - -/** - * Decoder in the server side to decode client requests. - * This decoder is stateless so it is safe to be shared by multiple threads. - * - * This assumes the inbound messages have been processed by a frame decoder created by - * {@link org.apache.spark.network.util.NettyUtils#createFrameDecoder()}. - */ -@ChannelHandler.Sharable -public final class ClientRequestDecoder extends MessageToMessageDecoder { - - @Override - public void decode(ChannelHandlerContext ctx, ByteBuf in, List out) { - ClientRequest.Type msgType = ClientRequest.Type.decode(in); - ClientRequest decoded = decode(msgType, in); - assert decoded.type() == msgType; - assert in.readableBytes() == 0; - out.add(decoded); - } - - private ClientRequest decode(ClientRequest.Type msgType, ByteBuf in) { - switch (msgType) { - case ChunkFetchRequest: - return ChunkFetchRequest.decode(in); - - case RpcRequest: - return RpcRequest.decode(in); - - default: throw new IllegalArgumentException("Unexpected message type: " + msgType); - } - } -} diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/request/ClientRequestEncoder.java b/network/common/src/main/java/org/apache/spark/network/protocol/request/ClientRequestEncoder.java deleted file mode 100644 index bcff4a0a25568..0000000000000 --- a/network/common/src/main/java/org/apache/spark/network/protocol/request/ClientRequestEncoder.java +++ /dev/null @@ -1,46 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network.protocol.request; - -import java.util.List; - -import io.netty.buffer.ByteBuf; -import io.netty.channel.ChannelHandler; -import io.netty.channel.ChannelHandlerContext; -import io.netty.handler.codec.MessageToMessageEncoder; - -/** - * Encoder for {@link ClientRequest} used in client side. - * - * This encoder is stateless so it is safe to be shared by multiple threads. - */ -@ChannelHandler.Sharable -public final class ClientRequestEncoder extends MessageToMessageEncoder { - @Override - public void encode(ChannelHandlerContext ctx, ClientRequest in, List out) { - ClientRequest.Type msgType = in.type(); - // Write 8 bytes for the frame's length, followed by the request type and request itself. - int frameLength = 8 + msgType.encodedLength() + in.encodedLength(); - ByteBuf buf = ctx.alloc().buffer(frameLength); - buf.writeLong(frameLength); - msgType.encode(buf); - in.encode(buf); - assert buf.writableBytes() == 0; - out.add(buf); - } -} diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/request/RequestMessage.java b/network/common/src/main/java/org/apache/spark/network/protocol/request/RequestMessage.java new file mode 100644 index 0000000000000..58abce25d9a2a --- /dev/null +++ b/network/common/src/main/java/org/apache/spark/network/protocol/request/RequestMessage.java @@ -0,0 +1,25 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.protocol.request; + +import org.apache.spark.network.protocol.Message; + +/** Messages from the client to the server. */ +public interface RequestMessage extends Message { + // token interface +} diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/request/RpcRequest.java b/network/common/src/main/java/org/apache/spark/network/protocol/request/RpcRequest.java index 126370330f723..810da7a689c13 100644 --- a/network/common/src/main/java/org/apache/spark/network/protocol/request/RpcRequest.java +++ b/network/common/src/main/java/org/apache/spark/network/protocol/request/RpcRequest.java @@ -24,10 +24,10 @@ /** * A generic RPC which is handled by a remote {@link org.apache.spark.network.server.RpcHandler}. - * This will correspond to a single {@link org.apache.spark.network.protocol.response.ServerResponse} - * (either success or failure). + * This will correspond to a single + * {@link org.apache.spark.network.protocol.response.ResponseMessage} (either success or failure). */ -public final class RpcRequest implements ClientRequest { +public final class RpcRequest implements RequestMessage { /** Tag is used to link an RPC request with its response. */ public final long tag; diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/response/ChunkFetchFailure.java b/network/common/src/main/java/org/apache/spark/network/protocol/response/ChunkFetchFailure.java index 3a57d71b4f3ea..18ed4d95bba4c 100644 --- a/network/common/src/main/java/org/apache/spark/network/protocol/response/ChunkFetchFailure.java +++ b/network/common/src/main/java/org/apache/spark/network/protocol/response/ChunkFetchFailure.java @@ -17,6 +17,7 @@ package org.apache.spark.network.protocol.response; +import com.google.common.base.Charsets; import com.google.common.base.Objects; import io.netty.buffer.ByteBuf; @@ -26,7 +27,7 @@ * Response to {@link org.apache.spark.network.protocol.request.ChunkFetchRequest} when there is an * error fetching the chunk. */ -public final class ChunkFetchFailure implements ServerResponse { +public final class ChunkFetchFailure implements ResponseMessage { public final StreamChunkId streamChunkId; public final String errorString; @@ -40,13 +41,13 @@ public ChunkFetchFailure(StreamChunkId streamChunkId, String errorString) { @Override public int encodedLength() { - return streamChunkId.encodedLength() + 4 + errorString.getBytes().length; + return streamChunkId.encodedLength() + 4 + errorString.getBytes(Charsets.UTF_8).length; } @Override public void encode(ByteBuf buf) { streamChunkId.encode(buf); - byte[] errorBytes = errorString.getBytes(); + byte[] errorBytes = errorString.getBytes(Charsets.UTF_8); buf.writeInt(errorBytes.length); buf.writeBytes(errorBytes); } diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/response/ChunkFetchSuccess.java b/network/common/src/main/java/org/apache/spark/network/protocol/response/ChunkFetchSuccess.java index 874dc4f5940cf..6bc26a64b9945 100644 --- a/network/common/src/main/java/org/apache/spark/network/protocol/response/ChunkFetchSuccess.java +++ b/network/common/src/main/java/org/apache/spark/network/protocol/response/ChunkFetchSuccess.java @@ -32,7 +32,7 @@ * may be written by Netty in a more efficient manner (i.e., zero-copy write). * Similarly, the client-side decoding will reuse the Netty ByteBuf as the buffer. */ -public final class ChunkFetchSuccess implements ServerResponse { +public final class ChunkFetchSuccess implements ResponseMessage { public final StreamChunkId streamChunkId; public final ManagedBuffer buffer; @@ -49,7 +49,7 @@ public int encodedLength() { return streamChunkId.encodedLength(); } - /** Encoding does NOT include buffer itself. See {@link ServerResponseEncoder}. */ + /** Encoding does NOT include buffer itself. See {@link MessageEncoder}. */ @Override public void encode(ByteBuf buf) { streamChunkId.encode(buf); diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/response/ServerResponseDecoder.java b/network/common/src/main/java/org/apache/spark/network/protocol/response/MessageDecoder.java similarity index 70% rename from network/common/src/main/java/org/apache/spark/network/protocol/response/ServerResponseDecoder.java rename to network/common/src/main/java/org/apache/spark/network/protocol/response/MessageDecoder.java index e06198284e620..3ae80305803eb 100644 --- a/network/common/src/main/java/org/apache/spark/network/protocol/response/ServerResponseDecoder.java +++ b/network/common/src/main/java/org/apache/spark/network/protocol/response/MessageDecoder.java @@ -23,30 +23,44 @@ import io.netty.channel.ChannelHandler; import io.netty.channel.ChannelHandlerContext; import io.netty.handler.codec.MessageToMessageDecoder; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import org.apache.spark.network.protocol.Message; +import org.apache.spark.network.protocol.request.ChunkFetchRequest; +import org.apache.spark.network.protocol.request.RpcRequest; /** * Decoder used by the client side to encode server-to-client responses. * This encoder is stateless so it is safe to be shared by multiple threads. */ @ChannelHandler.Sharable -public final class ServerResponseDecoder extends MessageToMessageDecoder { +public final class MessageDecoder extends MessageToMessageDecoder { + private final Logger logger = LoggerFactory.getLogger(MessageDecoder.class); @Override public void decode(ChannelHandlerContext ctx, ByteBuf in, List out) { - ServerResponse.Type msgType = ServerResponse.Type.decode(in); - ServerResponse decoded = decode(msgType, in); + Message.Type msgType = Message.Type.decode(in); + Message decoded = decode(msgType, in); assert decoded.type() == msgType; + logger.debug("Received message " + msgType + ": " + decoded); out.add(decoded); } - private ServerResponse decode(ServerResponse.Type msgType, ByteBuf in) { + private Message decode(Message.Type msgType, ByteBuf in) { switch (msgType) { + case ChunkFetchRequest: + return ChunkFetchRequest.decode(in); + case ChunkFetchSuccess: return ChunkFetchSuccess.decode(in); case ChunkFetchFailure: return ChunkFetchFailure.decode(in); + case RpcRequest: + return RpcRequest.decode(in); + case RpcResponse: return RpcResponse.decode(in); diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/response/ServerResponseEncoder.java b/network/common/src/main/java/org/apache/spark/network/protocol/response/MessageEncoder.java similarity index 78% rename from network/common/src/main/java/org/apache/spark/network/protocol/response/ServerResponseEncoder.java rename to network/common/src/main/java/org/apache/spark/network/protocol/response/MessageEncoder.java index 069f42463a8fe..5ca8de42a6429 100644 --- a/network/common/src/main/java/org/apache/spark/network/protocol/response/ServerResponseEncoder.java +++ b/network/common/src/main/java/org/apache/spark/network/protocol/response/MessageEncoder.java @@ -26,17 +26,25 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import org.apache.spark.network.protocol.Message; + /** * Encoder used by the server side to encode server-to-client responses. * This encoder is stateless so it is safe to be shared by multiple threads. */ @ChannelHandler.Sharable -public final class ServerResponseEncoder extends MessageToMessageEncoder { +public final class MessageEncoder extends MessageToMessageEncoder { - private final Logger logger = LoggerFactory.getLogger(ServerResponseEncoder.class); + private final Logger logger = LoggerFactory.getLogger(MessageEncoder.class); + /*** + * Encodes a Message by invoking its encode() method. For non-data messages, we will add one + * ByteBuf to 'out' containing the total frame length, the message type, and the message itself. + * In the case of a ChunkFetchSuccess, we will also add the ManagedBuffer corresponding to the + * data to 'out', in order to enable zero-copy transfer. + */ @Override - public void encode(ChannelHandlerContext ctx, ServerResponse in, List out) { + public void encode(ChannelHandlerContext ctx, Message in, List out) { Object body = null; long bodyLength = 0; @@ -56,7 +64,7 @@ public void encode(ChannelHandlerContext ctx, ServerResponse in, List ou } } - ServerResponse.Type msgType = in.type(); + Message.Type msgType = in.type(); // All messages have the frame length, message type, and message itself. int headerLength = 8 + msgType.encodedLength() + in.encodedLength(); long frameLength = headerLength + bodyLength; diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/response/ResponseMessage.java b/network/common/src/main/java/org/apache/spark/network/protocol/response/ResponseMessage.java new file mode 100644 index 0000000000000..8f545e91d1d8e --- /dev/null +++ b/network/common/src/main/java/org/apache/spark/network/protocol/response/ResponseMessage.java @@ -0,0 +1,25 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.protocol.response; + +import org.apache.spark.network.protocol.Message; + +/** Messages from the server to the client. */ +public interface ResponseMessage extends Message { + // token interface +} diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/response/RpcFailure.java b/network/common/src/main/java/org/apache/spark/network/protocol/response/RpcFailure.java index 274920b28bced..6b71da5708c58 100644 --- a/network/common/src/main/java/org/apache/spark/network/protocol/response/RpcFailure.java +++ b/network/common/src/main/java/org/apache/spark/network/protocol/response/RpcFailure.java @@ -17,11 +17,12 @@ package org.apache.spark.network.protocol.response; +import com.google.common.base.Charsets; import com.google.common.base.Objects; import io.netty.buffer.ByteBuf; /** Response to {@link org.apache.spark.network.protocol.request.RpcRequest} for a failed RPC. */ -public final class RpcFailure implements ServerResponse { +public final class RpcFailure implements ResponseMessage { public final long tag; public final String errorString; @@ -35,13 +36,13 @@ public RpcFailure(long tag, String errorString) { @Override public int encodedLength() { - return 8 + 4 + errorString.getBytes().length; + return 8 + 4 + errorString.getBytes(Charsets.UTF_8).length; } @Override public void encode(ByteBuf buf) { buf.writeLong(tag); - byte[] errorBytes = errorString.getBytes(); + byte[] errorBytes = errorString.getBytes(Charsets.UTF_8); buf.writeInt(errorBytes.length); buf.writeBytes(errorBytes); } diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/response/RpcResponse.java b/network/common/src/main/java/org/apache/spark/network/protocol/response/RpcResponse.java index 0c6f8acdcdc4b..40623ce31c666 100644 --- a/network/common/src/main/java/org/apache/spark/network/protocol/response/RpcResponse.java +++ b/network/common/src/main/java/org/apache/spark/network/protocol/response/RpcResponse.java @@ -23,7 +23,7 @@ import io.netty.buffer.ByteBuf; /** Response to {@link org.apache.spark.network.protocol.request.RpcRequest} for a successful RPC. */ -public final class RpcResponse implements ServerResponse { +public final class RpcResponse implements ResponseMessage { public final long tag; public final byte[] response; diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/response/ServerResponse.java b/network/common/src/main/java/org/apache/spark/network/protocol/response/ServerResponse.java deleted file mode 100644 index 335f9e8ea69f9..0000000000000 --- a/network/common/src/main/java/org/apache/spark/network/protocol/response/ServerResponse.java +++ /dev/null @@ -1,63 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network.protocol.response; - -import io.netty.buffer.ByteBuf; - -import org.apache.spark.network.protocol.Encodable; - -/** - * Messages from server to client (usually in response to some - * {@link org.apache.spark.network.protocol.request.ClientRequest}. - */ -public interface ServerResponse extends Encodable { - /** Used to identify this response type. */ - Type type(); - - /** - * Preceding every serialized ServerResponse is the type, which allows us to deserialize - * the response. - */ - public static enum Type implements Encodable { - ChunkFetchSuccess(0), ChunkFetchFailure(1), RpcResponse(2), RpcFailure(3); - - private final byte id; - - private Type(int id) { - assert id < 128 : "Cannot have more than 128 response types"; - this.id = (byte) id; - } - - public byte id() { return id; } - - @Override public int encodedLength() { return 1; } - - @Override public void encode(ByteBuf buf) { buf.writeByte(id); } - - public static Type decode(ByteBuf buf) { - byte id = buf.readByte(); - switch(id) { - case 0: return ChunkFetchSuccess; - case 1: return ChunkFetchFailure; - case 2: return RpcResponse; - case 3: return RpcFailure; - default: throw new IllegalArgumentException("Unknown response type: " + id); - } - } - } -} diff --git a/network/common/src/main/java/org/apache/spark/network/server/DefaultStreamManager.java b/network/common/src/main/java/org/apache/spark/network/server/DefaultStreamManager.java index 04814d9a88c4a..d93607a7c31ea 100644 --- a/network/common/src/main/java/org/apache/spark/network/server/DefaultStreamManager.java +++ b/network/common/src/main/java/org/apache/spark/network/server/DefaultStreamManager.java @@ -23,6 +23,9 @@ import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.atomic.AtomicLong; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + import org.apache.spark.network.buffer.ManagedBuffer; /** @@ -30,6 +33,8 @@ * fetched as chunks by the client. */ public class DefaultStreamManager extends StreamManager { + private final Logger logger = LoggerFactory.getLogger(DefaultStreamManager.class); + private final AtomicLong nextStreamId; private final Map streams; @@ -61,7 +66,14 @@ public ManagedBuffer getChunk(long streamId, int chunkIndex) { "Requested chunk index beyond end %s", chunkIndex)); } state.curChunk += 1; - return state.buffers.next(); + ManagedBuffer nextChunk = state.buffers.next(); + + if (!state.buffers.hasNext()) { + logger.trace("Removing stream id {}", streamId); + streams.remove(streamId); + } + + return nextChunk; } @Override diff --git a/network/common/src/main/java/org/apache/spark/network/server/MessageHandler.java b/network/common/src/main/java/org/apache/spark/network/server/MessageHandler.java new file mode 100644 index 0000000000000..b80c15106ecbd --- /dev/null +++ b/network/common/src/main/java/org/apache/spark/network/server/MessageHandler.java @@ -0,0 +1,36 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.server; + +import org.apache.spark.network.protocol.Message; + +/** + * Handles either request or response messages coming off of Netty. A MessageHandler instance + * is associated with a single Netty Channel (though it may have multiple clients on the same + * Channel.) + */ +public abstract class MessageHandler { + /** Handles the receipt of a single message. */ + public abstract void handle(T message); + + /** Invoked when an exception was caught on the Channel. */ + public abstract void exceptionCaught(Throwable cause); + + /** Invoked when the channel this MessageHandler is on has been unregistered. */ + public abstract void channelUnregistered(); +} diff --git a/network/common/src/main/java/org/apache/spark/network/server/RpcHandler.java b/network/common/src/main/java/org/apache/spark/network/server/RpcHandler.java index abfbe66d875e8..5700cc83bd9c8 100644 --- a/network/common/src/main/java/org/apache/spark/network/server/RpcHandler.java +++ b/network/common/src/main/java/org/apache/spark/network/server/RpcHandler.java @@ -18,6 +18,7 @@ package org.apache.spark.network.server; import org.apache.spark.network.client.RpcResponseCallback; +import org.apache.spark.network.client.SluiceClient; /** * Handler for sendRPC() messages sent by {@link org.apache.spark.network.client.SluiceClient}s. @@ -26,6 +27,12 @@ public interface RpcHandler { /** * Receive a single RPC message. Any exception thrown while in this method will be sent back to * the client in string form as a standard RPC failure. + * + * @param client A channel client which enables the handler to make requests back to the sender + * of this RPC. + * @param message The serialized bytes of the RPC. + * @param callback Callback which should be invoked exactly once upon success or failure of the + * RPC. */ - void receive(byte[] message, RpcResponseCallback callback); + void receive(SluiceClient client, byte[] message, RpcResponseCallback callback); } diff --git a/network/common/src/main/java/org/apache/spark/network/server/SluiceChannelHandler.java b/network/common/src/main/java/org/apache/spark/network/server/SluiceChannelHandler.java new file mode 100644 index 0000000000000..d5a91ec1b6c28 --- /dev/null +++ b/network/common/src/main/java/org/apache/spark/network/server/SluiceChannelHandler.java @@ -0,0 +1,88 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.server; + +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.SimpleChannelInboundHandler; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import org.apache.spark.network.client.SluiceClient; +import org.apache.spark.network.client.SluiceResponseHandler; +import org.apache.spark.network.protocol.Message; +import org.apache.spark.network.protocol.request.RequestMessage; +import org.apache.spark.network.protocol.response.ResponseMessage; +import org.apache.spark.network.util.NettyUtils; + +/** + * A handler which is used for delegating requests to the + * {@link org.apache.spark.network.server.SluiceRequestHandler} and responses to the + * {@link org.apache.spark.network.client.SluiceResponseHandler}. + * + * All channels created in Sluice are bidirectional. When the Client initiates a Netty Channel + * with a RequestMessage (which gets handled by the Server's RequestHandler), the Server will + * produce a ResponseMessage (handled by the Client's ResponseHandler). However, the Server also + * gets a handle on the same Channel, so it may then begin to send RequestMessages to the Client. + * This means that the Client also needs a RequestHandler and the Server needs a ResponseHandler, + * for the Client's responses to the Server's requests. + */ +public class SluiceChannelHandler extends SimpleChannelInboundHandler { + private final Logger logger = LoggerFactory.getLogger(SluiceChannelHandler.class); + + private final SluiceClient client; + private final SluiceResponseHandler responseHandler; + private final SluiceRequestHandler requestHandler; + + public SluiceChannelHandler( + SluiceClient client, + SluiceResponseHandler responseHandler, + SluiceRequestHandler requestHandler) { + this.client = client; + this.responseHandler = responseHandler; + this.requestHandler = requestHandler; + } + + public SluiceClient getClient() { + return client; + } + + @Override + public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception { + logger.warn("Exception in connection from " + NettyUtils.getRemoteAddress(ctx.channel()), + cause); + requestHandler.exceptionCaught(cause); + responseHandler.exceptionCaught(cause); + ctx.close(); + } + + @Override + public void channelUnregistered(ChannelHandlerContext ctx) throws Exception { + requestHandler.channelUnregistered(); + responseHandler.channelUnregistered(); + super.channelUnregistered(ctx); + } + + @Override + public void channelRead0(ChannelHandlerContext ctx, Message request) { + if (request instanceof RequestMessage) { + requestHandler.handle((RequestMessage) request); + } else { + responseHandler.handle((ResponseMessage) request); + } + } +} diff --git a/network/common/src/main/java/org/apache/spark/network/server/SluiceServerHandler.java b/network/common/src/main/java/org/apache/spark/network/server/SluiceRequestHandler.java similarity index 65% rename from network/common/src/main/java/org/apache/spark/network/server/SluiceServerHandler.java rename to network/common/src/main/java/org/apache/spark/network/server/SluiceRequestHandler.java index fad72fbfc711b..5f5111e0a7638 100644 --- a/network/common/src/main/java/org/apache/spark/network/server/SluiceServerHandler.java +++ b/network/common/src/main/java/org/apache/spark/network/server/SluiceRequestHandler.java @@ -21,33 +21,40 @@ import com.google.common.base.Throwables; import com.google.common.collect.Sets; +import io.netty.channel.Channel; import io.netty.channel.ChannelFuture; import io.netty.channel.ChannelFutureListener; -import io.netty.channel.ChannelHandlerContext; -import io.netty.channel.SimpleChannelInboundHandler; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.apache.spark.network.buffer.ManagedBuffer; import org.apache.spark.network.client.RpcResponseCallback; +import org.apache.spark.network.client.SluiceClient; import org.apache.spark.network.protocol.Encodable; +import org.apache.spark.network.protocol.request.RequestMessage; import org.apache.spark.network.protocol.request.ChunkFetchRequest; -import org.apache.spark.network.protocol.request.ClientRequest; import org.apache.spark.network.protocol.request.RpcRequest; import org.apache.spark.network.protocol.response.ChunkFetchFailure; import org.apache.spark.network.protocol.response.ChunkFetchSuccess; import org.apache.spark.network.protocol.response.RpcFailure; import org.apache.spark.network.protocol.response.RpcResponse; +import org.apache.spark.network.util.NettyUtils; /** - * A handler that processes requests from clients and writes chunk data back. Each handler keeps - * track of which streams have been fetched via this channel, in order to clean them up if the - * channel is terminated (see #channelUnregistered). + * A handler that processes requests from clients and writes chunk data back. Each handler is + * attached to a single Netty channel, and keeps track of which streams have been fetched via this + * channel, in order to clean them up if the channel is terminated (see #channelUnregistered). * * The messages should have been processed by the pipeline setup by {@link SluiceServer}. */ -public class SluiceServerHandler extends SimpleChannelInboundHandler { - private final Logger logger = LoggerFactory.getLogger(SluiceServerHandler.class); +public class SluiceRequestHandler extends MessageHandler { + private final Logger logger = LoggerFactory.getLogger(SluiceRequestHandler.class); + + /** The Netty channel that this handler is associated with. */ + private final Channel channel; + + /** Client on the same channel allowing us to talk back to the requester. */ + private final SluiceClient reverseClient; /** Returns each chunk part of a stream. */ private final StreamManager streamManager; @@ -58,22 +65,24 @@ public class SluiceServerHandler extends SimpleChannelInboundHandler streamIds; - public SluiceServerHandler(StreamManager streamManager, RpcHandler rpcHandler) { + public SluiceRequestHandler( + Channel channel, + SluiceClient reverseClient, + StreamManager streamManager, + RpcHandler rpcHandler) { + this.channel = channel; + this.reverseClient = reverseClient; this.streamManager = streamManager; this.rpcHandler = rpcHandler; this.streamIds = Sets.newHashSet(); } @Override - public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception { - logger.error("Exception in connection from " + ctx.channel().remoteAddress(), cause); - ctx.close(); - super.exceptionCaught(ctx, cause); + public void exceptionCaught(Throwable cause) { } @Override - public void channelUnregistered(ChannelHandlerContext ctx) throws Exception { - super.channelUnregistered(ctx); + public void channelUnregistered() { // Inform the StreamManager that these streams will no longer be read from. for (long streamId : streamIds) { streamManager.connectionTerminated(streamId); @@ -81,18 +90,18 @@ public void channelUnregistered(ChannelHandlerContext ctx) throws Exception { } @Override - public void channelRead0(ChannelHandlerContext ctx, ClientRequest request) { + public void handle(RequestMessage request) { if (request instanceof ChunkFetchRequest) { - processFetchRequest(ctx, (ChunkFetchRequest) request); + processFetchRequest((ChunkFetchRequest) request); } else if (request instanceof RpcRequest) { - processRpcRequest(ctx, (RpcRequest) request); + processRpcRequest((RpcRequest) request); } else { throw new IllegalArgumentException("Unknown request type: " + request); } } - private void processFetchRequest(final ChannelHandlerContext ctx, final ChunkFetchRequest req) { - final String client = ctx.channel().remoteAddress().toString(); + private void processFetchRequest(final ChunkFetchRequest req) { + final String client = NettyUtils.getRemoteAddress(channel); streamIds.add(req.streamChunkId.streamId); logger.trace("Received req from {} to fetch block {}", client, req.streamChunkId); @@ -103,29 +112,29 @@ private void processFetchRequest(final ChannelHandlerContext ctx, final ChunkFet } catch (Exception e) { logger.error(String.format( "Error opening block %s for request from %s", req.streamChunkId, client), e); - respond(ctx, new ChunkFetchFailure(req.streamChunkId, Throwables.getStackTraceAsString(e))); + respond(new ChunkFetchFailure(req.streamChunkId, Throwables.getStackTraceAsString(e))); return; } - respond(ctx, new ChunkFetchSuccess(req.streamChunkId, buf)); + respond(new ChunkFetchSuccess(req.streamChunkId, buf)); } - private void processRpcRequest(final ChannelHandlerContext ctx, final RpcRequest req) { + private void processRpcRequest(final RpcRequest req) { try { - rpcHandler.receive(req.message, new RpcResponseCallback() { + rpcHandler.receive(reverseClient, req.message, new RpcResponseCallback() { @Override public void onSuccess(byte[] response) { - respond(ctx, new RpcResponse(req.tag, response)); + respond(new RpcResponse(req.tag, response)); } @Override public void onFailure(Throwable e) { - respond(ctx, new RpcFailure(req.tag, Throwables.getStackTraceAsString(e))); + respond(new RpcFailure(req.tag, Throwables.getStackTraceAsString(e))); } }); } catch (Exception e) { logger.error("Error while invoking RpcHandler#receive() on RPC tag " + req.tag, e); - respond(ctx, new RpcFailure(req.tag, Throwables.getStackTraceAsString(e))); + respond(new RpcFailure(req.tag, Throwables.getStackTraceAsString(e))); } } @@ -133,9 +142,9 @@ public void onFailure(Throwable e) { * Responds to a single message with some Encodable object. If a failure occurs while sending, * it will be logged and the channel closed. */ - private void respond(final ChannelHandlerContext ctx, final Encodable result) { - final String remoteAddress = ctx.channel().remoteAddress().toString(); - ctx.writeAndFlush(result).addListener( + private void respond(final Encodable result) { + final String remoteAddress = channel.remoteAddress().toString(); + channel.writeAndFlush(result).addListener( new ChannelFutureListener() { @Override public void operationComplete(ChannelFuture future) throws Exception { @@ -144,7 +153,7 @@ public void operationComplete(ChannelFuture future) throws Exception { } else { logger.error(String.format("Error sending result %s to %s; closing connection", result, remoteAddress), future.cause()); - ctx.close(); + channel.close(); } } } diff --git a/network/common/src/main/java/org/apache/spark/network/server/SluiceServer.java b/network/common/src/main/java/org/apache/spark/network/server/SluiceServer.java index aa81271024156..965db536a2782 100644 --- a/network/common/src/main/java/org/apache/spark/network/server/SluiceServer.java +++ b/network/common/src/main/java/org/apache/spark/network/server/SluiceServer.java @@ -19,6 +19,7 @@ import java.io.Closeable; import java.net.InetSocketAddress; +import java.util.concurrent.TimeUnit; import io.netty.bootstrap.ServerBootstrap; import io.netty.buffer.PooledByteBufAllocator; @@ -30,8 +31,7 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import org.apache.spark.network.protocol.request.ClientRequestDecoder; -import org.apache.spark.network.protocol.response.ServerResponseEncoder; +import org.apache.spark.network.SluiceContext; import org.apache.spark.network.util.IOMode; import org.apache.spark.network.util.NettyUtils; import org.apache.spark.network.util.SluiceConfig; @@ -42,18 +42,16 @@ public class SluiceServer implements Closeable { private final Logger logger = LoggerFactory.getLogger(SluiceServer.class); + private final SluiceContext context; private final SluiceConfig conf; - private final StreamManager streamManager; - private final RpcHandler rpcHandler; private ServerBootstrap bootstrap; private ChannelFuture channelFuture; private int port; - public SluiceServer(SluiceConfig conf, StreamManager streamManager, RpcHandler rpcHandler) { - this.conf = conf; - this.streamManager = streamManager; - this.rpcHandler = rpcHandler; + public SluiceServer(SluiceContext context) { + this.context = context; + this.conf = context.getConf(); init(); } @@ -86,16 +84,9 @@ private void init() { } bootstrap.childHandler(new ChannelInitializer() { - @Override protected void initChannel(SocketChannel ch) throws Exception { - ch.pipeline() - .addLast("frameDecoder", NettyUtils.createFrameDecoder()) - .addLast("clientRequestDecoder", new ClientRequestDecoder()) - .addLast("serverResponseEncoder", new ServerResponseEncoder()) - // NOTE: Chunks are currently guaranteed to be returned in the order of request, but this - // would require more logic to guarantee if this were not part of the same event loop. - .addLast("handler", new SluiceServerHandler(streamManager, rpcHandler)); + context.initializePipeline(ch); } }); @@ -109,7 +100,8 @@ protected void initChannel(SocketChannel ch) throws Exception { @Override public void close() { if (channelFuture != null) { - channelFuture.channel().close().awaitUninterruptibly(); + // close is a local operation and should finish with milliseconds; timeout just to be safe + channelFuture.channel().close().awaitUninterruptibly(10, TimeUnit.SECONDS); channelFuture = null; } if (bootstrap != null && bootstrap.group() != null) { diff --git a/network/common/src/main/java/org/apache/spark/network/server/StreamManager.java b/network/common/src/main/java/org/apache/spark/network/server/StreamManager.java index 2e07f5a270cb9..47b74b229fdec 100644 --- a/network/common/src/main/java/org/apache/spark/network/server/StreamManager.java +++ b/network/common/src/main/java/org/apache/spark/network/server/StreamManager.java @@ -21,7 +21,7 @@ /** * The StreamManager is used to fetch individual chunks from a stream. This is used in - * {@link SluiceServerHandler} in order to respond to fetchChunk() requests. Creation of the + * {@link SluiceRequestHandler} in order to respond to fetchChunk() requests. Creation of the * stream is outside the scope of Sluice, but a given stream is guaranteed to be read by only one * client connection, meaning that getChunk() for a particular stream will be called serially and * that once the connection associated with the stream is closed, that stream will never be used diff --git a/network/common/src/main/java/org/apache/spark/network/util/IOMode.java b/network/common/src/main/java/org/apache/spark/network/util/IOMode.java index 91cb3e0e6f8f6..c0aa12c81ba64 100644 --- a/network/common/src/main/java/org/apache/spark/network/util/IOMode.java +++ b/network/common/src/main/java/org/apache/spark/network/util/IOMode.java @@ -19,7 +19,7 @@ /** * Selector for which form of low-level IO we should use. - * NIO is always available, while EPOLL is only available on certain machines. + * NIO is always available, while EPOLL is only available on Linux. * AUTO is used to select EPOLL if it's available, or NIO otherwise. */ public enum IOMode { diff --git a/network/common/src/main/java/org/apache/spark/network/util/JavaUtils.java b/network/common/src/main/java/org/apache/spark/network/util/JavaUtils.java index fafdcad04aeb6..32ba3f5b07f7a 100644 --- a/network/common/src/main/java/org/apache/spark/network/util/JavaUtils.java +++ b/network/common/src/main/java/org/apache/spark/network/util/JavaUtils.java @@ -18,13 +18,21 @@ package org.apache.spark.network.util; import java.io.Closeable; +import java.io.IOException; import com.google.common.io.Closeables; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; public class JavaUtils { + private static final Logger logger = LoggerFactory.getLogger(JavaUtils.class); + /** Closes the given object, ignoring IOExceptions. */ - @SuppressWarnings("deprecation") - public static void closeQuietly(Closeable closable) { - Closeables.closeQuietly(closable); + public static void closeQuietly(Closeable closeable) { + try { + closeable.close(); + } catch (IOException e) { + logger.error("IOException should not have been thrown.", e); + } } } diff --git a/network/common/src/main/java/org/apache/spark/network/util/NettyUtils.java b/network/common/src/main/java/org/apache/spark/network/util/NettyUtils.java index 3d20dc9e1c1cd..a925c05469d3c 100644 --- a/network/common/src/main/java/org/apache/spark/network/util/NettyUtils.java +++ b/network/common/src/main/java/org/apache/spark/network/util/NettyUtils.java @@ -44,11 +44,11 @@ public static EventLoopGroup createEventLoop(IOMode mode, int numThreads, String } ThreadFactory threadFactory = new ThreadFactoryBuilder() - .setDaemon(true) - .setNameFormat(threadPrefix + "-%d") - .build(); + .setDaemon(true) + .setNameFormat(threadPrefix + "-%d") + .build(); - switch(mode) { + switch (mode) { case NIO: return new NioEventLoopGroup(numThreads, threadFactory); case EPOLL: @@ -63,7 +63,7 @@ public static Class getClientChannelClass(IOMode mode) { if (mode == IOMode.AUTO) { mode = autoselectMode(); } - switch(mode) { + switch (mode) { case NIO: return NioSocketChannel.class; case EPOLL: @@ -78,7 +78,7 @@ public static Class getServerChannelClass(IOMode mode) if (mode == IOMode.AUTO) { mode = autoselectMode(); } - switch(mode) { + switch (mode) { case NIO: return NioServerSocketChannel.class; case EPOLL: @@ -101,9 +101,16 @@ public static ByteToMessageDecoder createFrameDecoder() { return new LengthFieldBasedFrameDecoder(Integer.MAX_VALUE, 0, 8, -8, 8); } + /** Returns the remote address on the channel or "" if none exists. */ + public static String getRemoteAddress(Channel channel) { + if (channel != null && channel.remoteAddress() != null) { + return channel.remoteAddress().toString(); + } + return ""; + } + /** Returns EPOLL if it's available on this system, NIO otherwise. */ private static IOMode autoselectMode() { return Epoll.isAvailable() ? IOMode.EPOLL : IOMode.NIO; } } - diff --git a/network/common/src/test/java/org/apache/spark/network/IntegrationSuite.java b/network/common/src/test/java/org/apache/spark/network/IntegrationSuite.java index d20528558cae1..d38f6db99c09b 100644 --- a/network/common/src/test/java/org/apache/spark/network/IntegrationSuite.java +++ b/network/common/src/test/java/org/apache/spark/network/IntegrationSuite.java @@ -94,8 +94,9 @@ public ManagedBuffer getChunk(long streamId, int chunkIndex) { } } }; - server = new SluiceServer(conf, streamManager, new NoOpRpcHandler()); - clientFactory = new SluiceClientFactory(conf); + SluiceContext context = new SluiceContext(conf, streamManager, new NoOpRpcHandler()); + server = context.createServer(); + clientFactory = context.createClientFactory(); } @AfterClass @@ -118,6 +119,7 @@ public void releaseBuffers() { } private FetchResult fetchChunks(List chunkIndices) throws Exception { + System.out.println("----------------------------------------------------------------"); SluiceClient client = clientFactory.createClient(TestUtils.getLocalHost(), server.getPort()); final Semaphore sem = new Semaphore(0); @@ -170,6 +172,14 @@ public void fetchFileChunk() throws Exception { res.releaseBuffers(); } + @Test + public void fetchNonExistentChunk() throws Exception { + FetchResult res = fetchChunks(Lists.newArrayList(12345)); + assertTrue(res.successChunks.isEmpty()); + assertEquals(res.failedChunks, Sets.newHashSet(12345)); + assertTrue(res.buffers.isEmpty()); + } + @Test public void fetchBothChunks() throws Exception { FetchResult res = fetchChunks(Lists.newArrayList(BUFFER_CHUNK_INDEX, FILE_CHUNK_INDEX)); @@ -179,14 +189,6 @@ public void fetchBothChunks() throws Exception { res.releaseBuffers(); } - @Test - public void fetchNonExistentChunk() throws Exception { - FetchResult res = fetchChunks(Lists.newArrayList(12345)); - assertTrue(res.successChunks.isEmpty()); - assertEquals(res.failedChunks, Sets.newHashSet(12345)); - assertTrue(res.buffers.isEmpty()); - } - @Test public void fetchChunkAndNonExistent() throws Exception { FetchResult res = fetchChunks(Lists.newArrayList(BUFFER_CHUNK_INDEX, 12345)); diff --git a/network/common/src/test/java/org/apache/spark/network/NoOpRpcHandler.java b/network/common/src/test/java/org/apache/spark/network/NoOpRpcHandler.java index af35709319957..ccfb7576afadb 100644 --- a/network/common/src/test/java/org/apache/spark/network/NoOpRpcHandler.java +++ b/network/common/src/test/java/org/apache/spark/network/NoOpRpcHandler.java @@ -17,10 +17,11 @@ import org.apache.spark.network.client.RpcResponseCallback; import org.apache.spark.network.server.RpcHandler; +import org.apache.spark.network.client.SluiceClient; public class NoOpRpcHandler implements RpcHandler { @Override - public void receive(byte[] message, RpcResponseCallback callback) { + public void receive(SluiceClient client, byte[] message, RpcResponseCallback callback) { callback.onSuccess(new byte[0]); } } diff --git a/network/common/src/test/java/org/apache/spark/network/ProtocolSuite.java b/network/common/src/test/java/org/apache/spark/network/ProtocolSuite.java index cf74a9d8993fe..d2476e7f2ac22 100644 --- a/network/common/src/test/java/org/apache/spark/network/ProtocolSuite.java +++ b/network/common/src/test/java/org/apache/spark/network/ProtocolSuite.java @@ -22,25 +22,22 @@ import static org.junit.Assert.assertEquals; +import org.apache.spark.network.protocol.Message; import org.apache.spark.network.protocol.StreamChunkId; import org.apache.spark.network.protocol.request.ChunkFetchRequest; -import org.apache.spark.network.protocol.request.ClientRequest; -import org.apache.spark.network.protocol.request.ClientRequestDecoder; -import org.apache.spark.network.protocol.request.ClientRequestEncoder; import org.apache.spark.network.protocol.response.ChunkFetchFailure; import org.apache.spark.network.protocol.response.ChunkFetchSuccess; -import org.apache.spark.network.protocol.response.ServerResponse; -import org.apache.spark.network.protocol.response.ServerResponseDecoder; -import org.apache.spark.network.protocol.response.ServerResponseEncoder; +import org.apache.spark.network.protocol.response.MessageDecoder; +import org.apache.spark.network.protocol.response.MessageEncoder; import org.apache.spark.network.util.NettyUtils; public class ProtocolSuite { - private void testServerToClient(ServerResponse msg) { - EmbeddedChannel serverChannel = new EmbeddedChannel(new ServerResponseEncoder()); + private void testServerToClient(Message msg) { + EmbeddedChannel serverChannel = new EmbeddedChannel(new MessageEncoder()); serverChannel.writeOutbound(msg); EmbeddedChannel clientChannel = new EmbeddedChannel( - NettyUtils.createFrameDecoder(), new ServerResponseDecoder()); + NettyUtils.createFrameDecoder(), new MessageDecoder()); while (!serverChannel.outboundMessages().isEmpty()) { clientChannel.writeInbound(serverChannel.readOutbound()); @@ -50,12 +47,12 @@ private void testServerToClient(ServerResponse msg) { assertEquals(msg, clientChannel.readInbound()); } - private void testClientToServer(ClientRequest msg) { - EmbeddedChannel clientChannel = new EmbeddedChannel(new ClientRequestEncoder()); + private void testClientToServer(Message msg) { + EmbeddedChannel clientChannel = new EmbeddedChannel(new MessageEncoder()); clientChannel.writeOutbound(msg); EmbeddedChannel serverChannel = new EmbeddedChannel( - NettyUtils.createFrameDecoder(), new ClientRequestDecoder()); + NettyUtils.createFrameDecoder(), new MessageDecoder()); while (!clientChannel.outboundMessages().isEmpty()) { serverChannel.writeInbound(clientChannel.readOutbound()); diff --git a/network/common/src/test/java/org/apache/spark/network/SluiceClientFactorySuite.java b/network/common/src/test/java/org/apache/spark/network/SluiceClientFactorySuite.java index e6b59b9ad8e5c..219d6cc998bd7 100644 --- a/network/common/src/test/java/org/apache/spark/network/SluiceClientFactorySuite.java +++ b/network/common/src/test/java/org/apache/spark/network/SluiceClientFactorySuite.java @@ -38,6 +38,7 @@ public class SluiceClientFactorySuite { private SluiceConfig conf; + private SluiceContext context; private SluiceServer server1; private SluiceServer server2; @@ -46,8 +47,9 @@ public void setUp() { conf = new SluiceConfig(new DefaultConfigProvider()); StreamManager streamManager = new DefaultStreamManager(); RpcHandler rpcHandler = new NoOpRpcHandler(); - server1 = new SluiceServer(conf, streamManager, rpcHandler); - server2 = new SluiceServer(conf, streamManager, rpcHandler); + context = new SluiceContext(conf, streamManager, rpcHandler); + server1 = context.createServer(); + server2 = context.createServer(); } @After @@ -58,7 +60,7 @@ public void tearDown() { @Test public void createAndReuseBlockClients() throws TimeoutException { - SluiceClientFactory factory = new SluiceClientFactory(conf); + SluiceClientFactory factory = context.createClientFactory(); SluiceClient c1 = factory.createClient(TestUtils.getLocalHost(), server1.getPort()); SluiceClient c2 = factory.createClient(TestUtils.getLocalHost(), server1.getPort()); SluiceClient c3 = factory.createClient(TestUtils.getLocalHost(), server2.getPort()); @@ -71,7 +73,7 @@ public void createAndReuseBlockClients() throws TimeoutException { @Test public void neverReturnInactiveClients() throws Exception { - SluiceClientFactory factory = new SluiceClientFactory(conf); + SluiceClientFactory factory = context.createClientFactory(); SluiceClient c1 = factory.createClient(TestUtils.getLocalHost(), server1.getPort()); c1.close(); @@ -89,7 +91,7 @@ public void neverReturnInactiveClients() throws Exception { @Test public void closeBlockClientsWithFactory() throws TimeoutException { - SluiceClientFactory factory = new SluiceClientFactory(conf); + SluiceClientFactory factory = context.createClientFactory(); SluiceClient c1 = factory.createClient(TestUtils.getLocalHost(), server1.getPort()); SluiceClient c2 = factory.createClient(TestUtils.getLocalHost(), server2.getPort()); assertTrue(c1.isActive()); diff --git a/network/common/src/test/java/org/apache/spark/network/SluiceClientHandlerSuite.java b/network/common/src/test/java/org/apache/spark/network/SluiceClientHandlerSuite.java index cab0597fb948a..c665f2313c589 100644 --- a/network/common/src/test/java/org/apache/spark/network/SluiceClientHandlerSuite.java +++ b/network/common/src/test/java/org/apache/spark/network/SluiceClientHandlerSuite.java @@ -18,17 +18,17 @@ package org.apache.spark.network; import io.netty.channel.embedded.EmbeddedChannel; +import io.netty.channel.local.LocalChannel; import org.junit.Test; import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertFalse; import static org.mockito.Matchers.any; import static org.mockito.Matchers.eq; import static org.mockito.Mockito.*; import org.apache.spark.network.buffer.ManagedBuffer; import org.apache.spark.network.client.ChunkReceivedCallback; -import org.apache.spark.network.client.SluiceClientHandler; +import org.apache.spark.network.client.SluiceResponseHandler; import org.apache.spark.network.protocol.StreamChunkId; import org.apache.spark.network.protocol.response.ChunkFetchFailure; import org.apache.spark.network.protocol.response.ChunkFetchSuccess; @@ -38,53 +38,45 @@ public class SluiceClientHandlerSuite { public void handleSuccessfulFetch() { StreamChunkId streamChunkId = new StreamChunkId(1, 0); - SluiceClientHandler handler = new SluiceClientHandler(); + SluiceResponseHandler handler = new SluiceResponseHandler(new LocalChannel()); ChunkReceivedCallback callback = mock(ChunkReceivedCallback.class); handler.addFetchRequest(streamChunkId, callback); assertEquals(1, handler.numOutstandingRequests()); - EmbeddedChannel channel = new EmbeddedChannel(handler); - - channel.writeInbound(new ChunkFetchSuccess(streamChunkId, new TestManagedBuffer(123))); + handler.handle(new ChunkFetchSuccess(streamChunkId, new TestManagedBuffer(123))); verify(callback, times(1)).onSuccess(eq(0), (ManagedBuffer) any()); assertEquals(0, handler.numOutstandingRequests()); - assertFalse(channel.finish()); } @Test public void handleFailedFetch() { StreamChunkId streamChunkId = new StreamChunkId(1, 0); - SluiceClientHandler handler = new SluiceClientHandler(); + SluiceResponseHandler handler = new SluiceResponseHandler(new LocalChannel()); ChunkReceivedCallback callback = mock(ChunkReceivedCallback.class); handler.addFetchRequest(streamChunkId, callback); assertEquals(1, handler.numOutstandingRequests()); - EmbeddedChannel channel = new EmbeddedChannel(handler); - channel.writeInbound(new ChunkFetchFailure(streamChunkId, "some error msg")); + handler.handle(new ChunkFetchFailure(streamChunkId, "some error msg")); verify(callback, times(1)).onFailure(eq(0), (Throwable) any()); assertEquals(0, handler.numOutstandingRequests()); - assertFalse(channel.finish()); } @Test public void clearAllOutstandingRequests() { - SluiceClientHandler handler = new SluiceClientHandler(); + SluiceResponseHandler handler = new SluiceResponseHandler(new LocalChannel()); ChunkReceivedCallback callback = mock(ChunkReceivedCallback.class); handler.addFetchRequest(new StreamChunkId(1, 0), callback); handler.addFetchRequest(new StreamChunkId(1, 1), callback); handler.addFetchRequest(new StreamChunkId(1, 2), callback); assertEquals(3, handler.numOutstandingRequests()); - EmbeddedChannel channel = new EmbeddedChannel(handler); - - channel.writeInbound(new ChunkFetchSuccess(new StreamChunkId(1, 0), new TestManagedBuffer(12))); - channel.pipeline().fireExceptionCaught(new Exception("duh duh duhhhh")); + handler.handle(new ChunkFetchSuccess(new StreamChunkId(1, 0), new TestManagedBuffer(12))); + handler.exceptionCaught(new Exception("duh duh duhhhh")); // should fail both b2 and b3 verify(callback, times(1)).onSuccess(eq(0), (ManagedBuffer) any()); verify(callback, times(1)).onFailure(eq(1), (Throwable) any()); verify(callback, times(1)).onFailure(eq(2), (Throwable) any()); assertEquals(0, handler.numOutstandingRequests()); - assertFalse(channel.finish()); } } From 7b7a26cf2c109373dc52dc9d33383da81184eaee Mon Sep 17 00:00:00 2001 From: Aaron Davidson Date: Thu, 16 Oct 2014 20:20:05 -0700 Subject: [PATCH 36/46] Fix Nio compile issue --- .../org/apache/spark/network/nio/NioBlockTransferService.scala | 1 - .../org/apache/spark/network/client/SluiceClientFactory.java | 2 -- .../test/java/org/apache/spark/network/IntegrationSuite.java | 1 - .../org/apache/spark/streaming/scheduler/ReceiverTracker.scala | 2 +- 4 files changed, 1 insertion(+), 5 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/network/nio/NioBlockTransferService.scala b/core/src/main/scala/org/apache/spark/network/nio/NioBlockTransferService.scala index aa27aaf5d8c91..1b76b938f14f5 100644 --- a/core/src/main/scala/org/apache/spark/network/nio/NioBlockTransferService.scala +++ b/core/src/main/scala/org/apache/spark/network/nio/NioBlockTransferService.scala @@ -157,7 +157,6 @@ final class NioBlockTransferService(conf: SparkConf, securityManager: SecurityMa case e: Exception => logError("Exception handling buffer message", e) Some(Message.createErrorMessage(e, msg.id)) - } } case otherMessage: Any => diff --git a/network/common/src/main/java/org/apache/spark/network/client/SluiceClientFactory.java b/network/common/src/main/java/org/apache/spark/network/client/SluiceClientFactory.java index 5de998ef6ed55..638034f93cd65 100644 --- a/network/common/src/main/java/org/apache/spark/network/client/SluiceClientFactory.java +++ b/network/common/src/main/java/org/apache/spark/network/client/SluiceClientFactory.java @@ -84,13 +84,11 @@ public SluiceClient createClient(String remoteHost, int remotePort) throws Timeo final InetSocketAddress address = new InetSocketAddress(remoteHost, remotePort); SluiceClient cachedClient = connectionPool.get(address); if (cachedClient != null && cachedClient.isActive()) { - System.out.println("Reusing cached client: " + cachedClient); return cachedClient; } else if (cachedClient != null) { connectionPool.remove(address, cachedClient); // Remove inactive clients. } - System.out.println("Creating new client: " + cachedClient); logger.debug("Creating new connection to " + address); Bootstrap bootstrap = new Bootstrap(); diff --git a/network/common/src/test/java/org/apache/spark/network/IntegrationSuite.java b/network/common/src/test/java/org/apache/spark/network/IntegrationSuite.java index d38f6db99c09b..81f6afbc6fb52 100644 --- a/network/common/src/test/java/org/apache/spark/network/IntegrationSuite.java +++ b/network/common/src/test/java/org/apache/spark/network/IntegrationSuite.java @@ -119,7 +119,6 @@ public void releaseBuffers() { } private FetchResult fetchChunks(List chunkIndices) throws Exception { - System.out.println("----------------------------------------------------------------"); SluiceClient client = clientFactory.createClient(TestUtils.getLocalHost(), server.getPort()); final Semaphore sem = new Semaphore(0); diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala index 190373e0cb5f2..7149dbc12a365 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala @@ -122,7 +122,7 @@ class ReceiverTracker(ssc: StreamingContext) extends Logging { sender: ActorRef ) { if (!receiverInputStreamMap.contains(streamId)) { - throw new Exception("Register received for unexpected type " + streamId) + throw new Exception("Register received for unexpected id " + streamId) } receiverInfo(streamId) = ReceiverInfo( streamId, s"${typ}-${streamId}", receiverActor, true, host) From d236dfdd2c6ac21d6eb1ce4e356be0f69aa6eb24 Mon Sep 17 00:00:00 2001 From: Aaron Davidson Date: Fri, 17 Oct 2014 00:10:51 -0700 Subject: [PATCH 37/46] Remove no-op serializer :) --- .../apache/spark/serializer/Serializer.scala | 45 ------------------- 1 file changed, 45 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/serializer/Serializer.scala b/core/src/main/scala/org/apache/spark/serializer/Serializer.scala index 4024dea31845c..ca6e971d227fb 100644 --- a/core/src/main/scala/org/apache/spark/serializer/Serializer.scala +++ b/core/src/main/scala/org/apache/spark/serializer/Serializer.scala @@ -142,48 +142,3 @@ abstract class DeserializationStream { } } } - - -class NoOpReadSerializer(conf: SparkConf) extends Serializer with Serializable { - override def newInstance(): SerializerInstance = { - new NoOpReadSerializerInstance() - } -} - -private[spark] class NoOpReadSerializerInstance() - extends SerializerInstance { - - override def serialize[T: ClassTag](t: T): ByteBuffer = { - val bos = new ByteArrayOutputStream() - val out = serializeStream(bos) - out.writeObject(t) - out.close() - ByteBuffer.wrap(bos.toByteArray) - } - - override def deserialize[T: ClassTag](bytes: ByteBuffer): T = { - null.asInstanceOf[T] - } - - override def deserialize[T: ClassTag](bytes: ByteBuffer, loader: ClassLoader): T = { - null.asInstanceOf[T] - } - - override def serializeStream(s: OutputStream): SerializationStream = { - new JavaSerializationStream(s, 100) - } - - override def deserializeStream(s: InputStream): DeserializationStream = { - new NoOpDeserializationStream(s, Utils.getContextOrSparkClassLoader) - } - - def deserializeStream(s: InputStream, loader: ClassLoader): DeserializationStream = { - new NoOpDeserializationStream(s, loader) - } -} - -private[spark] class NoOpDeserializationStream(in: InputStream, loader: ClassLoader) - extends DeserializationStream { - def readObject[T: ClassTag](): T = throw new EOFException() - def close() { } -} From 9da0bc11383587c21f6306bb0a2e9fb4b86fbb88 Mon Sep 17 00:00:00 2001 From: Aaron Davidson Date: Fri, 17 Oct 2014 10:53:49 -0700 Subject: [PATCH 38/46] Add RPC unit tests --- .../spark/network/client/SluiceClient.java | 2 +- .../network/client/SluiceResponseHandler.java | 3 +- .../protocol/response/ChunkFetchFailure.java | 2 +- .../network/protocol/response/RpcFailure.java | 2 +- ...e.java => ChunkFetchIntegrationSuite.java} | 2 +- .../apache/spark/network/NoOpRpcHandler.java | 1 + .../apache/spark/network/ProtocolSuite.java | 25 ++- .../spark/network/RpcIntegrationSuite.java | 176 ++++++++++++++++++ ...e.java => SluiceResponseHandlerSuite.java} | 36 +++- 9 files changed, 233 insertions(+), 16 deletions(-) rename network/common/src/test/java/org/apache/spark/network/{IntegrationSuite.java => ChunkFetchIntegrationSuite.java} (99%) create mode 100644 network/common/src/test/java/org/apache/spark/network/RpcIntegrationSuite.java rename network/common/src/test/java/org/apache/spark/network/{SluiceClientHandlerSuite.java => SluiceResponseHandlerSuite.java} (71%) diff --git a/network/common/src/main/java/org/apache/spark/network/client/SluiceClient.java b/network/common/src/main/java/org/apache/spark/network/client/SluiceClient.java index d6d97981eebd6..88bf365cc5e2d 100644 --- a/network/common/src/main/java/org/apache/spark/network/client/SluiceClient.java +++ b/network/common/src/main/java/org/apache/spark/network/client/SluiceClient.java @@ -71,7 +71,7 @@ public SluiceClient(Channel channel, SluiceResponseHandler handler) { } public boolean isActive() { - return channel.isOpen() || channel.isRegistered() || channel.isActive(); + return channel.isOpen() || channel.isActive(); } /** diff --git a/network/common/src/main/java/org/apache/spark/network/client/SluiceResponseHandler.java b/network/common/src/main/java/org/apache/spark/network/client/SluiceResponseHandler.java index 9fbd487da86a7..83ee1b5ef8102 100644 --- a/network/common/src/main/java/org/apache/spark/network/client/SluiceResponseHandler.java +++ b/network/common/src/main/java/org/apache/spark/network/client/SluiceResponseHandler.java @@ -151,8 +151,9 @@ public void handle(ResponseMessage message) { } } + /** Returns total number of outstanding requests (fetch requests + rpcs) */ @VisibleForTesting public int numOutstandingRequests() { - return outstandingFetches.size(); + return outstandingFetches.size() + outstandingRpcs.size(); } } diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/response/ChunkFetchFailure.java b/network/common/src/main/java/org/apache/spark/network/protocol/response/ChunkFetchFailure.java index 18ed4d95bba4c..cb3cbcd0a53ca 100644 --- a/network/common/src/main/java/org/apache/spark/network/protocol/response/ChunkFetchFailure.java +++ b/network/common/src/main/java/org/apache/spark/network/protocol/response/ChunkFetchFailure.java @@ -57,7 +57,7 @@ public static ChunkFetchFailure decode(ByteBuf buf) { int numErrorStringBytes = buf.readInt(); byte[] errorBytes = new byte[numErrorStringBytes]; buf.readBytes(errorBytes); - return new ChunkFetchFailure(streamChunkId, new String(errorBytes)); + return new ChunkFetchFailure(streamChunkId, new String(errorBytes, Charsets.UTF_8)); } @Override diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/response/RpcFailure.java b/network/common/src/main/java/org/apache/spark/network/protocol/response/RpcFailure.java index 6b71da5708c58..1f161f7957543 100644 --- a/network/common/src/main/java/org/apache/spark/network/protocol/response/RpcFailure.java +++ b/network/common/src/main/java/org/apache/spark/network/protocol/response/RpcFailure.java @@ -52,7 +52,7 @@ public static RpcFailure decode(ByteBuf buf) { int numErrorStringBytes = buf.readInt(); byte[] errorBytes = new byte[numErrorStringBytes]; buf.readBytes(errorBytes); - return new RpcFailure(tag, new String(errorBytes)); + return new RpcFailure(tag, new String(errorBytes, Charsets.UTF_8)); } @Override diff --git a/network/common/src/test/java/org/apache/spark/network/IntegrationSuite.java b/network/common/src/test/java/org/apache/spark/network/ChunkFetchIntegrationSuite.java similarity index 99% rename from network/common/src/test/java/org/apache/spark/network/IntegrationSuite.java rename to network/common/src/test/java/org/apache/spark/network/ChunkFetchIntegrationSuite.java index 81f6afbc6fb52..f7f53e2df4a49 100644 --- a/network/common/src/test/java/org/apache/spark/network/IntegrationSuite.java +++ b/network/common/src/test/java/org/apache/spark/network/ChunkFetchIntegrationSuite.java @@ -48,7 +48,7 @@ import org.apache.spark.network.util.DefaultConfigProvider; import org.apache.spark.network.util.SluiceConfig; -public class IntegrationSuite { +public class ChunkFetchIntegrationSuite { static final long STREAM_ID = 1; static final int BUFFER_CHUNK_INDEX = 0; static final int FILE_CHUNK_INDEX = 1; diff --git a/network/common/src/test/java/org/apache/spark/network/NoOpRpcHandler.java b/network/common/src/test/java/org/apache/spark/network/NoOpRpcHandler.java index ccfb7576afadb..e7bad051c6200 100644 --- a/network/common/src/test/java/org/apache/spark/network/NoOpRpcHandler.java +++ b/network/common/src/test/java/org/apache/spark/network/NoOpRpcHandler.java @@ -19,6 +19,7 @@ import org.apache.spark.network.server.RpcHandler; import org.apache.spark.network.client.SluiceClient; +/** Test RpcHandler which always returns a zero-sized success. */ public class NoOpRpcHandler implements RpcHandler { @Override public void receive(SluiceClient client, byte[] message, RpcResponseCallback callback) { diff --git a/network/common/src/test/java/org/apache/spark/network/ProtocolSuite.java b/network/common/src/test/java/org/apache/spark/network/ProtocolSuite.java index d2476e7f2ac22..9f20496f75f82 100644 --- a/network/common/src/test/java/org/apache/spark/network/ProtocolSuite.java +++ b/network/common/src/test/java/org/apache/spark/network/ProtocolSuite.java @@ -25,10 +25,13 @@ import org.apache.spark.network.protocol.Message; import org.apache.spark.network.protocol.StreamChunkId; import org.apache.spark.network.protocol.request.ChunkFetchRequest; +import org.apache.spark.network.protocol.request.RpcRequest; import org.apache.spark.network.protocol.response.ChunkFetchFailure; import org.apache.spark.network.protocol.response.ChunkFetchSuccess; import org.apache.spark.network.protocol.response.MessageDecoder; import org.apache.spark.network.protocol.response.MessageEncoder; +import org.apache.spark.network.protocol.response.RpcFailure; +import org.apache.spark.network.protocol.response.RpcResponse; import org.apache.spark.network.util.NettyUtils; public class ProtocolSuite { @@ -63,19 +66,21 @@ private void testClientToServer(Message msg) { } @Test - public void s2cChunkFetchSuccess() { - testServerToClient(new ChunkFetchSuccess(new StreamChunkId(1, 2), new TestManagedBuffer(10))); - testServerToClient(new ChunkFetchSuccess(new StreamChunkId(1, 2), new TestManagedBuffer(0))); + public void requests() { + testClientToServer(new ChunkFetchRequest(new StreamChunkId(1, 2))); + testClientToServer(new RpcRequest(12345, new byte[0])); + testClientToServer(new RpcRequest(12345, new byte[100])); } @Test - public void s2cBlockFetchFailure() { + public void responses() { + testServerToClient(new ChunkFetchSuccess(new StreamChunkId(1, 2), new TestManagedBuffer(10))); + testServerToClient(new ChunkFetchSuccess(new StreamChunkId(1, 2), new TestManagedBuffer(0))); testServerToClient(new ChunkFetchFailure(new StreamChunkId(1, 2), "this is an error")); testServerToClient(new ChunkFetchFailure(new StreamChunkId(1, 2), "")); + testServerToClient(new RpcResponse(12345, new byte[0])); + testServerToClient(new RpcResponse(12345, new byte[1000])); + testServerToClient(new RpcFailure(0, "this is an error")); + testServerToClient(new RpcFailure(0, "")); } - - @Test - public void c2sChunkFetchRequest() { - testClientToServer(new ChunkFetchRequest(new StreamChunkId(1, 2))); - } -} +} \ No newline at end of file diff --git a/network/common/src/test/java/org/apache/spark/network/RpcIntegrationSuite.java b/network/common/src/test/java/org/apache/spark/network/RpcIntegrationSuite.java new file mode 100644 index 0000000000000..a909e4032d608 --- /dev/null +++ b/network/common/src/test/java/org/apache/spark/network/RpcIntegrationSuite.java @@ -0,0 +1,176 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network; + +import java.util.Collections; +import java.util.HashSet; +import java.util.Iterator; +import java.util.Set; +import java.util.concurrent.Semaphore; +import java.util.concurrent.TimeUnit; + +import com.google.common.base.Charsets; +import com.google.common.collect.Sets; +import org.junit.AfterClass; +import org.junit.BeforeClass; +import org.junit.Test; + +import static org.junit.Assert.*; + +import org.apache.spark.network.client.RpcResponseCallback; +import org.apache.spark.network.client.SluiceClient; +import org.apache.spark.network.client.SluiceClientFactory; +import org.apache.spark.network.server.DefaultStreamManager; +import org.apache.spark.network.server.RpcHandler; +import org.apache.spark.network.server.SluiceServer; +import org.apache.spark.network.util.DefaultConfigProvider; +import org.apache.spark.network.util.SluiceConfig; + +public class RpcIntegrationSuite { + static SluiceServer server; + static SluiceClientFactory clientFactory; + static RpcHandler rpcHandler; + + @BeforeClass + public static void setUp() throws Exception { + SluiceConfig conf = new SluiceConfig(new DefaultConfigProvider()); + rpcHandler = new RpcHandler() { + @Override + public void receive(SluiceClient client, byte[] message, RpcResponseCallback callback) { + String msg = new String(message, Charsets.UTF_8); + String[] parts = msg.split("/"); + if (parts[0].equals("hello")) { + callback.onSuccess(("Hello, " + parts[1] + "!").getBytes(Charsets.UTF_8)); + } else if (parts[0].equals("return error")) { + callback.onFailure(new RuntimeException("Returned: " + parts[1])); + } else if (parts[0].equals("throw error")) { + throw new RuntimeException("Thrown: " + parts[1]); + } + } + }; + SluiceContext context = new SluiceContext(conf, new DefaultStreamManager(), rpcHandler); + server = context.createServer(); + clientFactory = context.createClientFactory(); + } + + @AfterClass + public static void tearDown() { + server.close(); + clientFactory.close(); + } + + class RpcResult { + public Set successMessages; + public Set errorMessages; + } + + private RpcResult sendRPC(String ... commands) throws Exception { + SluiceClient client = clientFactory.createClient(TestUtils.getLocalHost(), server.getPort()); + final Semaphore sem = new Semaphore(0); + + final RpcResult res = new RpcResult(); + res.successMessages = Collections.synchronizedSet(new HashSet()); + res.errorMessages = Collections.synchronizedSet(new HashSet()); + + RpcResponseCallback callback = new RpcResponseCallback() { + @Override + public void onSuccess(byte[] message) { + res.successMessages.add(new String(message, Charsets.UTF_8)); + sem.release(); + } + + @Override + public void onFailure(Throwable e) { + res.errorMessages.add(e.getMessage()); + sem.release(); + } + }; + + for (String command : commands) { + client.sendRpc(command.getBytes(Charsets.UTF_8), callback); + } + + if (!sem.tryAcquire(commands.length, 5, TimeUnit.SECONDS)) { + fail("Timeout getting response from the server"); + } + client.close(); + return res; + } + + @Test + public void singleRPC() throws Exception { + RpcResult res = sendRPC("hello/Aaron"); + assertEquals(res.successMessages, Sets.newHashSet("Hello, Aaron!")); + assertTrue(res.errorMessages.isEmpty()); + } + + @Test + public void doubleRPC() throws Exception { + RpcResult res = sendRPC("hello/Aaron", "hello/Reynold"); + assertEquals(res.successMessages, Sets.newHashSet("Hello, Aaron!", "Hello, Reynold!")); + assertTrue(res.errorMessages.isEmpty()); + } + + @Test + public void returnErrorRPC() throws Exception { + RpcResult res = sendRPC("return error/OK"); + assertTrue(res.successMessages.isEmpty()); + assertErrorsContain(res.errorMessages, Sets.newHashSet("Returned: OK")); + } + + @Test + public void throwErrorRPC() throws Exception { + RpcResult res = sendRPC("throw error/uh-oh"); + assertTrue(res.successMessages.isEmpty()); + assertErrorsContain(res.errorMessages, Sets.newHashSet("Thrown: uh-oh")); + } + + @Test + public void doubleTrouble() throws Exception { + RpcResult res = sendRPC("return error/OK", "throw error/uh-oh"); + assertTrue(res.successMessages.isEmpty()); + assertErrorsContain(res.errorMessages, Sets.newHashSet("Returned: OK", "Thrown: uh-oh")); + } + + @Test + public void sendSuccessAndFailure() throws Exception { + RpcResult res = sendRPC("hello/Bob", "throw error/the", "hello/Builder", "return error/!"); + assertEquals(res.successMessages, Sets.newHashSet("Hello, Bob!", "Hello, Builder!")); + assertErrorsContain(res.errorMessages, Sets.newHashSet("Thrown: the", "Returned: !")); + } + + private void assertErrorsContain(Set errors, Set contains) { + assertEquals(contains.size(), errors.size()); + + Set remainingErrors = Sets.newHashSet(errors); + for (String contain : contains) { + Iterator it = remainingErrors.iterator(); + boolean foundMatch = false; + while (it.hasNext()) { + if (it.next().contains(contain)) { + it.remove(); + foundMatch = true; + break; + } + } + assertTrue("Could not find error containing " + contain + "; errors: " + errors, foundMatch); + } + + assertTrue(remainingErrors.isEmpty()); + } +} diff --git a/network/common/src/test/java/org/apache/spark/network/SluiceClientHandlerSuite.java b/network/common/src/test/java/org/apache/spark/network/SluiceResponseHandlerSuite.java similarity index 71% rename from network/common/src/test/java/org/apache/spark/network/SluiceClientHandlerSuite.java rename to network/common/src/test/java/org/apache/spark/network/SluiceResponseHandlerSuite.java index c665f2313c589..3138c5d21a85f 100644 --- a/network/common/src/test/java/org/apache/spark/network/SluiceClientHandlerSuite.java +++ b/network/common/src/test/java/org/apache/spark/network/SluiceResponseHandlerSuite.java @@ -28,12 +28,15 @@ import org.apache.spark.network.buffer.ManagedBuffer; import org.apache.spark.network.client.ChunkReceivedCallback; +import org.apache.spark.network.client.RpcResponseCallback; import org.apache.spark.network.client.SluiceResponseHandler; import org.apache.spark.network.protocol.StreamChunkId; import org.apache.spark.network.protocol.response.ChunkFetchFailure; import org.apache.spark.network.protocol.response.ChunkFetchSuccess; +import org.apache.spark.network.protocol.response.RpcFailure; +import org.apache.spark.network.protocol.response.RpcResponse; -public class SluiceClientHandlerSuite { +public class SluiceResponseHandlerSuite { @Test public void handleSuccessfulFetch() { StreamChunkId streamChunkId = new StreamChunkId(1, 0); @@ -79,4 +82,35 @@ public void clearAllOutstandingRequests() { verify(callback, times(1)).onFailure(eq(2), (Throwable) any()); assertEquals(0, handler.numOutstandingRequests()); } + + @Test + public void handleSuccessfulRPC() { + SluiceResponseHandler handler = new SluiceResponseHandler(new LocalChannel()); + RpcResponseCallback callback = mock(RpcResponseCallback.class); + handler.addRpcRequest(12345, callback); + assertEquals(1, handler.numOutstandingRequests()); + + handler.handle(new RpcResponse(54321, new byte[7])); // should be ignored + assertEquals(1, handler.numOutstandingRequests()); + + byte[] arr = new byte[10]; + handler.handle(new RpcResponse(12345, arr)); + verify(callback, times(1)).onSuccess(eq(arr)); + assertEquals(0, handler.numOutstandingRequests()); + } + + @Test + public void handleFailedRPC() { + SluiceResponseHandler handler = new SluiceResponseHandler(new LocalChannel()); + RpcResponseCallback callback = mock(RpcResponseCallback.class); + handler.addRpcRequest(12345, callback); + assertEquals(1, handler.numOutstandingRequests()); + + handler.handle(new RpcFailure(54321, "uh-oh!")); // should be ignored + assertEquals(1, handler.numOutstandingRequests()); + + handler.handle(new RpcFailure(12345, "oh no")); + verify(callback, times(1)).onFailure((Throwable) any()); + assertEquals(0, handler.numOutstandingRequests()); + } } From ccd49595e8d0a730489e577b1152ad67027a5687 Mon Sep 17 00:00:00 2001 From: Aaron Davidson Date: Fri, 17 Oct 2014 15:23:15 -0700 Subject: [PATCH 39/46] Don't throw exception if client immediately fails This seems to cause an exception in `DistributedSuite#recover from repeated node failures during shuffle-reduce`. I guess that the exception caused by having an invalid client is handled in a different way than the client creation throwing an exception. --- .../org/apache/spark/network/client/SluiceClientFactory.java | 2 -- 1 file changed, 2 deletions(-) diff --git a/network/common/src/main/java/org/apache/spark/network/client/SluiceClientFactory.java b/network/common/src/main/java/org/apache/spark/network/client/SluiceClientFactory.java index 638034f93cd65..4c345fb4eb92d 100644 --- a/network/common/src/main/java/org/apache/spark/network/client/SluiceClientFactory.java +++ b/network/common/src/main/java/org/apache/spark/network/client/SluiceClientFactory.java @@ -129,8 +129,6 @@ public void initChannel(SocketChannel ch) { // in the "inactive" state, and someone else simultaneously tried to create another client to // the same server. This is an error condition, as the first client failed to connect. throw new IllegalStateException("Client was unset! Must have been immediately inactive."); - } else if (!client.isActive()) { - throw new IllegalStateException("Failed to create active client."); } return client; } From e5675a4b919452c1eb68dd3a885f5d7f747e6014 Mon Sep 17 00:00:00 2001 From: Aaron Davidson Date: Fri, 17 Oct 2014 18:09:04 -0700 Subject: [PATCH 40/46] Fail outstanding RPCs as well --- .../network/client/SluiceResponseHandler.java | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/network/common/src/main/java/org/apache/spark/network/client/SluiceResponseHandler.java b/network/common/src/main/java/org/apache/spark/network/client/SluiceResponseHandler.java index 83ee1b5ef8102..e6ebe7159f1bd 100644 --- a/network/common/src/main/java/org/apache/spark/network/client/SluiceResponseHandler.java +++ b/network/common/src/main/java/org/apache/spark/network/client/SluiceResponseHandler.java @@ -79,23 +79,31 @@ private void failOutstandingRequests(Throwable cause) { for (Map.Entry entry : outstandingFetches.entrySet()) { entry.getValue().onFailure(entry.getKey().chunkIndex, cause); } + for (Map.Entry entry : outstandingRpcs.entrySet()) { + entry.getValue().onFailure(cause); + } + // It's OK if new fetches appear, as they will fail immediately. outstandingFetches.clear(); + outstandingRpcs.clear(); } @Override public void channelUnregistered() { - if (outstandingFetches.size() > 0) { + if (numOutstandingRequests() > 0) { String remoteAddress = NettyUtils.getRemoteAddress(channel); - logger.error("Still have {} requests outstanding when contention from {} is closed", - outstandingFetches.size(), remoteAddress); + logger.error("Still have {} requests outstanding when connection from {} is closed", + numOutstandingRequests(), remoteAddress); failOutstandingRequests(new RuntimeException("Connection from " + remoteAddress + " closed")); } } @Override public void exceptionCaught(Throwable cause) { - if (outstandingFetches.size() > 0) { + if (numOutstandingRequests() > 0) { + String remoteAddress = NettyUtils.getRemoteAddress(channel); + logger.error("Still have {} requests outstanding when connection from {} is closed", + numOutstandingRequests(), remoteAddress); failOutstandingRequests(cause); } } From 322dfc1ac158bef1e3f23d85333233dc5c62b8f3 Mon Sep 17 00:00:00 2001 From: Aaron Davidson Date: Sun, 26 Oct 2014 22:37:17 -0700 Subject: [PATCH 41/46] Address Reynold's comments, including major rename --- .../network/netty/NettyBlockFetcher.scala | 55 ++++++++-------- .../network/netty/NettyBlockRpcServer.scala | 4 +- .../netty/NettyBlockTransferService.scala | 22 +++---- ...uiceContext.java => TransportContext.java} | 65 ++++++++++--------- ...SluiceClient.java => TransportClient.java} | 26 ++++---- ...ctory.java => TransportClientFactory.java} | 40 ++++++------ ...ler.java => TransportResponseHandler.java} | 10 +-- .../spark/network/server/RpcHandler.java | 6 +- .../spark/network/server/StreamManager.java | 10 +-- ...ndler.java => TransportClientHandler.java} | 37 ++++++----- ...dler.java => TransportRequestHandler.java} | 20 +++--- ...SluiceServer.java => TransportServer.java} | 14 ++-- .../spark/network/util/ConfigProvider.java | 2 +- .../org/apache/spark/network/util/IOMode.java | 2 +- .../apache/spark/network/util/NettyUtils.java | 14 ---- ...java => SystemPropertyConfigProvider.java} | 2 +- .../{SluiceConfig.java => TransportConf.java} | 6 +- .../network/ChunkFetchIntegrationSuite.java | 20 +++--- .../apache/spark/network/NoOpRpcHandler.java | 4 +- .../apache/spark/network/ProtocolSuite.java | 2 +- .../spark/network/RpcIntegrationSuite.java | 22 +++---- ....java => TransportClientFactorySuite.java} | 44 ++++++------- ...ava => TransportResponseHandlerSuite.java} | 15 ++--- 23 files changed, 215 insertions(+), 227 deletions(-) rename network/common/src/main/java/org/apache/spark/network/{SluiceContext.java => TransportContext.java} (54%) rename network/common/src/main/java/org/apache/spark/network/client/{SluiceClient.java => TransportClient.java} (88%) rename network/common/src/main/java/org/apache/spark/network/client/{SluiceClientFactory.java => TransportClientFactory.java} (81%) rename network/common/src/main/java/org/apache/spark/network/client/{SluiceResponseHandler.java => TransportResponseHandler.java} (95%) rename network/common/src/main/java/org/apache/spark/network/server/{SluiceChannelHandler.java => TransportClientHandler.java} (67%) rename network/common/src/main/java/org/apache/spark/network/server/{SluiceRequestHandler.java => TransportRequestHandler.java} (92%) rename network/common/src/main/java/org/apache/spark/network/server/{SluiceServer.java => TransportServer.java} (90%) rename network/common/src/main/java/org/apache/spark/network/util/{DefaultConfigProvider.java => SystemPropertyConfigProvider.java} (94%) rename network/common/src/main/java/org/apache/spark/network/util/{SluiceConfig.java => TransportConf.java} (94%) rename network/common/src/test/java/org/apache/spark/network/{SluiceClientFactorySuite.java => TransportClientFactorySuite.java} (62%) rename network/common/src/test/java/org/apache/spark/network/{SluiceResponseHandlerSuite.java => TransportResponseHandlerSuite.java} (88%) diff --git a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockFetcher.scala b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockFetcher.scala index a03e7c39428ee..344d17e7bf661 100644 --- a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockFetcher.scala +++ b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockFetcher.scala @@ -25,7 +25,7 @@ import org.apache.spark.network.BlockFetchingListener import org.apache.spark.network.netty.NettyMessages._ import org.apache.spark.serializer.{JavaSerializer, Serializer} import org.apache.spark.network.buffer.ManagedBuffer -import org.apache.spark.network.client.{RpcResponseCallback, ChunkReceivedCallback, SluiceClient} +import org.apache.spark.network.client.{RpcResponseCallback, ChunkReceivedCallback, TransportClient} import org.apache.spark.storage.BlockId import org.apache.spark.util.Utils @@ -39,18 +39,18 @@ import org.apache.spark.util.Utils */ class NettyBlockFetcher( serializer: Serializer, - client: SluiceClient, + client: TransportClient, blockIds: Seq[String], listener: BlockFetchingListener) extends Logging { require(blockIds.nonEmpty) - val ser = serializer.newInstance() + private val ser = serializer.newInstance() - var streamHandle: ShuffleStreamHandle = _ + private var streamHandle: ShuffleStreamHandle = _ - val chunkCallback = new ChunkReceivedCallback { + private val chunkCallback = new ChunkReceivedCallback { // On receipt of a chunk, pass it upwards as a block. def onSuccess(chunkIndex: Int, buffer: ManagedBuffer): Unit = Utils.logUncaughtExceptions { listener.onBlockFetchSuccess(blockIds(chunkIndex), buffer) @@ -64,29 +64,32 @@ class NettyBlockFetcher( } } - // Send the RPC to open the given set of blocks. This will return a ShuffleStreamHandle. - client.sendRpc(ser.serialize(OpenBlocks(blockIds.map(BlockId.apply))).array(), - new RpcResponseCallback { - override def onSuccess(response: Array[Byte]): Unit = { - try { - streamHandle = ser.deserialize[ShuffleStreamHandle](ByteBuffer.wrap(response)) - logTrace(s"Successfully opened block set: $streamHandle! Preparing to fetch chunks.") + /** Begins the fetching process, calling the listener with every block fetched. */ + def start(): Unit = { + // Send the RPC to open the given set of blocks. This will return a ShuffleStreamHandle. + client.sendRpc(ser.serialize(OpenBlocks(blockIds.map(BlockId.apply))).array(), + new RpcResponseCallback { + override def onSuccess(response: Array[Byte]): Unit = { + try { + streamHandle = ser.deserialize[ShuffleStreamHandle](ByteBuffer.wrap(response)) + logTrace(s"Successfully opened block set: $streamHandle! Preparing to fetch chunks.") - // Immediately request all chunks -- we expect that the total size of the request is - // reasonable due to higher level chunking in [[ShuffleBlockFetcherIterator]]. - for (i <- 0 until streamHandle.numChunks) { - client.fetchChunk(streamHandle.streamId, i, chunkCallback) + // Immediately request all chunks -- we expect that the total size of the request is + // reasonable due to higher level chunking in [[ShuffleBlockFetcherIterator]]. + for (i <- 0 until streamHandle.numChunks) { + client.fetchChunk(streamHandle.streamId, i, chunkCallback) + } + } catch { + case e: Exception => + logError("Failed while starting block fetches", e) + blockIds.foreach(blockId => Utils.tryLog(listener.onBlockFetchFailure(blockId, e))) } - } catch { - case e: Exception => - logError("Failed while starting block fetches", e) - blockIds.foreach(listener.onBlockFetchFailure(_, e)) } - } - override def onFailure(e: Throwable): Unit = { - logError("Failed while starting block fetches") - blockIds.foreach(listener.onBlockFetchFailure(_, e)) - } - }) + override def onFailure(e: Throwable): Unit = { + logError("Failed while starting block fetches") + blockIds.foreach(blockId => Utils.tryLog(listener.onBlockFetchFailure(blockId, e))) + } + }) + } } diff --git a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala index 9206237256e0b..02c657e1d61b5 100644 --- a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala +++ b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala @@ -23,7 +23,7 @@ import org.apache.spark.Logging import org.apache.spark.network.BlockDataManager import org.apache.spark.serializer.Serializer import org.apache.spark.network.buffer.{NioManagedBuffer, ManagedBuffer} -import org.apache.spark.network.client.{SluiceClient, RpcResponseCallback} +import org.apache.spark.network.client.{TransportClient, RpcResponseCallback} import org.apache.spark.network.server.{DefaultStreamManager, RpcHandler} import org.apache.spark.storage.{StorageLevel, BlockId} @@ -53,7 +53,7 @@ class NettyBlockRpcServer( import NettyMessages._ override def receive( - client: SluiceClient, + client: TransportClient, messageBytes: Array[Byte], responseContext: RpcResponseCallback): Unit = { val ser = serializer.newInstance() diff --git a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala index 6145c86c65617..501a2d123d456 100644 --- a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala +++ b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala @@ -22,10 +22,10 @@ import scala.concurrent.{Promise, Future} import org.apache.spark.SparkConf import org.apache.spark.network._ import org.apache.spark.network.buffer.ManagedBuffer -import org.apache.spark.network.client.{RpcResponseCallback, SluiceClient, SluiceClientFactory} +import org.apache.spark.network.client.{RpcResponseCallback, TransportClient, TransportClientFactory} import org.apache.spark.network.netty.NettyMessages.UploadBlock import org.apache.spark.network.server._ -import org.apache.spark.network.util.{ConfigProvider, SluiceConfig} +import org.apache.spark.network.util.{ConfigProvider, TransportConf} import org.apache.spark.serializer.JavaSerializer import org.apache.spark.storage.{BlockId, StorageLevel} import org.apache.spark.util.Utils @@ -37,20 +37,20 @@ class NettyBlockTransferService(conf: SparkConf) extends BlockTransferService { // TODO: Don't use Java serialization, use a more cross-version compatible serialization format. val serializer = new JavaSerializer(conf) - // Create a SluiceConfig using SparkConf. - private[this] val sluiceConf = new SluiceConfig( + // Create a TransportConfig using SparkConf. + private[this] val transportConf = new TransportConf( new ConfigProvider { override def get(name: String) = conf.get(name) }) - private[this] var sluiceContext: SluiceContext = _ - private[this] var server: SluiceServer = _ - private[this] var clientFactory: SluiceClientFactory = _ + private[this] var transportContext: TransportContext = _ + private[this] var server: TransportServer = _ + private[this] var clientFactory: TransportClientFactory = _ override def init(blockDataManager: BlockDataManager): Unit = { val streamManager = new DefaultStreamManager val rpcHandler = new NettyBlockRpcServer(serializer, streamManager, blockDataManager) - sluiceContext = new SluiceContext(sluiceConf, streamManager, rpcHandler) - clientFactory = sluiceContext.createClientFactory() - server = sluiceContext.createServer() + transportContext = new TransportContext(transportConf, streamManager, rpcHandler) + clientFactory = transportContext.createClientFactory() + server = transportContext.createServer() } override def fetchBlocks( @@ -59,7 +59,7 @@ class NettyBlockTransferService(conf: SparkConf) extends BlockTransferService { blockIds: Seq[String], listener: BlockFetchingListener): Unit = { val client = clientFactory.createClient(hostname, port) - new NettyBlockFetcher(serializer, client, blockIds, listener) + new NettyBlockFetcher(serializer, client, blockIds, listener).start() } override def hostName: String = Utils.localHostName() diff --git a/network/common/src/main/java/org/apache/spark/network/SluiceContext.java b/network/common/src/main/java/org/apache/spark/network/TransportContext.java similarity index 54% rename from network/common/src/main/java/org/apache/spark/network/SluiceContext.java rename to network/common/src/main/java/org/apache/spark/network/TransportContext.java index 7845ceb8b7d06..da0decac7e064 100644 --- a/network/common/src/main/java/org/apache/spark/network/SluiceContext.java +++ b/network/common/src/main/java/org/apache/spark/network/TransportContext.java @@ -22,38 +22,38 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import org.apache.spark.network.client.SluiceClient; -import org.apache.spark.network.client.SluiceClientFactory; -import org.apache.spark.network.client.SluiceResponseHandler; +import org.apache.spark.network.client.TransportClient; +import org.apache.spark.network.client.TransportClientFactory; +import org.apache.spark.network.client.TransportResponseHandler; import org.apache.spark.network.protocol.response.MessageDecoder; import org.apache.spark.network.protocol.response.MessageEncoder; import org.apache.spark.network.server.RpcHandler; -import org.apache.spark.network.server.SluiceChannelHandler; -import org.apache.spark.network.server.SluiceRequestHandler; -import org.apache.spark.network.server.SluiceServer; +import org.apache.spark.network.server.TransportClientHandler; +import org.apache.spark.network.server.TransportRequestHandler; +import org.apache.spark.network.server.TransportServer; import org.apache.spark.network.server.StreamManager; import org.apache.spark.network.util.NettyUtils; -import org.apache.spark.network.util.SluiceConfig; +import org.apache.spark.network.util.TransportConf; /** - * Contains the context to create a {@link SluiceServer}, {@link SluiceClientFactory}, and to setup - * Netty Channel pipelines with a {@link SluiceChannelHandler}. + * Contains the context to create a {@link TransportServer}, {@link TransportClientFactory}, and to + * setup Netty Channel pipelines with a {@link TransportClientHandler}. * - * The SluiceServer and SluiceClientFactory both create a SluiceChannelHandler for each channel. - * As each SluiceChannelHandler contains a SluiceClient, this enables server processes to send - * messages back to the client on an existing channel. + * The TransportServer and TransportClientFactory both create a TransportChannelHandler for each + * channel. As each TransportChannelHandler contains a TransportClient, this enables server + * processes to send messages back to the client on an existing channel. */ -public class SluiceContext { - private final Logger logger = LoggerFactory.getLogger(SluiceContext.class); +public class TransportContext { + private final Logger logger = LoggerFactory.getLogger(TransportContext.class); - private final SluiceConfig conf; + private final TransportConf conf; private final StreamManager streamManager; private final RpcHandler rpcHandler; private final MessageEncoder encoder; private final MessageDecoder decoder; - public SluiceContext(SluiceConfig conf, StreamManager streamManager, RpcHandler rpcHandler) { + public TransportContext(TransportConf conf, StreamManager streamManager, RpcHandler rpcHandler) { this.conf = conf; this.streamManager = streamManager; this.rpcHandler = rpcHandler; @@ -61,25 +61,26 @@ public SluiceContext(SluiceConfig conf, StreamManager streamManager, RpcHandler this.decoder = new MessageDecoder(); } - public SluiceClientFactory createClientFactory() { - return new SluiceClientFactory(this); + public TransportClientFactory createClientFactory() { + return new TransportClientFactory(this); } - public SluiceServer createServer() { - return new SluiceServer(this); + public TransportServer createServer() { + return new TransportServer(this); } /** * Initializes a client or server Netty Channel Pipeline which encodes/decodes messages and - * has a {@link SluiceChannelHandler} to handle request or response messages. + * has a {@link org.apache.spark.network.server.TransportClientHandler} to handle request or + * response messages. * - * @return Returns the created SluiceChannelHandler, which includes a SluiceClient that can be - * used to communicate on this channel. The SluiceClient is directly associated with a - * ChannelHandler to ensure all users of the same channel get the same SluiceClient object. + * @return Returns the created TransportChannelHandler, which includes a TransportClient that can + * be used to communicate on this channel. The TransportClient is directly associated with a + * ChannelHandler to ensure all users of the same channel get the same TransportClient object. */ - public SluiceChannelHandler initializePipeline(SocketChannel channel) { + public TransportClientHandler initializePipeline(SocketChannel channel) { try { - SluiceChannelHandler channelHandler = createChannelHandler(channel); + TransportClientHandler channelHandler = createChannelHandler(channel); channel.pipeline() .addLast("encoder", encoder) .addLast("frameDecoder", NettyUtils.createFrameDecoder()) @@ -99,13 +100,13 @@ public SluiceChannelHandler initializePipeline(SocketChannel channel) { * ResponseMessages. The channel is expected to have been successfully created, though certain * properties (such as the remoteAddress()) may not be available yet. */ - private SluiceChannelHandler createChannelHandler(Channel channel) { - SluiceResponseHandler responseHandler = new SluiceResponseHandler(channel); - SluiceClient client = new SluiceClient(channel, responseHandler); - SluiceRequestHandler requestHandler = new SluiceRequestHandler(channel, client, streamManager, + private TransportClientHandler createChannelHandler(Channel channel) { + TransportResponseHandler responseHandler = new TransportResponseHandler(channel); + TransportClient client = new TransportClient(channel, responseHandler); + TransportRequestHandler requestHandler = new TransportRequestHandler(channel, client, streamManager, rpcHandler); - return new SluiceChannelHandler(client, responseHandler, requestHandler); + return new TransportClientHandler(client, responseHandler, requestHandler); } - public SluiceConfig getConf() { return conf; } + public TransportConf getConf() { return conf; } } diff --git a/network/common/src/main/java/org/apache/spark/network/client/SluiceClient.java b/network/common/src/main/java/org/apache/spark/network/client/TransportClient.java similarity index 88% rename from network/common/src/main/java/org/apache/spark/network/client/SluiceClient.java rename to network/common/src/main/java/org/apache/spark/network/client/TransportClient.java index 88bf365cc5e2d..75e26cb7e60c1 100644 --- a/network/common/src/main/java/org/apache/spark/network/client/SluiceClient.java +++ b/network/common/src/main/java/org/apache/spark/network/client/TransportClient.java @@ -39,9 +39,9 @@ * hundreds of KB to a few MB. * * Note that while this client deals with the fetching of chunks from a stream (i.e., data plane), - * the actual setup of the streams is done outside the scope of Sluice. The convenience method - * "sendRPC" is provided to enable control plane communication between the client and server to - * perform this setup. + * the actual setup of the streams is done outside the scope of the transport layer. The convenience + * method "sendRPC" is provided to enable control plane communication between the client and server + * to perform this setup. * * For example, a typical workflow might be: * client.sendRPC(new OpenFile("/foo")) --> returns StreamId = 100 @@ -50,22 +50,22 @@ * ... * client.sendRPC(new CloseStream(100)) * - * Construct an instance of SluiceClient using {@link SluiceClientFactory}. A single SluiceClient - * may be used for multiple streams, but any given stream must be restricted to a single client, - * in order to avoid out-of-order responses. + * Construct an instance of TransportClient using {@link TransportClientFactory}. A single + * TransportClient may be used for multiple streams, but any given stream must be restricted to a + * single client, in order to avoid out-of-order responses. * - * NB: This class is used to make requests to the server, while {@link SluiceResponseHandler} is + * NB: This class is used to make requests to the server, while {@link TransportResponseHandler} is * responsible for handling responses from the server. * * Concurrency: thread safe and can be called from multiple threads. */ -public class SluiceClient implements Closeable { - private final Logger logger = LoggerFactory.getLogger(SluiceClient.class); +public class TransportClient implements Closeable { + private final Logger logger = LoggerFactory.getLogger(TransportClient.class); private final Channel channel; - private final SluiceResponseHandler handler; + private final TransportResponseHandler handler; - public SluiceClient(Channel channel, SluiceResponseHandler handler) { + public TransportClient(Channel channel, TransportResponseHandler handler) { this.channel = Preconditions.checkNotNull(channel); this.handler = Preconditions.checkNotNull(handler); } @@ -81,8 +81,8 @@ public boolean isActive() { * some streams may not support this. * * Multiple fetchChunk requests may be outstanding simultaneously, and the chunks are guaranteed - * to be returned in the same order that they were requested, assuming only a single SluiceClient - * is used to fetch the chunks. + * to be returned in the same order that they were requested, assuming only a single + * TransportClient is used to fetch the chunks. * * @param streamId Identifier that refers to a stream in the remote StreamManager. This should * be agreed upon by client and server beforehand. diff --git a/network/common/src/main/java/org/apache/spark/network/client/SluiceClientFactory.java b/network/common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java similarity index 81% rename from network/common/src/main/java/org/apache/spark/network/client/SluiceClientFactory.java rename to network/common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java index 4c345fb4eb92d..c351858bfe30d 100644 --- a/network/common/src/main/java/org/apache/spark/network/client/SluiceClientFactory.java +++ b/network/common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java @@ -36,35 +36,33 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import org.apache.spark.network.SluiceContext; -import org.apache.spark.network.protocol.response.MessageDecoder; -import org.apache.spark.network.protocol.response.MessageEncoder; -import org.apache.spark.network.server.SluiceChannelHandler; +import org.apache.spark.network.TransportContext; +import org.apache.spark.network.server.TransportClientHandler; import org.apache.spark.network.util.IOMode; import org.apache.spark.network.util.NettyUtils; -import org.apache.spark.network.util.SluiceConfig; +import org.apache.spark.network.util.TransportConf; /** - * Factory for creating {@link SluiceClient}s by using createClient. + * Factory for creating {@link TransportClient}s by using createClient. * * The factory maintains a connection pool to other hosts and should return the same - * {@link SluiceClient} for the same remote host. It also shares a single worker thread pool for - * all {@link SluiceClient}s. + * {@link TransportClient} for the same remote host. It also shares a single worker thread pool for + * all {@link TransportClient}s. */ -public class SluiceClientFactory implements Closeable { - private final Logger logger = LoggerFactory.getLogger(SluiceClientFactory.class); +public class TransportClientFactory implements Closeable { + private final Logger logger = LoggerFactory.getLogger(TransportClientFactory.class); - private final SluiceContext context; - private final SluiceConfig conf; - private final ConcurrentHashMap connectionPool; + private final TransportContext context; + private final TransportConf conf; + private final ConcurrentHashMap connectionPool; private final Class socketChannelClass; private final EventLoopGroup workerGroup; - public SluiceClientFactory(SluiceContext context) { + public TransportClientFactory(TransportContext context) { this.context = context; this.conf = context.getConf(); - this.connectionPool = new ConcurrentHashMap(); + this.connectionPool = new ConcurrentHashMap(); IOMode ioMode = IOMode.valueOf(conf.ioMode()); this.socketChannelClass = NettyUtils.getClientChannelClass(ioMode); @@ -78,11 +76,11 @@ public SluiceClientFactory(SluiceContext context) { * * Concurrency: This method is safe to call from multiple threads. */ - public SluiceClient createClient(String remoteHost, int remotePort) throws TimeoutException { + public TransportClient createClient(String remoteHost, int remotePort) throws TimeoutException { // Get connection from the connection pool first. // If it is not found or not active, create a new one. final InetSocketAddress address = new InetSocketAddress(remoteHost, remotePort); - SluiceClient cachedClient = connectionPool.get(address); + TransportClient cachedClient = connectionPool.get(address); if (cachedClient != null && cachedClient.isActive()) { return cachedClient; } else if (cachedClient != null) { @@ -105,8 +103,8 @@ public SluiceClient createClient(String remoteHost, int remotePort) throws Timeo bootstrap.handler(new ChannelInitializer() { @Override public void initChannel(SocketChannel ch) { - SluiceChannelHandler channelHandler = context.initializePipeline(ch); - SluiceClient oldClient = connectionPool.putIfAbsent(address, channelHandler.getClient()); + TransportClientHandler channelHandler = context.initializePipeline(ch); + TransportClient oldClient = connectionPool.putIfAbsent(address, channelHandler.getClient()); if (oldClient != null) { logger.debug("Two clients were created concurrently, second one will be disposed."); ch.close(); @@ -123,7 +121,7 @@ public void initChannel(SocketChannel ch) { String.format("Connecting to %s timed out (%s ms)", address, conf.connectionTimeoutMs())); } - SluiceClient client = connectionPool.get(address); + TransportClient client = connectionPool.get(address); if (client == null) { // The only way we should be able to reach here is if the client we created started out // in the "inactive" state, and someone else simultaneously tried to create another client to @@ -136,7 +134,7 @@ public void initChannel(SocketChannel ch) { /** Close all connections in the connection pool, and shutdown the worker thread pool. */ @Override public void close() { - for (SluiceClient client : connectionPool.values()) { + for (TransportClient client : connectionPool.values()) { client.close(); } connectionPool.clear(); diff --git a/network/common/src/main/java/org/apache/spark/network/client/SluiceResponseHandler.java b/network/common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java similarity index 95% rename from network/common/src/main/java/org/apache/spark/network/client/SluiceResponseHandler.java rename to network/common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java index e6ebe7159f1bd..187b20d27656b 100644 --- a/network/common/src/main/java/org/apache/spark/network/client/SluiceResponseHandler.java +++ b/network/common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java @@ -35,13 +35,13 @@ import org.apache.spark.network.util.NettyUtils; /** - * Handler that processes server responses, in response to requests issued from a [[SluiceClient]]. - * It works by tracking the list of outstanding requests (and their callbacks). + * Handler that processes server responses, in response to requests issued from a + * [[TransportClient]]. It works by tracking the list of outstanding requests (and their callbacks). * * Concurrency: thread safe and can be called from multiple threads. */ -public class SluiceResponseHandler extends MessageHandler { - private final Logger logger = LoggerFactory.getLogger(SluiceResponseHandler.class); +public class TransportResponseHandler extends MessageHandler { + private final Logger logger = LoggerFactory.getLogger(TransportResponseHandler.class); private final Channel channel; @@ -49,7 +49,7 @@ public class SluiceResponseHandler extends MessageHandler { private final Map outstandingRpcs; - public SluiceResponseHandler(Channel channel) { + public TransportResponseHandler(Channel channel) { this.channel = channel; this.outstandingFetches = new ConcurrentHashMap(); this.outstandingRpcs = new ConcurrentHashMap(); diff --git a/network/common/src/main/java/org/apache/spark/network/server/RpcHandler.java b/network/common/src/main/java/org/apache/spark/network/server/RpcHandler.java index 5700cc83bd9c8..f54a696b8ff79 100644 --- a/network/common/src/main/java/org/apache/spark/network/server/RpcHandler.java +++ b/network/common/src/main/java/org/apache/spark/network/server/RpcHandler.java @@ -18,10 +18,10 @@ package org.apache.spark.network.server; import org.apache.spark.network.client.RpcResponseCallback; -import org.apache.spark.network.client.SluiceClient; +import org.apache.spark.network.client.TransportClient; /** - * Handler for sendRPC() messages sent by {@link org.apache.spark.network.client.SluiceClient}s. + * Handler for sendRPC() messages sent by {@link org.apache.spark.network.client.TransportClient}s. */ public interface RpcHandler { /** @@ -34,5 +34,5 @@ public interface RpcHandler { * @param callback Callback which should be invoked exactly once upon success or failure of the * RPC. */ - void receive(SluiceClient client, byte[] message, RpcResponseCallback callback); + void receive(TransportClient client, byte[] message, RpcResponseCallback callback); } diff --git a/network/common/src/main/java/org/apache/spark/network/server/StreamManager.java b/network/common/src/main/java/org/apache/spark/network/server/StreamManager.java index 47b74b229fdec..5a9a14a180c10 100644 --- a/network/common/src/main/java/org/apache/spark/network/server/StreamManager.java +++ b/network/common/src/main/java/org/apache/spark/network/server/StreamManager.java @@ -21,11 +21,11 @@ /** * The StreamManager is used to fetch individual chunks from a stream. This is used in - * {@link SluiceRequestHandler} in order to respond to fetchChunk() requests. Creation of the - * stream is outside the scope of Sluice, but a given stream is guaranteed to be read by only one - * client connection, meaning that getChunk() for a particular stream will be called serially and - * that once the connection associated with the stream is closed, that stream will never be used - * again. + * {@link TransportRequestHandler} in order to respond to fetchChunk() requests. Creation of the + * stream is outside the scope of the transport layer, but a given stream is guaranteed to be read + * by only one client connection, meaning that getChunk() for a particular stream will be called + * serially and that once the connection associated with the stream is closed, that stream will + * never be used again. */ public abstract class StreamManager { /** diff --git a/network/common/src/main/java/org/apache/spark/network/server/SluiceChannelHandler.java b/network/common/src/main/java/org/apache/spark/network/server/TransportClientHandler.java similarity index 67% rename from network/common/src/main/java/org/apache/spark/network/server/SluiceChannelHandler.java rename to network/common/src/main/java/org/apache/spark/network/server/TransportClientHandler.java index d5a91ec1b6c28..08cc1b1f95de6 100644 --- a/network/common/src/main/java/org/apache/spark/network/server/SluiceChannelHandler.java +++ b/network/common/src/main/java/org/apache/spark/network/server/TransportClientHandler.java @@ -22,8 +22,8 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import org.apache.spark.network.client.SluiceClient; -import org.apache.spark.network.client.SluiceResponseHandler; +import org.apache.spark.network.client.TransportClient; +import org.apache.spark.network.client.TransportResponseHandler; import org.apache.spark.network.protocol.Message; import org.apache.spark.network.protocol.request.RequestMessage; import org.apache.spark.network.protocol.response.ResponseMessage; @@ -31,33 +31,34 @@ /** * A handler which is used for delegating requests to the - * {@link org.apache.spark.network.server.SluiceRequestHandler} and responses to the - * {@link org.apache.spark.network.client.SluiceResponseHandler}. + * {@link TransportRequestHandler} and responses to the + * {@link org.apache.spark.network.client.TransportResponseHandler}. * - * All channels created in Sluice are bidirectional. When the Client initiates a Netty Channel - * with a RequestMessage (which gets handled by the Server's RequestHandler), the Server will - * produce a ResponseMessage (handled by the Client's ResponseHandler). However, the Server also - * gets a handle on the same Channel, so it may then begin to send RequestMessages to the Client. + * All channels created in the transport layer are bidirectional. When the Client initiates a Netty + * Channel with a RequestMessage (which gets handled by the Server's RequestHandler), the Server + * will produce a ResponseMessage (handled by the Client's ResponseHandler). However, the Server + * also gets a handle on the same Channel, so it may then begin to send RequestMessages to the + * Client. * This means that the Client also needs a RequestHandler and the Server needs a ResponseHandler, * for the Client's responses to the Server's requests. */ -public class SluiceChannelHandler extends SimpleChannelInboundHandler { - private final Logger logger = LoggerFactory.getLogger(SluiceChannelHandler.class); +public class TransportClientHandler extends SimpleChannelInboundHandler { + private final Logger logger = LoggerFactory.getLogger(TransportClientHandler.class); - private final SluiceClient client; - private final SluiceResponseHandler responseHandler; - private final SluiceRequestHandler requestHandler; + private final TransportClient client; + private final TransportResponseHandler responseHandler; + private final TransportRequestHandler requestHandler; - public SluiceChannelHandler( - SluiceClient client, - SluiceResponseHandler responseHandler, - SluiceRequestHandler requestHandler) { + public TransportClientHandler( + TransportClient client, + TransportResponseHandler responseHandler, + TransportRequestHandler requestHandler) { this.client = client; this.responseHandler = responseHandler; this.requestHandler = requestHandler; } - public SluiceClient getClient() { + public TransportClient getClient() { return client; } diff --git a/network/common/src/main/java/org/apache/spark/network/server/SluiceRequestHandler.java b/network/common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java similarity index 92% rename from network/common/src/main/java/org/apache/spark/network/server/SluiceRequestHandler.java rename to network/common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java index 5f5111e0a7638..08a2a3ec52f8b 100644 --- a/network/common/src/main/java/org/apache/spark/network/server/SluiceRequestHandler.java +++ b/network/common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java @@ -29,7 +29,7 @@ import org.apache.spark.network.buffer.ManagedBuffer; import org.apache.spark.network.client.RpcResponseCallback; -import org.apache.spark.network.client.SluiceClient; +import org.apache.spark.network.client.TransportClient; import org.apache.spark.network.protocol.Encodable; import org.apache.spark.network.protocol.request.RequestMessage; import org.apache.spark.network.protocol.request.ChunkFetchRequest; @@ -45,16 +45,16 @@ * attached to a single Netty channel, and keeps track of which streams have been fetched via this * channel, in order to clean them up if the channel is terminated (see #channelUnregistered). * - * The messages should have been processed by the pipeline setup by {@link SluiceServer}. + * The messages should have been processed by the pipeline setup by {@link TransportServer}. */ -public class SluiceRequestHandler extends MessageHandler { - private final Logger logger = LoggerFactory.getLogger(SluiceRequestHandler.class); +public class TransportRequestHandler extends MessageHandler { + private final Logger logger = LoggerFactory.getLogger(TransportRequestHandler.class); /** The Netty channel that this handler is associated with. */ private final Channel channel; /** Client on the same channel allowing us to talk back to the requester. */ - private final SluiceClient reverseClient; + private final TransportClient reverseClient; /** Returns each chunk part of a stream. */ private final StreamManager streamManager; @@ -65,11 +65,11 @@ public class SluiceRequestHandler extends MessageHandler { /** List of all stream ids that have been read on this handler, used for cleanup. */ private final Set streamIds; - public SluiceRequestHandler( - Channel channel, - SluiceClient reverseClient, - StreamManager streamManager, - RpcHandler rpcHandler) { + public TransportRequestHandler( + Channel channel, + TransportClient reverseClient, + StreamManager streamManager, + RpcHandler rpcHandler) { this.channel = channel; this.reverseClient = reverseClient; this.streamManager = streamManager; diff --git a/network/common/src/main/java/org/apache/spark/network/server/SluiceServer.java b/network/common/src/main/java/org/apache/spark/network/server/TransportServer.java similarity index 90% rename from network/common/src/main/java/org/apache/spark/network/server/SluiceServer.java rename to network/common/src/main/java/org/apache/spark/network/server/TransportServer.java index 965db536a2782..973fb05f57944 100644 --- a/network/common/src/main/java/org/apache/spark/network/server/SluiceServer.java +++ b/network/common/src/main/java/org/apache/spark/network/server/TransportServer.java @@ -31,25 +31,25 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import org.apache.spark.network.SluiceContext; +import org.apache.spark.network.TransportContext; import org.apache.spark.network.util.IOMode; import org.apache.spark.network.util.NettyUtils; -import org.apache.spark.network.util.SluiceConfig; +import org.apache.spark.network.util.TransportConf; /** * Server for the efficient, low-level streaming service. */ -public class SluiceServer implements Closeable { - private final Logger logger = LoggerFactory.getLogger(SluiceServer.class); +public class TransportServer implements Closeable { + private final Logger logger = LoggerFactory.getLogger(TransportServer.class); - private final SluiceContext context; - private final SluiceConfig conf; + private final TransportContext context; + private final TransportConf conf; private ServerBootstrap bootstrap; private ChannelFuture channelFuture; private int port; - public SluiceServer(SluiceContext context) { + public TransportServer(TransportContext context) { this.context = context; this.conf = context.getConf(); diff --git a/network/common/src/main/java/org/apache/spark/network/util/ConfigProvider.java b/network/common/src/main/java/org/apache/spark/network/util/ConfigProvider.java index 2dc0e248ae835..d944d9da1c7f8 100644 --- a/network/common/src/main/java/org/apache/spark/network/util/ConfigProvider.java +++ b/network/common/src/main/java/org/apache/spark/network/util/ConfigProvider.java @@ -20,7 +20,7 @@ import java.util.NoSuchElementException; /** - * Provides a mechanism for constructing a {@link SluiceConfig} using some sort of configuration. + * Provides a mechanism for constructing a {@link TransportConf} using some sort of configuration. */ public abstract class ConfigProvider { /** Obtains the value of the given config, throws NoSuchElementException if it doesn't exist. */ diff --git a/network/common/src/main/java/org/apache/spark/network/util/IOMode.java b/network/common/src/main/java/org/apache/spark/network/util/IOMode.java index c0aa12c81ba64..6b208d95bbfbc 100644 --- a/network/common/src/main/java/org/apache/spark/network/util/IOMode.java +++ b/network/common/src/main/java/org/apache/spark/network/util/IOMode.java @@ -23,5 +23,5 @@ * AUTO is used to select EPOLL if it's available, or NIO otherwise. */ public enum IOMode { - NIO, EPOLL, AUTO + NIO, EPOLL } diff --git a/network/common/src/main/java/org/apache/spark/network/util/NettyUtils.java b/network/common/src/main/java/org/apache/spark/network/util/NettyUtils.java index a925c05469d3c..b1872341198e0 100644 --- a/network/common/src/main/java/org/apache/spark/network/util/NettyUtils.java +++ b/network/common/src/main/java/org/apache/spark/network/util/NettyUtils.java @@ -39,9 +39,6 @@ public class NettyUtils { /** Creates a Netty EventLoopGroup based on the IOMode. */ public static EventLoopGroup createEventLoop(IOMode mode, int numThreads, String threadPrefix) { - if (mode == IOMode.AUTO) { - mode = autoselectMode(); - } ThreadFactory threadFactory = new ThreadFactoryBuilder() .setDaemon(true) @@ -60,9 +57,6 @@ public static EventLoopGroup createEventLoop(IOMode mode, int numThreads, String /** Returns the correct (client) SocketChannel class based on IOMode. */ public static Class getClientChannelClass(IOMode mode) { - if (mode == IOMode.AUTO) { - mode = autoselectMode(); - } switch (mode) { case NIO: return NioSocketChannel.class; @@ -75,9 +69,6 @@ public static Class getClientChannelClass(IOMode mode) { /** Returns the correct ServerSocketChannel class based on IOMode. */ public static Class getServerChannelClass(IOMode mode) { - if (mode == IOMode.AUTO) { - mode = autoselectMode(); - } switch (mode) { case NIO: return NioServerSocketChannel.class; @@ -108,9 +99,4 @@ public static String getRemoteAddress(Channel channel) { } return ""; } - - /** Returns EPOLL if it's available on this system, NIO otherwise. */ - private static IOMode autoselectMode() { - return Epoll.isAvailable() ? IOMode.EPOLL : IOMode.NIO; - } } diff --git a/network/common/src/main/java/org/apache/spark/network/util/DefaultConfigProvider.java b/network/common/src/main/java/org/apache/spark/network/util/SystemPropertyConfigProvider.java similarity index 94% rename from network/common/src/main/java/org/apache/spark/network/util/DefaultConfigProvider.java rename to network/common/src/main/java/org/apache/spark/network/util/SystemPropertyConfigProvider.java index cef88c0091eff..f15ec8d294258 100644 --- a/network/common/src/main/java/org/apache/spark/network/util/DefaultConfigProvider.java +++ b/network/common/src/main/java/org/apache/spark/network/util/SystemPropertyConfigProvider.java @@ -20,7 +20,7 @@ import java.util.NoSuchElementException; /** Uses System properties to obtain config values. */ -public class DefaultConfigProvider extends ConfigProvider { +public class SystemPropertyConfigProvider extends ConfigProvider { @Override public String get(String name) { String value = System.getProperty(name); diff --git a/network/common/src/main/java/org/apache/spark/network/util/SluiceConfig.java b/network/common/src/main/java/org/apache/spark/network/util/TransportConf.java similarity index 94% rename from network/common/src/main/java/org/apache/spark/network/util/SluiceConfig.java rename to network/common/src/main/java/org/apache/spark/network/util/TransportConf.java index 26fa3229c4721..80f65d98032da 100644 --- a/network/common/src/main/java/org/apache/spark/network/util/SluiceConfig.java +++ b/network/common/src/main/java/org/apache/spark/network/util/TransportConf.java @@ -20,17 +20,17 @@ /** * A central location that tracks all the settings we expose to users. */ -public class SluiceConfig { +public class TransportConf { private final ConfigProvider conf; - public SluiceConfig(ConfigProvider conf) { + public TransportConf(ConfigProvider conf) { this.conf = conf; } /** Port the server listens on. Default to a random port. */ public int serverPort() { return conf.getInt("spark.shuffle.io.port", 0); } - /** IO mode: nio, epoll, or auto (try epoll first and then nio). */ + /** IO mode: nio or epoll */ public String ioMode() { return conf.get("spark.shuffle.io.mode", "NIO").toUpperCase(); } /** Connect timeout in secs. Default 120 secs. */ diff --git a/network/common/src/test/java/org/apache/spark/network/ChunkFetchIntegrationSuite.java b/network/common/src/test/java/org/apache/spark/network/ChunkFetchIntegrationSuite.java index f7f53e2df4a49..00ed7b527abd5 100644 --- a/network/common/src/test/java/org/apache/spark/network/ChunkFetchIntegrationSuite.java +++ b/network/common/src/test/java/org/apache/spark/network/ChunkFetchIntegrationSuite.java @@ -41,20 +41,20 @@ import org.apache.spark.network.buffer.ManagedBuffer; import org.apache.spark.network.buffer.NioManagedBuffer; import org.apache.spark.network.client.ChunkReceivedCallback; -import org.apache.spark.network.client.SluiceClient; -import org.apache.spark.network.client.SluiceClientFactory; -import org.apache.spark.network.server.SluiceServer; +import org.apache.spark.network.client.TransportClient; +import org.apache.spark.network.client.TransportClientFactory; +import org.apache.spark.network.server.TransportServer; import org.apache.spark.network.server.StreamManager; -import org.apache.spark.network.util.DefaultConfigProvider; -import org.apache.spark.network.util.SluiceConfig; +import org.apache.spark.network.util.SystemPropertyConfigProvider; +import org.apache.spark.network.util.TransportConf; public class ChunkFetchIntegrationSuite { static final long STREAM_ID = 1; static final int BUFFER_CHUNK_INDEX = 0; static final int FILE_CHUNK_INDEX = 1; - static SluiceServer server; - static SluiceClientFactory clientFactory; + static TransportServer server; + static TransportClientFactory clientFactory; static StreamManager streamManager; static File testFile; @@ -80,7 +80,7 @@ public static void setUp() throws Exception { fp.close(); fileChunk = new FileSegmentManagedBuffer(testFile, 10, testFile.length() - 25); - SluiceConfig conf = new SluiceConfig(new DefaultConfigProvider()); + TransportConf conf = new TransportConf(new SystemPropertyConfigProvider()); streamManager = new StreamManager() { @Override public ManagedBuffer getChunk(long streamId, int chunkIndex) { @@ -94,7 +94,7 @@ public ManagedBuffer getChunk(long streamId, int chunkIndex) { } } }; - SluiceContext context = new SluiceContext(conf, streamManager, new NoOpRpcHandler()); + TransportContext context = new TransportContext(conf, streamManager, new NoOpRpcHandler()); server = context.createServer(); clientFactory = context.createClientFactory(); } @@ -119,7 +119,7 @@ public void releaseBuffers() { } private FetchResult fetchChunks(List chunkIndices) throws Exception { - SluiceClient client = clientFactory.createClient(TestUtils.getLocalHost(), server.getPort()); + TransportClient client = clientFactory.createClient(TestUtils.getLocalHost(), server.getPort()); final Semaphore sem = new Semaphore(0); final FetchResult res = new FetchResult(); diff --git a/network/common/src/test/java/org/apache/spark/network/NoOpRpcHandler.java b/network/common/src/test/java/org/apache/spark/network/NoOpRpcHandler.java index e7bad051c6200..7aa37efc582e4 100644 --- a/network/common/src/test/java/org/apache/spark/network/NoOpRpcHandler.java +++ b/network/common/src/test/java/org/apache/spark/network/NoOpRpcHandler.java @@ -16,13 +16,13 @@ */ import org.apache.spark.network.client.RpcResponseCallback; +import org.apache.spark.network.client.TransportClient; import org.apache.spark.network.server.RpcHandler; -import org.apache.spark.network.client.SluiceClient; /** Test RpcHandler which always returns a zero-sized success. */ public class NoOpRpcHandler implements RpcHandler { @Override - public void receive(SluiceClient client, byte[] message, RpcResponseCallback callback) { + public void receive(TransportClient client, byte[] message, RpcResponseCallback callback) { callback.onSuccess(new byte[0]); } } diff --git a/network/common/src/test/java/org/apache/spark/network/ProtocolSuite.java b/network/common/src/test/java/org/apache/spark/network/ProtocolSuite.java index 9f20496f75f82..6932760c44fee 100644 --- a/network/common/src/test/java/org/apache/spark/network/ProtocolSuite.java +++ b/network/common/src/test/java/org/apache/spark/network/ProtocolSuite.java @@ -83,4 +83,4 @@ public void responses() { testServerToClient(new RpcFailure(0, "this is an error")); testServerToClient(new RpcFailure(0, "")); } -} \ No newline at end of file +} diff --git a/network/common/src/test/java/org/apache/spark/network/RpcIntegrationSuite.java b/network/common/src/test/java/org/apache/spark/network/RpcIntegrationSuite.java index a909e4032d608..19ce9c6a8d826 100644 --- a/network/common/src/test/java/org/apache/spark/network/RpcIntegrationSuite.java +++ b/network/common/src/test/java/org/apache/spark/network/RpcIntegrationSuite.java @@ -33,25 +33,25 @@ import static org.junit.Assert.*; import org.apache.spark.network.client.RpcResponseCallback; -import org.apache.spark.network.client.SluiceClient; -import org.apache.spark.network.client.SluiceClientFactory; +import org.apache.spark.network.client.TransportClient; +import org.apache.spark.network.client.TransportClientFactory; import org.apache.spark.network.server.DefaultStreamManager; import org.apache.spark.network.server.RpcHandler; -import org.apache.spark.network.server.SluiceServer; -import org.apache.spark.network.util.DefaultConfigProvider; -import org.apache.spark.network.util.SluiceConfig; +import org.apache.spark.network.server.TransportServer; +import org.apache.spark.network.util.SystemPropertyConfigProvider; +import org.apache.spark.network.util.TransportConf; public class RpcIntegrationSuite { - static SluiceServer server; - static SluiceClientFactory clientFactory; + static TransportServer server; + static TransportClientFactory clientFactory; static RpcHandler rpcHandler; @BeforeClass public static void setUp() throws Exception { - SluiceConfig conf = new SluiceConfig(new DefaultConfigProvider()); + TransportConf conf = new TransportConf(new SystemPropertyConfigProvider()); rpcHandler = new RpcHandler() { @Override - public void receive(SluiceClient client, byte[] message, RpcResponseCallback callback) { + public void receive(TransportClient client, byte[] message, RpcResponseCallback callback) { String msg = new String(message, Charsets.UTF_8); String[] parts = msg.split("/"); if (parts[0].equals("hello")) { @@ -63,7 +63,7 @@ public void receive(SluiceClient client, byte[] message, RpcResponseCallback cal } } }; - SluiceContext context = new SluiceContext(conf, new DefaultStreamManager(), rpcHandler); + TransportContext context = new TransportContext(conf, new DefaultStreamManager(), rpcHandler); server = context.createServer(); clientFactory = context.createClientFactory(); } @@ -80,7 +80,7 @@ class RpcResult { } private RpcResult sendRPC(String ... commands) throws Exception { - SluiceClient client = clientFactory.createClient(TestUtils.getLocalHost(), server.getPort()); + TransportClient client = clientFactory.createClient(TestUtils.getLocalHost(), server.getPort()); final Semaphore sem = new Semaphore(0); final RpcResult res = new RpcResult(); diff --git a/network/common/src/test/java/org/apache/spark/network/SluiceClientFactorySuite.java b/network/common/src/test/java/org/apache/spark/network/TransportClientFactorySuite.java similarity index 62% rename from network/common/src/test/java/org/apache/spark/network/SluiceClientFactorySuite.java rename to network/common/src/test/java/org/apache/spark/network/TransportClientFactorySuite.java index 219d6cc998bd7..f76b4bc55182d 100644 --- a/network/common/src/test/java/org/apache/spark/network/SluiceClientFactorySuite.java +++ b/network/common/src/test/java/org/apache/spark/network/TransportClientFactorySuite.java @@ -26,28 +26,28 @@ import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertTrue; -import org.apache.spark.network.client.SluiceClient; -import org.apache.spark.network.client.SluiceClientFactory; +import org.apache.spark.network.client.TransportClient; +import org.apache.spark.network.client.TransportClientFactory; import org.apache.spark.network.server.DefaultStreamManager; import org.apache.spark.network.server.RpcHandler; -import org.apache.spark.network.server.SluiceServer; +import org.apache.spark.network.server.TransportServer; import org.apache.spark.network.server.StreamManager; -import org.apache.spark.network.util.DefaultConfigProvider; +import org.apache.spark.network.util.SystemPropertyConfigProvider; import org.apache.spark.network.util.JavaUtils; -import org.apache.spark.network.util.SluiceConfig; +import org.apache.spark.network.util.TransportConf; -public class SluiceClientFactorySuite { - private SluiceConfig conf; - private SluiceContext context; - private SluiceServer server1; - private SluiceServer server2; +public class TransportClientFactorySuite { + private TransportConf conf; + private TransportContext context; + private TransportServer server1; + private TransportServer server2; @Before public void setUp() { - conf = new SluiceConfig(new DefaultConfigProvider()); + conf = new TransportConf(new SystemPropertyConfigProvider()); StreamManager streamManager = new DefaultStreamManager(); RpcHandler rpcHandler = new NoOpRpcHandler(); - context = new SluiceContext(conf, streamManager, rpcHandler); + context = new TransportContext(conf, streamManager, rpcHandler); server1 = context.createServer(); server2 = context.createServer(); } @@ -60,10 +60,10 @@ public void tearDown() { @Test public void createAndReuseBlockClients() throws TimeoutException { - SluiceClientFactory factory = context.createClientFactory(); - SluiceClient c1 = factory.createClient(TestUtils.getLocalHost(), server1.getPort()); - SluiceClient c2 = factory.createClient(TestUtils.getLocalHost(), server1.getPort()); - SluiceClient c3 = factory.createClient(TestUtils.getLocalHost(), server2.getPort()); + TransportClientFactory factory = context.createClientFactory(); + TransportClient c1 = factory.createClient(TestUtils.getLocalHost(), server1.getPort()); + TransportClient c2 = factory.createClient(TestUtils.getLocalHost(), server1.getPort()); + TransportClient c3 = factory.createClient(TestUtils.getLocalHost(), server2.getPort()); assertTrue(c1.isActive()); assertTrue(c3.isActive()); assertTrue(c1 == c2); @@ -73,8 +73,8 @@ public void createAndReuseBlockClients() throws TimeoutException { @Test public void neverReturnInactiveClients() throws Exception { - SluiceClientFactory factory = context.createClientFactory(); - SluiceClient c1 = factory.createClient(TestUtils.getLocalHost(), server1.getPort()); + TransportClientFactory factory = context.createClientFactory(); + TransportClient c1 = factory.createClient(TestUtils.getLocalHost(), server1.getPort()); c1.close(); long start = System.currentTimeMillis(); @@ -83,7 +83,7 @@ public void neverReturnInactiveClients() throws Exception { } assertFalse(c1.isActive()); - SluiceClient c2 = factory.createClient(TestUtils.getLocalHost(), server1.getPort()); + TransportClient c2 = factory.createClient(TestUtils.getLocalHost(), server1.getPort()); assertFalse(c1 == c2); assertTrue(c2.isActive()); factory.close(); @@ -91,9 +91,9 @@ public void neverReturnInactiveClients() throws Exception { @Test public void closeBlockClientsWithFactory() throws TimeoutException { - SluiceClientFactory factory = context.createClientFactory(); - SluiceClient c1 = factory.createClient(TestUtils.getLocalHost(), server1.getPort()); - SluiceClient c2 = factory.createClient(TestUtils.getLocalHost(), server2.getPort()); + TransportClientFactory factory = context.createClientFactory(); + TransportClient c1 = factory.createClient(TestUtils.getLocalHost(), server1.getPort()); + TransportClient c2 = factory.createClient(TestUtils.getLocalHost(), server2.getPort()); assertTrue(c1.isActive()); assertTrue(c2.isActive()); factory.close(); diff --git a/network/common/src/test/java/org/apache/spark/network/SluiceResponseHandlerSuite.java b/network/common/src/test/java/org/apache/spark/network/TransportResponseHandlerSuite.java similarity index 88% rename from network/common/src/test/java/org/apache/spark/network/SluiceResponseHandlerSuite.java rename to network/common/src/test/java/org/apache/spark/network/TransportResponseHandlerSuite.java index 3138c5d21a85f..6e360e96099f4 100644 --- a/network/common/src/test/java/org/apache/spark/network/SluiceResponseHandlerSuite.java +++ b/network/common/src/test/java/org/apache/spark/network/TransportResponseHandlerSuite.java @@ -17,7 +17,6 @@ package org.apache.spark.network; -import io.netty.channel.embedded.EmbeddedChannel; import io.netty.channel.local.LocalChannel; import org.junit.Test; @@ -29,19 +28,19 @@ import org.apache.spark.network.buffer.ManagedBuffer; import org.apache.spark.network.client.ChunkReceivedCallback; import org.apache.spark.network.client.RpcResponseCallback; -import org.apache.spark.network.client.SluiceResponseHandler; +import org.apache.spark.network.client.TransportResponseHandler; import org.apache.spark.network.protocol.StreamChunkId; import org.apache.spark.network.protocol.response.ChunkFetchFailure; import org.apache.spark.network.protocol.response.ChunkFetchSuccess; import org.apache.spark.network.protocol.response.RpcFailure; import org.apache.spark.network.protocol.response.RpcResponse; -public class SluiceResponseHandlerSuite { +public class TransportResponseHandlerSuite { @Test public void handleSuccessfulFetch() { StreamChunkId streamChunkId = new StreamChunkId(1, 0); - SluiceResponseHandler handler = new SluiceResponseHandler(new LocalChannel()); + TransportResponseHandler handler = new TransportResponseHandler(new LocalChannel()); ChunkReceivedCallback callback = mock(ChunkReceivedCallback.class); handler.addFetchRequest(streamChunkId, callback); assertEquals(1, handler.numOutstandingRequests()); @@ -54,7 +53,7 @@ public void handleSuccessfulFetch() { @Test public void handleFailedFetch() { StreamChunkId streamChunkId = new StreamChunkId(1, 0); - SluiceResponseHandler handler = new SluiceResponseHandler(new LocalChannel()); + TransportResponseHandler handler = new TransportResponseHandler(new LocalChannel()); ChunkReceivedCallback callback = mock(ChunkReceivedCallback.class); handler.addFetchRequest(streamChunkId, callback); assertEquals(1, handler.numOutstandingRequests()); @@ -66,7 +65,7 @@ public void handleFailedFetch() { @Test public void clearAllOutstandingRequests() { - SluiceResponseHandler handler = new SluiceResponseHandler(new LocalChannel()); + TransportResponseHandler handler = new TransportResponseHandler(new LocalChannel()); ChunkReceivedCallback callback = mock(ChunkReceivedCallback.class); handler.addFetchRequest(new StreamChunkId(1, 0), callback); handler.addFetchRequest(new StreamChunkId(1, 1), callback); @@ -85,7 +84,7 @@ public void clearAllOutstandingRequests() { @Test public void handleSuccessfulRPC() { - SluiceResponseHandler handler = new SluiceResponseHandler(new LocalChannel()); + TransportResponseHandler handler = new TransportResponseHandler(new LocalChannel()); RpcResponseCallback callback = mock(RpcResponseCallback.class); handler.addRpcRequest(12345, callback); assertEquals(1, handler.numOutstandingRequests()); @@ -101,7 +100,7 @@ public void handleSuccessfulRPC() { @Test public void handleFailedRPC() { - SluiceResponseHandler handler = new SluiceResponseHandler(new LocalChannel()); + TransportResponseHandler handler = new TransportResponseHandler(new LocalChannel()); RpcResponseCallback callback = mock(RpcResponseCallback.class); handler.addRpcRequest(12345, callback); assertEquals(1, handler.numOutstandingRequests()); From 14e37f7ce36f42e07fbe4f5382c0674754f523c9 Mon Sep 17 00:00:00 2001 From: Aaron Davidson Date: Tue, 28 Oct 2014 15:56:40 -0700 Subject: [PATCH 42/46] Address Reynold's comments --- .../network/netty/NettyBlockFetcher.scala | 2 +- .../shuffle/FileShuffleBlockManager.scala | 4 +- .../storage/ShuffleBlockFetcherIterator.scala | 7 ++-- .../spark/network/TransportContext.java | 27 +++++++----- .../buffer/FileSegmentManagedBuffer.java | 3 +- .../spark/network/buffer/ManagedBuffer.java | 3 +- .../network/buffer/NettyManagedBuffer.java | 4 +- .../network/buffer/NioManagedBuffer.java | 2 +- .../client/ChunkFetchFailureException.java | 10 +---- .../spark/network/client/TransportClient.java | 22 +++++----- .../client/TransportClientFactory.java | 40 ++++++++++-------- .../client/TransportResponseHandler.java | 42 +++++++++---------- .../{response => }/ChunkFetchFailure.java | 7 +--- .../{request => }/ChunkFetchRequest.java | 6 +-- .../{response => }/ChunkFetchSuccess.java | 8 ++-- .../spark/network/protocol/Encodable.java | 6 +++ .../spark/network/protocol/Message.java | 2 +- .../{response => }/MessageDecoder.java | 8 +--- .../{response => }/MessageEncoder.java | 4 +- .../{request => }/RequestMessage.java | 2 +- .../{response => }/ResponseMessage.java | 2 +- .../protocol/{response => }/RpcFailure.java | 20 ++++----- .../protocol/{request => }/RpcRequest.java | 22 +++++----- .../protocol/{response => }/RpcResponse.java | 20 ++++----- .../network/server/DefaultStreamManager.java | 15 ++++--- ...dler.java => TransportChannelHandler.java} | 27 +++++++----- .../server/TransportRequestHandler.java | 30 ++++++------- .../spark/network/server/TransportServer.java | 9 +++- .../network/ChunkFetchIntegrationSuite.java | 1 - .../apache/spark/network/ProtocolSuite.java | 16 +++---- .../spark/network/RpcIntegrationSuite.java | 1 - .../SystemPropertyConfigProvider.java | 4 +- .../spark/network/TestManagedBuffer.java | 4 +- .../network/TransportClientFactorySuite.java | 1 - .../TransportResponseHandlerSuite.java | 8 ++-- 35 files changed, 203 insertions(+), 186 deletions(-) rename network/common/src/main/java/org/apache/spark/network/protocol/{response => }/ChunkFetchFailure.java (91%) rename network/common/src/main/java/org/apache/spark/network/protocol/{request => }/ChunkFetchRequest.java (90%) rename network/common/src/main/java/org/apache/spark/network/protocol/{response => }/ChunkFetchSuccess.java (88%) rename network/common/src/main/java/org/apache/spark/network/protocol/{response => }/MessageDecoder.java (88%) rename network/common/src/main/java/org/apache/spark/network/protocol/{response => }/MessageEncoder.java (96%) rename network/common/src/main/java/org/apache/spark/network/protocol/{request => }/RequestMessage.java (95%) rename network/common/src/main/java/org/apache/spark/network/protocol/{response => }/ResponseMessage.java (94%) rename network/common/src/main/java/org/apache/spark/network/protocol/{response => }/RpcFailure.java (79%) rename network/common/src/main/java/org/apache/spark/network/protocol/{request => }/RpcRequest.java (78%) rename network/common/src/main/java/org/apache/spark/network/protocol/{response => }/RpcResponse.java (79%) rename network/common/src/main/java/org/apache/spark/network/server/{TransportClientHandler.java => TransportChannelHandler.java} (79%) rename network/common/src/{main/java/org/apache/spark/network/util => test/java/org/apache/spark/network}/SystemPropertyConfigProvider.java (92%) diff --git a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockFetcher.scala b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockFetcher.scala index 344d17e7bf661..8c5ffd8da6bbb 100644 --- a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockFetcher.scala +++ b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockFetcher.scala @@ -87,7 +87,7 @@ class NettyBlockFetcher( } override def onFailure(e: Throwable): Unit = { - logError("Failed while starting block fetches") + logError("Failed while starting block fetches", e) blockIds.foreach(blockId => Utils.tryLog(listener.onBlockFetchFailure(blockId, e))) } }) diff --git a/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockManager.scala b/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockManager.scala index c35aa2481ad03..1fb5b2c4546bd 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockManager.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockManager.scala @@ -24,14 +24,14 @@ import java.util.concurrent.atomic.AtomicInteger import scala.collection.JavaConversions._ +import org.apache.spark.{Logging, SparkConf, SparkEnv} import org.apache.spark.executor.ShuffleWriteMetrics import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, ManagedBuffer} import org.apache.spark.serializer.Serializer import org.apache.spark.shuffle.FileShuffleBlockManager.ShuffleFileGroup import org.apache.spark.storage._ -import org.apache.spark.util.collection.{PrimitiveKeyOpenHashMap, PrimitiveVector} import org.apache.spark.util.{MetadataCleaner, MetadataCleanerType, TimeStampedHashMap} -import org.apache.spark.{Logging, SparkConf, SparkEnv} +import org.apache.spark.util.collection.{PrimitiveKeyOpenHashMap, PrimitiveVector} /** A group of writers for a ShuffleMapTask, one writer per reducer. */ private[spark] trait ShuffleWriterGroup { diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala index 23313fe9271fd..0d6f3bf003a9d 100644 --- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala +++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala @@ -21,12 +21,11 @@ import java.util.concurrent.LinkedBlockingQueue import scala.collection.mutable.{ArrayBuffer, HashSet, Queue} +import org.apache.spark.{Logging, TaskContext} import org.apache.spark.network.{BlockFetchingListener, BlockTransferService} -import org.apache.spark.serializer.Serializer import org.apache.spark.network.buffer.ManagedBuffer +import org.apache.spark.serializer.Serializer import org.apache.spark.util.{CompletionIterator, Utils} -import org.apache.spark.{Logging, TaskContext} - /** * An iterator that fetches multiple blocks. For local blocks, it fetches from the local block @@ -285,7 +284,7 @@ final class ShuffleBlockFetcherIterator( val iteratorOpt: Option[Iterator[Any]] = if (result.failed) { None } else { - val is = blockManager.wrapForCompression(result.blockId, result.buf.inputStream()) + val is = blockManager.wrapForCompression(result.blockId, result.buf.createInputStream()) val iter = serializer.newInstance().deserializeStream(is).asIterator Some(CompletionIterator[Any, Iterator[Any]](iter, { // Once the iterator is exhausted, release the buffer and set currentResult to null diff --git a/network/common/src/main/java/org/apache/spark/network/TransportContext.java b/network/common/src/main/java/org/apache/spark/network/TransportContext.java index da0decac7e064..854aa6685f85f 100644 --- a/network/common/src/main/java/org/apache/spark/network/TransportContext.java +++ b/network/common/src/main/java/org/apache/spark/network/TransportContext.java @@ -25,10 +25,10 @@ import org.apache.spark.network.client.TransportClient; import org.apache.spark.network.client.TransportClientFactory; import org.apache.spark.network.client.TransportResponseHandler; -import org.apache.spark.network.protocol.response.MessageDecoder; -import org.apache.spark.network.protocol.response.MessageEncoder; +import org.apache.spark.network.protocol.MessageDecoder; +import org.apache.spark.network.protocol.MessageEncoder; import org.apache.spark.network.server.RpcHandler; -import org.apache.spark.network.server.TransportClientHandler; +import org.apache.spark.network.server.TransportChannelHandler; import org.apache.spark.network.server.TransportRequestHandler; import org.apache.spark.network.server.TransportServer; import org.apache.spark.network.server.StreamManager; @@ -37,7 +37,12 @@ /** * Contains the context to create a {@link TransportServer}, {@link TransportClientFactory}, and to - * setup Netty Channel pipelines with a {@link TransportClientHandler}. + * setup Netty Channel pipelines with a {@link org.apache.spark.network.server.TransportChannelHandler}. + * + * There are two communication protocols that the TransportClient provides, control-plane RPCs and + * data-plane "chunk fetching". The handling of the RPCs is performed outside of the scope of the + * TransportContext (i.e., by a user-provided handler), and it is responsible for setting up streams + * which can be streamed through the data plane in chunks using zero-copy IO. * * The TransportServer and TransportClientFactory both create a TransportChannelHandler for each * channel. As each TransportChannelHandler contains a TransportClient, this enables server @@ -71,16 +76,16 @@ public TransportServer createServer() { /** * Initializes a client or server Netty Channel Pipeline which encodes/decodes messages and - * has a {@link org.apache.spark.network.server.TransportClientHandler} to handle request or + * has a {@link org.apache.spark.network.server.TransportChannelHandler} to handle request or * response messages. * * @return Returns the created TransportChannelHandler, which includes a TransportClient that can * be used to communicate on this channel. The TransportClient is directly associated with a * ChannelHandler to ensure all users of the same channel get the same TransportClient object. */ - public TransportClientHandler initializePipeline(SocketChannel channel) { + public TransportChannelHandler initializePipeline(SocketChannel channel) { try { - TransportClientHandler channelHandler = createChannelHandler(channel); + TransportChannelHandler channelHandler = createChannelHandler(channel); channel.pipeline() .addLast("encoder", encoder) .addLast("frameDecoder", NettyUtils.createFrameDecoder()) @@ -100,12 +105,12 @@ public TransportClientHandler initializePipeline(SocketChannel channel) { * ResponseMessages. The channel is expected to have been successfully created, though certain * properties (such as the remoteAddress()) may not be available yet. */ - private TransportClientHandler createChannelHandler(Channel channel) { + private TransportChannelHandler createChannelHandler(Channel channel) { TransportResponseHandler responseHandler = new TransportResponseHandler(channel); TransportClient client = new TransportClient(channel, responseHandler); - TransportRequestHandler requestHandler = new TransportRequestHandler(channel, client, streamManager, - rpcHandler); - return new TransportClientHandler(client, responseHandler, requestHandler); + TransportRequestHandler requestHandler = new TransportRequestHandler(channel, client, + streamManager, rpcHandler); + return new TransportChannelHandler(client, responseHandler, requestHandler); } public TransportConf getConf() { return conf; } diff --git a/network/common/src/main/java/org/apache/spark/network/buffer/FileSegmentManagedBuffer.java b/network/common/src/main/java/org/apache/spark/network/buffer/FileSegmentManagedBuffer.java index 224f1e6c515ea..a02f692a674b2 100644 --- a/network/common/src/main/java/org/apache/spark/network/buffer/FileSegmentManagedBuffer.java +++ b/network/common/src/main/java/org/apache/spark/network/buffer/FileSegmentManagedBuffer.java @@ -40,6 +40,7 @@ public final class FileSegmentManagedBuffer extends ManagedBuffer { * Memory mapping is expensive and can destabilize the JVM (SPARK-1145, SPARK-3889). * Avoid unless there's a good reason not to. */ + // TODO: Make this configurable private static final long MIN_MEMORY_MAP_BYTES = 2 * 1024 * 1024; private final File file; @@ -88,7 +89,7 @@ public ByteBuffer nioByteBuffer() throws IOException { } @Override - public InputStream inputStream() throws IOException { + public InputStream createInputStream() throws IOException { FileInputStream is = null; try { is = new FileInputStream(file); diff --git a/network/common/src/main/java/org/apache/spark/network/buffer/ManagedBuffer.java b/network/common/src/main/java/org/apache/spark/network/buffer/ManagedBuffer.java index 1735f5540c61b..a415db593a788 100644 --- a/network/common/src/main/java/org/apache/spark/network/buffer/ManagedBuffer.java +++ b/network/common/src/main/java/org/apache/spark/network/buffer/ManagedBuffer.java @@ -43,6 +43,7 @@ public abstract class ManagedBuffer { * Exposes this buffer's data as an NIO ByteBuffer. Changing the position and limit of the * returned ByteBuffer should not affect the content of this buffer. */ + // TODO: Deprecate this, usage may require expensive memory mapping or allocation. public abstract ByteBuffer nioByteBuffer() throws IOException; /** @@ -50,7 +51,7 @@ public abstract class ManagedBuffer { * necessarily check for the length of bytes read, so the caller is responsible for making sure * it does not go over the limit. */ - public abstract InputStream inputStream() throws IOException; + public abstract InputStream createInputStream() throws IOException; /** * Increment the reference count by one if applicable. diff --git a/network/common/src/main/java/org/apache/spark/network/buffer/NettyManagedBuffer.java b/network/common/src/main/java/org/apache/spark/network/buffer/NettyManagedBuffer.java index d928980423f1f..c806bfa45bef3 100644 --- a/network/common/src/main/java/org/apache/spark/network/buffer/NettyManagedBuffer.java +++ b/network/common/src/main/java/org/apache/spark/network/buffer/NettyManagedBuffer.java @@ -46,7 +46,7 @@ public ByteBuffer nioByteBuffer() throws IOException { } @Override - public InputStream inputStream() throws IOException { + public InputStream createInputStream() throws IOException { return new ByteBufInputStream(buf); } @@ -64,7 +64,7 @@ public ManagedBuffer release() { @Override public Object convertToNetty() throws IOException { - return buf; + return buf.duplicate(); } @Override diff --git a/network/common/src/main/java/org/apache/spark/network/buffer/NioManagedBuffer.java b/network/common/src/main/java/org/apache/spark/network/buffer/NioManagedBuffer.java index 3953ef89fbf88..f55b884bc45ce 100644 --- a/network/common/src/main/java/org/apache/spark/network/buffer/NioManagedBuffer.java +++ b/network/common/src/main/java/org/apache/spark/network/buffer/NioManagedBuffer.java @@ -46,7 +46,7 @@ public ByteBuffer nioByteBuffer() throws IOException { } @Override - public InputStream inputStream() throws IOException { + public InputStream createInputStream() throws IOException { return new ByteBufInputStream(Unpooled.wrappedBuffer(buf)); } diff --git a/network/common/src/main/java/org/apache/spark/network/client/ChunkFetchFailureException.java b/network/common/src/main/java/org/apache/spark/network/client/ChunkFetchFailureException.java index 40a1fe67b1c5b..1fbdcd6780785 100644 --- a/network/common/src/main/java/org/apache/spark/network/client/ChunkFetchFailureException.java +++ b/network/common/src/main/java/org/apache/spark/network/client/ChunkFetchFailureException.java @@ -21,17 +21,11 @@ * General exception caused by a remote exception while fetching a chunk. */ public class ChunkFetchFailureException extends RuntimeException { - private final int chunkIndex; - - public ChunkFetchFailureException(int chunkIndex, String errorMsg, Throwable cause) { + public ChunkFetchFailureException(String errorMsg, Throwable cause) { super(errorMsg, cause); - this.chunkIndex = chunkIndex; } - public ChunkFetchFailureException(int chunkIndex, String errorMsg) { + public ChunkFetchFailureException(String errorMsg) { super(errorMsg); - this.chunkIndex = chunkIndex; } - - public int getChunkIndex() { return chunkIndex; } } diff --git a/network/common/src/main/java/org/apache/spark/network/client/TransportClient.java b/network/common/src/main/java/org/apache/spark/network/client/TransportClient.java index 75e26cb7e60c1..b1732fcde21f1 100644 --- a/network/common/src/main/java/org/apache/spark/network/client/TransportClient.java +++ b/network/common/src/main/java/org/apache/spark/network/client/TransportClient.java @@ -28,9 +28,9 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import org.apache.spark.network.protocol.ChunkFetchRequest; +import org.apache.spark.network.protocol.RpcRequest; import org.apache.spark.network.protocol.StreamChunkId; -import org.apache.spark.network.protocol.request.ChunkFetchRequest; -import org.apache.spark.network.protocol.request.RpcRequest; import org.apache.spark.network.util.NettyUtils; /** @@ -106,7 +106,7 @@ public void fetchChunk( public void operationComplete(ChannelFuture future) throws Exception { if (future.isSuccess()) { long timeTaken = System.currentTimeMillis() - startTime; - logger.debug("Sending request {} to {} took {} ms", streamChunkId, serverAddr, + logger.trace("Sending request {} to {} took {} ms", streamChunkId, serverAddr, timeTaken); } else { String errorMsg = String.format("Failed to send request %s to %s: %s", streamChunkId, @@ -114,6 +114,7 @@ public void operationComplete(ChannelFuture future) throws Exception { logger.error(errorMsg, future.cause()); handler.removeFetchRequest(streamChunkId); callback.onFailure(chunkIndex, new RuntimeException(errorMsg, future.cause())); + channel.close(); } } }); @@ -126,24 +127,25 @@ public void operationComplete(ChannelFuture future) throws Exception { public void sendRpc(byte[] message, final RpcResponseCallback callback) { final String serverAddr = NettyUtils.getRemoteAddress(channel); final long startTime = System.currentTimeMillis(); - logger.debug("Sending RPC to {}", serverAddr); + logger.trace("Sending RPC to {}", serverAddr); - final long tag = UUID.randomUUID().getLeastSignificantBits(); - handler.addRpcRequest(tag, callback); + final long requestId = UUID.randomUUID().getLeastSignificantBits(); + handler.addRpcRequest(requestId, callback); - channel.writeAndFlush(new RpcRequest(tag, message)).addListener( + channel.writeAndFlush(new RpcRequest(requestId, message)).addListener( new ChannelFutureListener() { @Override public void operationComplete(ChannelFuture future) throws Exception { if (future.isSuccess()) { long timeTaken = System.currentTimeMillis() - startTime; - logger.debug("Sending request {} to {} took {} ms", tag, serverAddr, timeTaken); + logger.trace("Sending request {} to {} took {} ms", requestId, serverAddr, timeTaken); } else { - String errorMsg = String.format("Failed to send RPC %s to %s: %s", tag, + String errorMsg = String.format("Failed to send RPC %s to %s: %s", requestId, serverAddr, future.cause()); logger.error(errorMsg, future.cause()); - handler.removeRpcRequest(tag); + handler.removeRpcRequest(requestId); callback.onFailure(new RuntimeException(errorMsg, future.cause())); + channel.close(); } } }); diff --git a/network/common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java b/network/common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java index c351858bfe30d..10eb9ef7a025f 100644 --- a/network/common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java +++ b/network/common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java @@ -23,6 +23,7 @@ import java.net.SocketAddress; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.TimeoutException; +import java.util.concurrent.atomic.AtomicReference; import io.netty.bootstrap.Bootstrap; import io.netty.buffer.PooledByteBufAllocator; @@ -37,7 +38,7 @@ import org.slf4j.LoggerFactory; import org.apache.spark.network.TransportContext; -import org.apache.spark.network.server.TransportClientHandler; +import org.apache.spark.network.server.TransportChannelHandler; import org.apache.spark.network.util.IOMode; import org.apache.spark.network.util.NettyUtils; import org.apache.spark.network.util.TransportConf; @@ -66,6 +67,7 @@ public TransportClientFactory(TransportContext context) { IOMode ioMode = IOMode.valueOf(conf.ioMode()); this.socketChannelClass = NettyUtils.getClientChannelClass(ioMode); + // TODO: Make thread pool name configurable. this.workerGroup = NettyUtils.createEventLoop(ioMode, conf.clientThreads(), "shuffle-client"); } @@ -100,17 +102,13 @@ public TransportClient createClient(String remoteHost, int remotePort) throws Ti // Use pooled buffers to reduce temporary buffer allocation bootstrap.option(ChannelOption.ALLOCATOR, createPooledByteBufAllocator()); + final AtomicReference client = new AtomicReference(); + bootstrap.handler(new ChannelInitializer() { @Override public void initChannel(SocketChannel ch) { - TransportClientHandler channelHandler = context.initializePipeline(ch); - TransportClient oldClient = connectionPool.putIfAbsent(address, channelHandler.getClient()); - if (oldClient != null) { - logger.debug("Two clients were created concurrently, second one will be disposed."); - ch.close(); - // Note: this type of failure is still considered a success by Netty, and thus the - // ChannelFuture will complete successfully. - } + TransportChannelHandler clientHandler = context.initializePipeline(ch); + client.set(clientHandler.getClient()); } }); @@ -119,23 +117,31 @@ public void initChannel(SocketChannel ch) { if (!cf.awaitUninterruptibly(conf.connectionTimeoutMs())) { throw new TimeoutException( String.format("Connecting to %s timed out (%s ms)", address, conf.connectionTimeoutMs())); + } else if (cf.cause() != null) { + throw new RuntimeException(String.format("Failed to connect to %s", address), cf.cause()); } - TransportClient client = connectionPool.get(address); - if (client == null) { - // The only way we should be able to reach here is if the client we created started out - // in the "inactive" state, and someone else simultaneously tried to create another client to - // the same server. This is an error condition, as the first client failed to connect. - throw new IllegalStateException("Client was unset! Must have been immediately inactive."); + // Successful connection + assert client.get() != null : "Channel future completed successfully with null client"; + TransportClient oldClient = connectionPool.putIfAbsent(address, client.get()); + if (oldClient == null) { + return client.get(); + } else { + logger.debug("Two clients were created concurrently, second one will be disposed."); + client.get().close(); + return oldClient; } - return client; } /** Close all connections in the connection pool, and shutdown the worker thread pool. */ @Override public void close() { for (TransportClient client : connectionPool.values()) { - client.close(); + try { + client.close(); + } catch (RuntimeException e) { + logger.warn("Ignoring exception during close", e); + } } connectionPool.clear(); diff --git a/network/common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java b/network/common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java index 187b20d27656b..d8965590b34da 100644 --- a/network/common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java +++ b/network/common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java @@ -25,12 +25,12 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import org.apache.spark.network.protocol.response.ResponseMessage; +import org.apache.spark.network.protocol.ChunkFetchFailure; +import org.apache.spark.network.protocol.ChunkFetchSuccess; +import org.apache.spark.network.protocol.ResponseMessage; +import org.apache.spark.network.protocol.RpcFailure; +import org.apache.spark.network.protocol.RpcResponse; import org.apache.spark.network.protocol.StreamChunkId; -import org.apache.spark.network.protocol.response.ChunkFetchFailure; -import org.apache.spark.network.protocol.response.ChunkFetchSuccess; -import org.apache.spark.network.protocol.response.RpcFailure; -import org.apache.spark.network.protocol.response.RpcResponse; import org.apache.spark.network.server.MessageHandler; import org.apache.spark.network.util.NettyUtils; @@ -63,12 +63,12 @@ public void removeFetchRequest(StreamChunkId streamChunkId) { outstandingFetches.remove(streamChunkId); } - public void addRpcRequest(long tag, RpcResponseCallback callback) { - outstandingRpcs.put(tag, callback); + public void addRpcRequest(long requestId, RpcResponseCallback callback) { + outstandingRpcs.put(requestId, callback); } - public void removeRpcRequest(long tag) { - outstandingRpcs.remove(tag); + public void removeRpcRequest(long requestId) { + outstandingRpcs.remove(requestId); } /** @@ -115,7 +115,7 @@ public void handle(ResponseMessage message) { ChunkFetchSuccess resp = (ChunkFetchSuccess) message; ChunkReceivedCallback listener = outstandingFetches.get(resp.streamChunkId); if (listener == null) { - logger.warn("Got a response for block {} from {} but it is not outstanding", + logger.warn("Ignoring response for block {} from {} since it is not outstanding", resp.streamChunkId, remoteAddress); resp.buffer.release(); } else { @@ -127,31 +127,31 @@ public void handle(ResponseMessage message) { ChunkFetchFailure resp = (ChunkFetchFailure) message; ChunkReceivedCallback listener = outstandingFetches.get(resp.streamChunkId); if (listener == null) { - logger.warn("Got a response for block {} from {} ({}) but it is not outstanding", + logger.warn("Ignoring response for block {} from {} ({}) since it is not outstanding", resp.streamChunkId, remoteAddress, resp.errorString); } else { outstandingFetches.remove(resp.streamChunkId); - listener.onFailure(resp.streamChunkId.chunkIndex, - new ChunkFetchFailureException(resp.streamChunkId.chunkIndex, resp.errorString)); + listener.onFailure(resp.streamChunkId.chunkIndex, new ChunkFetchFailureException( + "Failure while fetching " + resp.streamChunkId + ": " + resp.errorString)); } } else if (message instanceof RpcResponse) { RpcResponse resp = (RpcResponse) message; - RpcResponseCallback listener = outstandingRpcs.get(resp.tag); + RpcResponseCallback listener = outstandingRpcs.get(resp.requestId); if (listener == null) { - logger.warn("Got a response for RPC {} from {} ({} bytes) but it is not outstanding", - resp.tag, remoteAddress, resp.response.length); + logger.warn("Ignoring response for RPC {} from {} ({} bytes) since it is not outstanding", + resp.requestId, remoteAddress, resp.response.length); } else { - outstandingRpcs.remove(resp.tag); + outstandingRpcs.remove(resp.requestId); listener.onSuccess(resp.response); } } else if (message instanceof RpcFailure) { RpcFailure resp = (RpcFailure) message; - RpcResponseCallback listener = outstandingRpcs.get(resp.tag); + RpcResponseCallback listener = outstandingRpcs.get(resp.requestId); if (listener == null) { - logger.warn("Got a response for RPC {} from {} ({}) but it is not outstanding", - resp.tag, remoteAddress, resp.errorString); + logger.warn("Ignoring response for RPC {} from {} ({}) since it is not outstanding", + resp.requestId, remoteAddress, resp.errorString); } else { - outstandingRpcs.remove(resp.tag); + outstandingRpcs.remove(resp.requestId); listener.onFailure(new RuntimeException(resp.errorString)); } } else { diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/response/ChunkFetchFailure.java b/network/common/src/main/java/org/apache/spark/network/protocol/ChunkFetchFailure.java similarity index 91% rename from network/common/src/main/java/org/apache/spark/network/protocol/response/ChunkFetchFailure.java rename to network/common/src/main/java/org/apache/spark/network/protocol/ChunkFetchFailure.java index cb3cbcd0a53ca..152af98ced7ce 100644 --- a/network/common/src/main/java/org/apache/spark/network/protocol/response/ChunkFetchFailure.java +++ b/network/common/src/main/java/org/apache/spark/network/protocol/ChunkFetchFailure.java @@ -15,17 +15,14 @@ * limitations under the License. */ -package org.apache.spark.network.protocol.response; +package org.apache.spark.network.protocol; import com.google.common.base.Charsets; import com.google.common.base.Objects; import io.netty.buffer.ByteBuf; -import org.apache.spark.network.protocol.StreamChunkId; - /** - * Response to {@link org.apache.spark.network.protocol.request.ChunkFetchRequest} when there is an - * error fetching the chunk. + * Response to {@link ChunkFetchRequest} when there is an error fetching the chunk. */ public final class ChunkFetchFailure implements ResponseMessage { public final StreamChunkId streamChunkId; diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/request/ChunkFetchRequest.java b/network/common/src/main/java/org/apache/spark/network/protocol/ChunkFetchRequest.java similarity index 90% rename from network/common/src/main/java/org/apache/spark/network/protocol/request/ChunkFetchRequest.java rename to network/common/src/main/java/org/apache/spark/network/protocol/ChunkFetchRequest.java index 99cbb8777a873..980947cf13f6b 100644 --- a/network/common/src/main/java/org/apache/spark/network/protocol/request/ChunkFetchRequest.java +++ b/network/common/src/main/java/org/apache/spark/network/protocol/ChunkFetchRequest.java @@ -15,16 +15,14 @@ * limitations under the License. */ -package org.apache.spark.network.protocol.request; +package org.apache.spark.network.protocol; import com.google.common.base.Objects; import io.netty.buffer.ByteBuf; -import org.apache.spark.network.protocol.StreamChunkId; - /** * Request to fetch a sequence of a single chunk of a stream. This will correspond to a single - * {@link org.apache.spark.network.protocol.response.ResponseMessage} (either success or failure). + * {@link org.apache.spark.network.protocol.ResponseMessage} (either success or failure). */ public final class ChunkFetchRequest implements RequestMessage { public final StreamChunkId streamChunkId; diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/response/ChunkFetchSuccess.java b/network/common/src/main/java/org/apache/spark/network/protocol/ChunkFetchSuccess.java similarity index 88% rename from network/common/src/main/java/org/apache/spark/network/protocol/response/ChunkFetchSuccess.java rename to network/common/src/main/java/org/apache/spark/network/protocol/ChunkFetchSuccess.java index 6bc26a64b9945..ff4936470c697 100644 --- a/network/common/src/main/java/org/apache/spark/network/protocol/response/ChunkFetchSuccess.java +++ b/network/common/src/main/java/org/apache/spark/network/protocol/ChunkFetchSuccess.java @@ -15,18 +15,16 @@ * limitations under the License. */ -package org.apache.spark.network.protocol.response; +package org.apache.spark.network.protocol; import com.google.common.base.Objects; import io.netty.buffer.ByteBuf; import org.apache.spark.network.buffer.ManagedBuffer; import org.apache.spark.network.buffer.NettyManagedBuffer; -import org.apache.spark.network.protocol.StreamChunkId; /** - * Response to {@link org.apache.spark.network.protocol.request.ChunkFetchRequest} when a chunk - * exists and has been successfully fetched. + * Response to {@link ChunkFetchRequest} when a chunk exists and has been successfully fetched. * * Note that the server-side encoding of this messages does NOT include the buffer itself, as this * may be written by Netty in a more efficient manner (i.e., zero-copy write). @@ -49,7 +47,7 @@ public int encodedLength() { return streamChunkId.encodedLength(); } - /** Encoding does NOT include buffer itself. See {@link MessageEncoder}. */ + /** Encoding does NOT include 'buffer' itself. See {@link MessageEncoder}. */ @Override public void encode(ByteBuf buf) { streamChunkId.encode(buf); diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/Encodable.java b/network/common/src/main/java/org/apache/spark/network/protocol/Encodable.java index 363ea5ecfa936..b4e299471b41a 100644 --- a/network/common/src/main/java/org/apache/spark/network/protocol/Encodable.java +++ b/network/common/src/main/java/org/apache/spark/network/protocol/Encodable.java @@ -22,6 +22,12 @@ /** * Interface for an object which can be encoded into a ByteBuf. Multiple Encodable objects are * stored in a single, pre-allocated ByteBuf, so Encodables must also provide their length. + * + * Encodable objects should provide a static "decode(ByteBuf)" method which is invoked by + * {@link MessageDecoder}. During decoding, if the object uses the ByteBuf as its data (rather than + * just copying data from it), then you must retain() the ByteBuf. + * + * Additionally, when adding a new Encodable Message, add it to {@link Message.Type}. */ public interface Encodable { /** Number of bytes of the encoded form of this object. */ diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/Message.java b/network/common/src/main/java/org/apache/spark/network/protocol/Message.java index 6731b3f53ae82..d568370125fd4 100644 --- a/network/common/src/main/java/org/apache/spark/network/protocol/Message.java +++ b/network/common/src/main/java/org/apache/spark/network/protocol/Message.java @@ -19,7 +19,7 @@ import io.netty.buffer.ByteBuf; -/** Messages from the client to the server. */ +/** An on-the-wire transmittable message. */ public interface Message extends Encodable { /** Used to identify this request type. */ Type type(); diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/response/MessageDecoder.java b/network/common/src/main/java/org/apache/spark/network/protocol/MessageDecoder.java similarity index 88% rename from network/common/src/main/java/org/apache/spark/network/protocol/response/MessageDecoder.java rename to network/common/src/main/java/org/apache/spark/network/protocol/MessageDecoder.java index 3ae80305803eb..81f8d7f96350f 100644 --- a/network/common/src/main/java/org/apache/spark/network/protocol/response/MessageDecoder.java +++ b/network/common/src/main/java/org/apache/spark/network/protocol/MessageDecoder.java @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.network.protocol.response; +package org.apache.spark.network.protocol; import java.util.List; @@ -26,10 +26,6 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import org.apache.spark.network.protocol.Message; -import org.apache.spark.network.protocol.request.ChunkFetchRequest; -import org.apache.spark.network.protocol.request.RpcRequest; - /** * Decoder used by the client side to encode server-to-client responses. * This encoder is stateless so it is safe to be shared by multiple threads. @@ -43,7 +39,7 @@ public void decode(ChannelHandlerContext ctx, ByteBuf in, List out) { Message.Type msgType = Message.Type.decode(in); Message decoded = decode(msgType, in); assert decoded.type() == msgType; - logger.debug("Received message " + msgType + ": " + decoded); + logger.trace("Received message " + msgType + ": " + decoded); out.add(decoded); } diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/response/MessageEncoder.java b/network/common/src/main/java/org/apache/spark/network/protocol/MessageEncoder.java similarity index 96% rename from network/common/src/main/java/org/apache/spark/network/protocol/response/MessageEncoder.java rename to network/common/src/main/java/org/apache/spark/network/protocol/MessageEncoder.java index 5ca8de42a6429..4cb8becc3ed22 100644 --- a/network/common/src/main/java/org/apache/spark/network/protocol/response/MessageEncoder.java +++ b/network/common/src/main/java/org/apache/spark/network/protocol/MessageEncoder.java @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.network.protocol.response; +package org.apache.spark.network.protocol; import java.util.List; @@ -26,8 +26,6 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import org.apache.spark.network.protocol.Message; - /** * Encoder used by the server side to encode server-to-client responses. * This encoder is stateless so it is safe to be shared by multiple threads. diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/request/RequestMessage.java b/network/common/src/main/java/org/apache/spark/network/protocol/RequestMessage.java similarity index 95% rename from network/common/src/main/java/org/apache/spark/network/protocol/request/RequestMessage.java rename to network/common/src/main/java/org/apache/spark/network/protocol/RequestMessage.java index 58abce25d9a2a..31b15bb17a327 100644 --- a/network/common/src/main/java/org/apache/spark/network/protocol/request/RequestMessage.java +++ b/network/common/src/main/java/org/apache/spark/network/protocol/RequestMessage.java @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.network.protocol.request; +package org.apache.spark.network.protocol; import org.apache.spark.network.protocol.Message; diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/response/ResponseMessage.java b/network/common/src/main/java/org/apache/spark/network/protocol/ResponseMessage.java similarity index 94% rename from network/common/src/main/java/org/apache/spark/network/protocol/response/ResponseMessage.java rename to network/common/src/main/java/org/apache/spark/network/protocol/ResponseMessage.java index 8f545e91d1d8e..6edffd11cf1e2 100644 --- a/network/common/src/main/java/org/apache/spark/network/protocol/response/ResponseMessage.java +++ b/network/common/src/main/java/org/apache/spark/network/protocol/ResponseMessage.java @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.network.protocol.response; +package org.apache.spark.network.protocol; import org.apache.spark.network.protocol.Message; diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/response/RpcFailure.java b/network/common/src/main/java/org/apache/spark/network/protocol/RpcFailure.java similarity index 79% rename from network/common/src/main/java/org/apache/spark/network/protocol/response/RpcFailure.java rename to network/common/src/main/java/org/apache/spark/network/protocol/RpcFailure.java index 1f161f7957543..e239d4ffbd29c 100644 --- a/network/common/src/main/java/org/apache/spark/network/protocol/response/RpcFailure.java +++ b/network/common/src/main/java/org/apache/spark/network/protocol/RpcFailure.java @@ -15,19 +15,19 @@ * limitations under the License. */ -package org.apache.spark.network.protocol.response; +package org.apache.spark.network.protocol; import com.google.common.base.Charsets; import com.google.common.base.Objects; import io.netty.buffer.ByteBuf; -/** Response to {@link org.apache.spark.network.protocol.request.RpcRequest} for a failed RPC. */ +/** Response to {@link RpcRequest} for a failed RPC. */ public final class RpcFailure implements ResponseMessage { - public final long tag; + public final long requestId; public final String errorString; - public RpcFailure(long tag, String errorString) { - this.tag = tag; + public RpcFailure(long requestId, String errorString) { + this.requestId = requestId; this.errorString = errorString; } @@ -41,25 +41,25 @@ public int encodedLength() { @Override public void encode(ByteBuf buf) { - buf.writeLong(tag); + buf.writeLong(requestId); byte[] errorBytes = errorString.getBytes(Charsets.UTF_8); buf.writeInt(errorBytes.length); buf.writeBytes(errorBytes); } public static RpcFailure decode(ByteBuf buf) { - long tag = buf.readLong(); + long requestId = buf.readLong(); int numErrorStringBytes = buf.readInt(); byte[] errorBytes = new byte[numErrorStringBytes]; buf.readBytes(errorBytes); - return new RpcFailure(tag, new String(errorBytes, Charsets.UTF_8)); + return new RpcFailure(requestId, new String(errorBytes, Charsets.UTF_8)); } @Override public boolean equals(Object other) { if (other instanceof RpcFailure) { RpcFailure o = (RpcFailure) other; - return tag == o.tag && errorString.equals(o.errorString); + return requestId == o.requestId && errorString.equals(o.errorString); } return false; } @@ -67,7 +67,7 @@ public boolean equals(Object other) { @Override public String toString() { return Objects.toStringHelper(this) - .add("tag", tag) + .add("requestId", requestId) .add("errorString", errorString) .toString(); } diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/request/RpcRequest.java b/network/common/src/main/java/org/apache/spark/network/protocol/RpcRequest.java similarity index 78% rename from network/common/src/main/java/org/apache/spark/network/protocol/request/RpcRequest.java rename to network/common/src/main/java/org/apache/spark/network/protocol/RpcRequest.java index 810da7a689c13..099e934ae018c 100644 --- a/network/common/src/main/java/org/apache/spark/network/protocol/request/RpcRequest.java +++ b/network/common/src/main/java/org/apache/spark/network/protocol/RpcRequest.java @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.network.protocol.request; +package org.apache.spark.network.protocol; import java.util.Arrays; @@ -25,17 +25,17 @@ /** * A generic RPC which is handled by a remote {@link org.apache.spark.network.server.RpcHandler}. * This will correspond to a single - * {@link org.apache.spark.network.protocol.response.ResponseMessage} (either success or failure). + * {@link org.apache.spark.network.protocol.ResponseMessage} (either success or failure). */ public final class RpcRequest implements RequestMessage { - /** Tag is used to link an RPC request with its response. */ - public final long tag; + /** Used to link an RPC request with its response. */ + public final long requestId; /** Serialized message to send to remote RpcHandler. */ public final byte[] message; - public RpcRequest(long tag, byte[] message) { - this.tag = tag; + public RpcRequest(long requestId, byte[] message) { + this.requestId = requestId; this.message = message; } @@ -49,24 +49,24 @@ public int encodedLength() { @Override public void encode(ByteBuf buf) { - buf.writeLong(tag); + buf.writeLong(requestId); buf.writeInt(message.length); buf.writeBytes(message); } public static RpcRequest decode(ByteBuf buf) { - long tag = buf.readLong(); + long requestId = buf.readLong(); int messageLen = buf.readInt(); byte[] message = new byte[messageLen]; buf.readBytes(message); - return new RpcRequest(tag, message); + return new RpcRequest(requestId, message); } @Override public boolean equals(Object other) { if (other instanceof RpcRequest) { RpcRequest o = (RpcRequest) other; - return tag == o.tag && Arrays.equals(message, o.message); + return requestId == o.requestId && Arrays.equals(message, o.message); } return false; } @@ -74,7 +74,7 @@ public boolean equals(Object other) { @Override public String toString() { return Objects.toStringHelper(this) - .add("tag", tag) + .add("requestId", requestId) .add("message", message) .toString(); } diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/response/RpcResponse.java b/network/common/src/main/java/org/apache/spark/network/protocol/RpcResponse.java similarity index 79% rename from network/common/src/main/java/org/apache/spark/network/protocol/response/RpcResponse.java rename to network/common/src/main/java/org/apache/spark/network/protocol/RpcResponse.java index 40623ce31c666..ed479478325b6 100644 --- a/network/common/src/main/java/org/apache/spark/network/protocol/response/RpcResponse.java +++ b/network/common/src/main/java/org/apache/spark/network/protocol/RpcResponse.java @@ -15,20 +15,20 @@ * limitations under the License. */ -package org.apache.spark.network.protocol.response; +package org.apache.spark.network.protocol; import java.util.Arrays; import com.google.common.base.Objects; import io.netty.buffer.ByteBuf; -/** Response to {@link org.apache.spark.network.protocol.request.RpcRequest} for a successful RPC. */ +/** Response to {@link RpcRequest} for a successful RPC. */ public final class RpcResponse implements ResponseMessage { - public final long tag; + public final long requestId; public final byte[] response; - public RpcResponse(long tag, byte[] response) { - this.tag = tag; + public RpcResponse(long requestId, byte[] response) { + this.requestId = requestId; this.response = response; } @@ -40,24 +40,24 @@ public RpcResponse(long tag, byte[] response) { @Override public void encode(ByteBuf buf) { - buf.writeLong(tag); + buf.writeLong(requestId); buf.writeInt(response.length); buf.writeBytes(response); } public static RpcResponse decode(ByteBuf buf) { - long tag = buf.readLong(); + long requestId = buf.readLong(); int responseLen = buf.readInt(); byte[] response = new byte[responseLen]; buf.readBytes(response); - return new RpcResponse(tag, response); + return new RpcResponse(requestId, response); } @Override public boolean equals(Object other) { if (other instanceof RpcResponse) { RpcResponse o = (RpcResponse) other; - return tag == o.tag && Arrays.equals(response, o.response); + return requestId == o.requestId && Arrays.equals(response, o.response); } return false; } @@ -65,7 +65,7 @@ public boolean equals(Object other) { @Override public String toString() { return Objects.toStringHelper(this) - .add("tag", tag) + .add("requestId", requestId) .add("response", response) .toString(); } diff --git a/network/common/src/main/java/org/apache/spark/network/server/DefaultStreamManager.java b/network/common/src/main/java/org/apache/spark/network/server/DefaultStreamManager.java index d93607a7c31ea..9688705569634 100644 --- a/network/common/src/main/java/org/apache/spark/network/server/DefaultStreamManager.java +++ b/network/common/src/main/java/org/apache/spark/network/server/DefaultStreamManager.java @@ -42,6 +42,8 @@ public class DefaultStreamManager extends StreamManager { private static class StreamState { final Iterator buffers; + // Used to keep track of the index of the buffer that the user has retrieved, just to ensure + // that the caller only requests each chunk one at a time, in order. int curChunk = 0; StreamState(Iterator buffers) { @@ -50,7 +52,8 @@ private static class StreamState { } public DefaultStreamManager() { - // Start with a random stream id to help identifying different streams. + // For debugging purposes, start with a random stream id to help identifying different streams. + // This does not need to be globally unique, only unique to this class. nextStreamId = new AtomicLong((long) new Random().nextInt(Integer.MAX_VALUE) * 1000); streams = new ConcurrentHashMap(); } @@ -87,13 +90,15 @@ public void connectionTerminated(long streamId) { } } + /** + * Registers a stream of ManagedBuffers which are served as individual chunks one at a time to + * callers. Each ManagedBuffer will be release()'d after it is transferred on the wire. If a + * client connection is closed before the iterator is fully drained, then the remaining buffers + * will all be release()'d. + */ public long registerStream(Iterator buffers) { long myStreamId = nextStreamId.getAndIncrement(); streams.put(myStreamId, new StreamState(buffers)); return myStreamId; } - - public void unregisterStream(long streamId) { - streams.remove(streamId); - } } diff --git a/network/common/src/main/java/org/apache/spark/network/server/TransportClientHandler.java b/network/common/src/main/java/org/apache/spark/network/server/TransportChannelHandler.java similarity index 79% rename from network/common/src/main/java/org/apache/spark/network/server/TransportClientHandler.java rename to network/common/src/main/java/org/apache/spark/network/server/TransportChannelHandler.java index 08cc1b1f95de6..e491367fa4528 100644 --- a/network/common/src/main/java/org/apache/spark/network/server/TransportClientHandler.java +++ b/network/common/src/main/java/org/apache/spark/network/server/TransportChannelHandler.java @@ -25,14 +25,13 @@ import org.apache.spark.network.client.TransportClient; import org.apache.spark.network.client.TransportResponseHandler; import org.apache.spark.network.protocol.Message; -import org.apache.spark.network.protocol.request.RequestMessage; -import org.apache.spark.network.protocol.response.ResponseMessage; +import org.apache.spark.network.protocol.RequestMessage; +import org.apache.spark.network.protocol.ResponseMessage; import org.apache.spark.network.util.NettyUtils; /** - * A handler which is used for delegating requests to the - * {@link TransportRequestHandler} and responses to the - * {@link org.apache.spark.network.client.TransportResponseHandler}. + * The single Transport-level Channel handler which is used for delegating requests to the + * {@link TransportRequestHandler} and responses to the {@link TransportResponseHandler}. * * All channels created in the transport layer are bidirectional. When the Client initiates a Netty * Channel with a RequestMessage (which gets handled by the Server's RequestHandler), the Server @@ -42,14 +41,14 @@ * This means that the Client also needs a RequestHandler and the Server needs a ResponseHandler, * for the Client's responses to the Server's requests. */ -public class TransportClientHandler extends SimpleChannelInboundHandler { - private final Logger logger = LoggerFactory.getLogger(TransportClientHandler.class); +public class TransportChannelHandler extends SimpleChannelInboundHandler { + private final Logger logger = LoggerFactory.getLogger(TransportChannelHandler.class); private final TransportClient client; private final TransportResponseHandler responseHandler; private final TransportRequestHandler requestHandler; - public TransportClientHandler( + public TransportChannelHandler( TransportClient client, TransportResponseHandler responseHandler, TransportRequestHandler requestHandler) { @@ -73,8 +72,16 @@ public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws E @Override public void channelUnregistered(ChannelHandlerContext ctx) throws Exception { - requestHandler.channelUnregistered(); - responseHandler.channelUnregistered(); + try { + requestHandler.channelUnregistered(); + } catch (RuntimeException e) { + logger.error("Exception from request handler while unregistering channel", e); + } + try { + responseHandler.channelUnregistered(); + } catch (RuntimeException e) { + logger.error("Exception from response handler while unregistering channel", e); + } super.channelUnregistered(ctx); } diff --git a/network/common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java b/network/common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java index 08a2a3ec52f8b..352f865935b11 100644 --- a/network/common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java +++ b/network/common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java @@ -31,13 +31,13 @@ import org.apache.spark.network.client.RpcResponseCallback; import org.apache.spark.network.client.TransportClient; import org.apache.spark.network.protocol.Encodable; -import org.apache.spark.network.protocol.request.RequestMessage; -import org.apache.spark.network.protocol.request.ChunkFetchRequest; -import org.apache.spark.network.protocol.request.RpcRequest; -import org.apache.spark.network.protocol.response.ChunkFetchFailure; -import org.apache.spark.network.protocol.response.ChunkFetchSuccess; -import org.apache.spark.network.protocol.response.RpcFailure; -import org.apache.spark.network.protocol.response.RpcResponse; +import org.apache.spark.network.protocol.RequestMessage; +import org.apache.spark.network.protocol.ChunkFetchRequest; +import org.apache.spark.network.protocol.RpcRequest; +import org.apache.spark.network.protocol.ChunkFetchFailure; +import org.apache.spark.network.protocol.ChunkFetchSuccess; +import org.apache.spark.network.protocol.RpcFailure; +import org.apache.spark.network.protocol.RpcResponse; import org.apache.spark.network.util.NettyUtils; /** @@ -66,10 +66,10 @@ public class TransportRequestHandler extends MessageHandler { private final Set streamIds; public TransportRequestHandler( - Channel channel, - TransportClient reverseClient, - StreamManager streamManager, - RpcHandler rpcHandler) { + Channel channel, + TransportClient reverseClient, + StreamManager streamManager, + RpcHandler rpcHandler) { this.channel = channel; this.reverseClient = reverseClient; this.streamManager = streamManager; @@ -124,17 +124,17 @@ private void processRpcRequest(final RpcRequest req) { rpcHandler.receive(reverseClient, req.message, new RpcResponseCallback() { @Override public void onSuccess(byte[] response) { - respond(new RpcResponse(req.tag, response)); + respond(new RpcResponse(req.requestId, response)); } @Override public void onFailure(Throwable e) { - respond(new RpcFailure(req.tag, Throwables.getStackTraceAsString(e))); + respond(new RpcFailure(req.requestId, Throwables.getStackTraceAsString(e))); } }); } catch (Exception e) { - logger.error("Error while invoking RpcHandler#receive() on RPC tag " + req.tag, e); - respond(new RpcFailure(req.tag, Throwables.getStackTraceAsString(e))); + logger.error("Error while invoking RpcHandler#receive() on RPC id " + req.requestId, e); + respond(new RpcFailure(req.requestId, Throwables.getStackTraceAsString(e))); } } diff --git a/network/common/src/main/java/org/apache/spark/network/server/TransportServer.java b/network/common/src/main/java/org/apache/spark/network/server/TransportServer.java index 973fb05f57944..243070750d6e7 100644 --- a/network/common/src/main/java/org/apache/spark/network/server/TransportServer.java +++ b/network/common/src/main/java/org/apache/spark/network/server/TransportServer.java @@ -47,7 +47,7 @@ public class TransportServer implements Closeable { private ServerBootstrap bootstrap; private ChannelFuture channelFuture; - private int port; + private int port = -1; public TransportServer(TransportContext context) { this.context = context; @@ -56,7 +56,12 @@ public TransportServer(TransportContext context) { init(); } - public int getPort() { return port; } + public int getPort() { + if (port == -1) { + throw new IllegalStateException("Server not initialized"); + } + return port; + } private void init() { diff --git a/network/common/src/test/java/org/apache/spark/network/ChunkFetchIntegrationSuite.java b/network/common/src/test/java/org/apache/spark/network/ChunkFetchIntegrationSuite.java index 00ed7b527abd5..738dca9b6a9ee 100644 --- a/network/common/src/test/java/org/apache/spark/network/ChunkFetchIntegrationSuite.java +++ b/network/common/src/test/java/org/apache/spark/network/ChunkFetchIntegrationSuite.java @@ -45,7 +45,6 @@ import org.apache.spark.network.client.TransportClientFactory; import org.apache.spark.network.server.TransportServer; import org.apache.spark.network.server.StreamManager; -import org.apache.spark.network.util.SystemPropertyConfigProvider; import org.apache.spark.network.util.TransportConf; public class ChunkFetchIntegrationSuite { diff --git a/network/common/src/test/java/org/apache/spark/network/ProtocolSuite.java b/network/common/src/test/java/org/apache/spark/network/ProtocolSuite.java index 6932760c44fee..43dc0cf8c7194 100644 --- a/network/common/src/test/java/org/apache/spark/network/ProtocolSuite.java +++ b/network/common/src/test/java/org/apache/spark/network/ProtocolSuite.java @@ -24,14 +24,14 @@ import org.apache.spark.network.protocol.Message; import org.apache.spark.network.protocol.StreamChunkId; -import org.apache.spark.network.protocol.request.ChunkFetchRequest; -import org.apache.spark.network.protocol.request.RpcRequest; -import org.apache.spark.network.protocol.response.ChunkFetchFailure; -import org.apache.spark.network.protocol.response.ChunkFetchSuccess; -import org.apache.spark.network.protocol.response.MessageDecoder; -import org.apache.spark.network.protocol.response.MessageEncoder; -import org.apache.spark.network.protocol.response.RpcFailure; -import org.apache.spark.network.protocol.response.RpcResponse; +import org.apache.spark.network.protocol.ChunkFetchRequest; +import org.apache.spark.network.protocol.ChunkFetchFailure; +import org.apache.spark.network.protocol.ChunkFetchSuccess; +import org.apache.spark.network.protocol.RpcRequest; +import org.apache.spark.network.protocol.RpcFailure; +import org.apache.spark.network.protocol.RpcResponse; +import org.apache.spark.network.protocol.MessageDecoder; +import org.apache.spark.network.protocol.MessageEncoder; import org.apache.spark.network.util.NettyUtils; public class ProtocolSuite { diff --git a/network/common/src/test/java/org/apache/spark/network/RpcIntegrationSuite.java b/network/common/src/test/java/org/apache/spark/network/RpcIntegrationSuite.java index 19ce9c6a8d826..9f216dd2d722d 100644 --- a/network/common/src/test/java/org/apache/spark/network/RpcIntegrationSuite.java +++ b/network/common/src/test/java/org/apache/spark/network/RpcIntegrationSuite.java @@ -38,7 +38,6 @@ import org.apache.spark.network.server.DefaultStreamManager; import org.apache.spark.network.server.RpcHandler; import org.apache.spark.network.server.TransportServer; -import org.apache.spark.network.util.SystemPropertyConfigProvider; import org.apache.spark.network.util.TransportConf; public class RpcIntegrationSuite { diff --git a/network/common/src/main/java/org/apache/spark/network/util/SystemPropertyConfigProvider.java b/network/common/src/test/java/org/apache/spark/network/SystemPropertyConfigProvider.java similarity index 92% rename from network/common/src/main/java/org/apache/spark/network/util/SystemPropertyConfigProvider.java rename to network/common/src/test/java/org/apache/spark/network/SystemPropertyConfigProvider.java index f15ec8d294258..f4e0a2426a3d2 100644 --- a/network/common/src/main/java/org/apache/spark/network/util/SystemPropertyConfigProvider.java +++ b/network/common/src/test/java/org/apache/spark/network/SystemPropertyConfigProvider.java @@ -15,10 +15,12 @@ * limitations under the License. */ -package org.apache.spark.network.util; +package org.apache.spark.network; import java.util.NoSuchElementException; +import org.apache.spark.network.util.ConfigProvider; + /** Uses System properties to obtain config values. */ public class SystemPropertyConfigProvider extends ConfigProvider { @Override diff --git a/network/common/src/test/java/org/apache/spark/network/TestManagedBuffer.java b/network/common/src/test/java/org/apache/spark/network/TestManagedBuffer.java index 7e7554af70f42..38113a918f795 100644 --- a/network/common/src/test/java/org/apache/spark/network/TestManagedBuffer.java +++ b/network/common/src/test/java/org/apache/spark/network/TestManagedBuffer.java @@ -59,8 +59,8 @@ public ByteBuffer nioByteBuffer() throws IOException { } @Override - public InputStream inputStream() throws IOException { - return underlying.inputStream(); + public InputStream createInputStream() throws IOException { + return underlying.createInputStream(); } @Override diff --git a/network/common/src/test/java/org/apache/spark/network/TransportClientFactorySuite.java b/network/common/src/test/java/org/apache/spark/network/TransportClientFactorySuite.java index f76b4bc55182d..3ef964616f0c5 100644 --- a/network/common/src/test/java/org/apache/spark/network/TransportClientFactorySuite.java +++ b/network/common/src/test/java/org/apache/spark/network/TransportClientFactorySuite.java @@ -32,7 +32,6 @@ import org.apache.spark.network.server.RpcHandler; import org.apache.spark.network.server.TransportServer; import org.apache.spark.network.server.StreamManager; -import org.apache.spark.network.util.SystemPropertyConfigProvider; import org.apache.spark.network.util.JavaUtils; import org.apache.spark.network.util.TransportConf; diff --git a/network/common/src/test/java/org/apache/spark/network/TransportResponseHandlerSuite.java b/network/common/src/test/java/org/apache/spark/network/TransportResponseHandlerSuite.java index 6e360e96099f4..17a03ebe88a93 100644 --- a/network/common/src/test/java/org/apache/spark/network/TransportResponseHandlerSuite.java +++ b/network/common/src/test/java/org/apache/spark/network/TransportResponseHandlerSuite.java @@ -29,11 +29,11 @@ import org.apache.spark.network.client.ChunkReceivedCallback; import org.apache.spark.network.client.RpcResponseCallback; import org.apache.spark.network.client.TransportResponseHandler; +import org.apache.spark.network.protocol.ChunkFetchFailure; +import org.apache.spark.network.protocol.ChunkFetchSuccess; +import org.apache.spark.network.protocol.RpcFailure; +import org.apache.spark.network.protocol.RpcResponse; import org.apache.spark.network.protocol.StreamChunkId; -import org.apache.spark.network.protocol.response.ChunkFetchFailure; -import org.apache.spark.network.protocol.response.ChunkFetchSuccess; -import org.apache.spark.network.protocol.response.RpcFailure; -import org.apache.spark.network.protocol.response.RpcResponse; public class TransportResponseHandlerSuite { @Test From 2b0d1c064899429eb115d984308eb18eebe7c9e0 Mon Sep 17 00:00:00 2001 From: Aaron Davidson Date: Tue, 28 Oct 2014 16:17:00 -0700 Subject: [PATCH 43/46] 100ch --- .../spark/network/nio/NioBlockTransferService.scala | 8 ++++---- core/src/main/scala/org/apache/spark/util/Utils.scala | 1 - 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/network/nio/NioBlockTransferService.scala b/core/src/main/scala/org/apache/spark/network/nio/NioBlockTransferService.scala index 489e023c8fb17..11793ea92adb1 100644 --- a/core/src/main/scala/org/apache/spark/network/nio/NioBlockTransferService.scala +++ b/core/src/main/scala/org/apache/spark/network/nio/NioBlockTransferService.scala @@ -100,15 +100,15 @@ final class NioBlockTransferService(conf: SparkConf, securityManager: SecurityMa // SPARK-4064: In some cases(eg. Remote block was removed) blockMessageArray may be empty. if (blockMessageArray.isEmpty) { blockIds.foreach { id => - listener.onBlockFetchFailure(id, - new SparkException(s"Received empty message from $cmId")) + listener.onBlockFetchFailure(id, new SparkException(s"Received empty message from $cmId")) } } else { for (blockMessage: BlockMessage <- blockMessageArray) { - if (blockMessage.getType != BlockMessage.TYPE_GOT_BLOCK) { + val msgType = blockMessage.getType + if (msgType != BlockMessage.TYPE_GOT_BLOCK) { if (blockMessage.getId != null) { listener.onBlockFetchFailure(blockMessage.getId.toString, - new SparkException(s"Unexpected message ${blockMessage.getType} received from $cmId")) + new SparkException(s"Unexpected message $msgType received from $cmId")) } } else { val blockId = blockMessage.getId diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index 1e881da5114d3..0daab91143e47 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -43,7 +43,6 @@ import org.json4s._ import tachyon.client.{TachyonFile,TachyonFS} import org.apache.spark._ -import org.apache.spark.util.SparkUncaughtExceptionHandler import org.apache.spark.serializer.{DeserializationStream, SerializationStream, SerializerInstance} /** CallSite represents a place in user code. It can have a short and a long form. */ From 4a204b846a8ce2b1cfbab9ed1ec42e8a2f082184 Mon Sep 17 00:00:00 2001 From: Aaron Davidson Date: Tue, 28 Oct 2014 17:59:31 -0700 Subject: [PATCH 44/46] Fail block fetches if client connection fails --- .../network/netty/NettyBlockTransferService.scala | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala index 501a2d123d456..38a3e945155e8 100644 --- a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala +++ b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala @@ -58,8 +58,14 @@ class NettyBlockTransferService(conf: SparkConf) extends BlockTransferService { port: Int, blockIds: Seq[String], listener: BlockFetchingListener): Unit = { - val client = clientFactory.createClient(hostname, port) - new NettyBlockFetcher(serializer, client, blockIds, listener).start() + try { + val client = clientFactory.createClient(hostname, port) + new NettyBlockFetcher(serializer, client, blockIds, listener).start() + } catch { + case e: Exception => + logError("Exception while beginning fetchBlocks", e) + blockIds.foreach(listener.onBlockFetchFailure(_, e)) + } } override def hostName: String = Utils.localHostName() From d7be11b74f6e4ccede5f783742b88ecffeb19add Mon Sep 17 00:00:00 2001 From: Aaron Davidson Date: Tue, 28 Oct 2014 19:02:35 -0700 Subject: [PATCH 45/46] Turn netty on by default --- core/src/main/scala/org/apache/spark/SparkEnv.scala | 1 - 1 file changed, 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala index 51031443b0654..5620b6dcdca68 100644 --- a/core/src/main/scala/org/apache/spark/SparkEnv.scala +++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala @@ -273,7 +273,6 @@ object SparkEnv extends Logging { val shuffleMemoryManager = new ShuffleMemoryManager(conf) - // TODO: This is only netty by default for initial testing -- it should not be merged as such!!! val blockTransferService = conf.get("spark.shuffle.blockTransferService", "netty").toLowerCase match { case "netty" => From cadfd28f116f0dbca11e580a23caf82060bcf922 Mon Sep 17 00:00:00 2001 From: Aaron Davidson Date: Tue, 28 Oct 2014 22:13:31 -0700 Subject: [PATCH 46/46] Turn netty off by default --- core/src/main/scala/org/apache/spark/SparkEnv.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala index 5620b6dcdca68..6a6dfda363974 100644 --- a/core/src/main/scala/org/apache/spark/SparkEnv.scala +++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala @@ -274,7 +274,7 @@ object SparkEnv extends Logging { val shuffleMemoryManager = new ShuffleMemoryManager(conf) val blockTransferService = - conf.get("spark.shuffle.blockTransferService", "netty").toLowerCase match { + conf.get("spark.shuffle.blockTransferService", "nio").toLowerCase match { case "netty" => new NettyBlockTransferService(conf) case "nio" =>