From 83acf3794be1886bd8199d8a45f6ee12141db0e0 Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Tue, 24 Nov 2015 17:12:59 -0800 Subject: [PATCH 1/2] [SPARK-12007] [network] Avoid copies in the network lib's RPC layer. This change seems large, but most of it is just replacing `byte[]` with `ByteBuffer` and `new byte[]` with `ByteBuffer.allocate()`, since it changes the network library's API. The following are parts of the code that actually have meaningful changes: - The Message implementations were changed to inherit from a new AbstractMessage that can optionally hold a reference to a body (in the form of a ManagedBuffer); this is similar to how ResponseWithBody worked before, except now it's not restricted to just responses. - The TransportFrameDecoder was pretty much rewritten to avoid copies as much as possible; it doesn't rely on CompositeByteBuf to accumulate incoming data anymore, since CompositeByteBuf has issues when slices are retained. The code now is able to create frames without having to resort to copying bytes except for a few bytes (containing the frame length) in very rare cases. - Some minor changes in the SASL layer to convert things back to `byte[]` since the JDK SASL API operates on those. --- .../mesos/MesosExternalShuffleService.scala | 3 +- .../network/netty/NettyBlockRpcServer.scala | 8 +- .../netty/NettyBlockTransferService.scala | 6 +- .../apache/spark/rpc/netty/NettyRpcEnv.scala | 16 +- .../org/apache/spark/rpc/netty/Outbox.scala | 9 +- .../rpc/netty/NettyRpcHandlerSuite.scala | 3 +- .../network/client/RpcResponseCallback.java | 4 +- .../spark/network/client/TransportClient.java | 16 +- .../client/TransportResponseHandler.java | 16 +- .../network/protocol/AbstractMessage.java | 54 +++++++ ...Body.java => AbstractResponseMessage.java} | 16 +- .../network/protocol/ChunkFetchFailure.java | 2 +- .../network/protocol/ChunkFetchRequest.java | 2 +- .../network/protocol/ChunkFetchSuccess.java | 8 +- .../spark/network/protocol/Message.java | 11 +- .../network/protocol/MessageEncoder.java | 29 ++-- .../spark/network/protocol/OneWayMessage.java | 33 ++-- .../spark/network/protocol/RpcFailure.java | 2 +- .../spark/network/protocol/RpcRequest.java | 34 +++-- .../spark/network/protocol/RpcResponse.java | 39 +++-- .../spark/network/protocol/StreamFailure.java | 2 +- .../spark/network/protocol/StreamRequest.java | 2 +- .../network/protocol/StreamResponse.java | 19 +-- .../network/sasl/SaslClientBootstrap.java | 12 +- .../spark/network/sasl/SaslMessage.java | 31 ++-- .../spark/network/sasl/SaslRpcHandler.java | 26 +++- .../spark/network/server/MessageHandler.java | 2 +- .../spark/network/server/NoOpRpcHandler.java | 8 +- .../spark/network/server/RpcHandler.java | 8 +- .../server/TransportChannelHandler.java | 2 +- .../server/TransportRequestHandler.java | 15 +- .../apache/spark/network/util/JavaUtils.java | 48 ++++-- .../network/util/TransportFrameDecoder.java | 142 ++++++++++++------ .../network/ChunkFetchIntegrationSuite.java | 5 +- .../apache/spark/network/ProtocolSuite.java | 10 +- .../RequestTimeoutIntegrationSuite.java | 51 ++++--- .../spark/network/RpcIntegrationSuite.java | 26 ++-- .../org/apache/spark/network/StreamSuite.java | 5 +- .../TransportResponseHandlerSuite.java | 24 +-- .../spark/network/sasl/SparkSaslSuite.java | 43 ++++-- .../util/TransportFrameDecoderSuite.java | 23 ++- .../shuffle/ExternalShuffleBlockHandler.java | 9 +- .../shuffle/ExternalShuffleClient.java | 3 +- .../shuffle/OneForOneBlockFetcher.java | 7 +- .../mesos/MesosExternalShuffleClient.java | 5 +- .../protocol/BlockTransferMessage.java | 8 +- 46 files changed, 563 insertions(+), 284 deletions(-) create mode 100644 network/common/src/main/java/org/apache/spark/network/protocol/AbstractMessage.java rename network/common/src/main/java/org/apache/spark/network/protocol/{ResponseWithBody.java => AbstractResponseMessage.java} (63%) diff --git a/core/src/main/scala/org/apache/spark/deploy/mesos/MesosExternalShuffleService.scala b/core/src/main/scala/org/apache/spark/deploy/mesos/MesosExternalShuffleService.scala index 12337a940a414..8ffcfc0878a42 100644 --- a/core/src/main/scala/org/apache/spark/deploy/mesos/MesosExternalShuffleService.scala +++ b/core/src/main/scala/org/apache/spark/deploy/mesos/MesosExternalShuffleService.scala @@ -18,6 +18,7 @@ package org.apache.spark.deploy.mesos import java.net.SocketAddress +import java.nio.ByteBuffer import scala.collection.mutable @@ -56,7 +57,7 @@ private[mesos] class MesosExternalShuffleBlockHandler(transportConf: TransportCo } } connectedApps(address) = appId - callback.onSuccess(new Array[Byte](0)) + callback.onSuccess(ByteBuffer.allocate(0)) case _ => super.handleMessage(message, client, callback) } } diff --git a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala index 76968249fb625..df8c21fb837ed 100644 --- a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala +++ b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala @@ -47,9 +47,9 @@ class NettyBlockRpcServer( override def receive( client: TransportClient, - messageBytes: Array[Byte], + rpcMessage: ByteBuffer, responseContext: RpcResponseCallback): Unit = { - val message = BlockTransferMessage.Decoder.fromByteArray(messageBytes) + val message = BlockTransferMessage.Decoder.fromByteBuffer(rpcMessage) logTrace(s"Received request: $message") message match { @@ -58,7 +58,7 @@ class NettyBlockRpcServer( openBlocks.blockIds.map(BlockId.apply).map(blockManager.getBlockData) val streamId = streamManager.registerStream(appId, blocks.iterator.asJava) logTrace(s"Registered streamId $streamId with ${blocks.size} buffers") - responseContext.onSuccess(new StreamHandle(streamId, blocks.size).toByteArray) + responseContext.onSuccess(new StreamHandle(streamId, blocks.size).toByteBuffer) case uploadBlock: UploadBlock => // StorageLevel is serialized as bytes using our JavaSerializer. @@ -66,7 +66,7 @@ class NettyBlockRpcServer( serializer.newInstance().deserialize(ByteBuffer.wrap(uploadBlock.metadata)) val data = new NioManagedBuffer(ByteBuffer.wrap(uploadBlock.blockData)) blockManager.putBlockData(BlockId(uploadBlock.blockId), data, level) - responseContext.onSuccess(new Array[Byte](0)) + responseContext.onSuccess(ByteBuffer.allocate(0)) } } diff --git a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala index b0694e3c6c8af..82c16e855b0c0 100644 --- a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala +++ b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala @@ -17,6 +17,8 @@ package org.apache.spark.network.netty +import java.nio.ByteBuffer + import scala.collection.JavaConverters._ import scala.concurrent.{Future, Promise} @@ -133,9 +135,9 @@ class NettyBlockTransferService(conf: SparkConf, securityManager: SecurityManage data } - client.sendRpc(new UploadBlock(appId, execId, blockId.toString, levelBytes, array).toByteArray, + client.sendRpc(new UploadBlock(appId, execId, blockId.toString, levelBytes, array).toByteBuffer, new RpcResponseCallback { - override def onSuccess(response: Array[Byte]): Unit = { + override def onSuccess(response: ByteBuffer): Unit = { logTrace(s"Successfully uploaded block $blockId") result.success((): Unit) } diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala index c7d74fa1d9195..68c5f44145b0d 100644 --- a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala +++ b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala @@ -241,16 +241,14 @@ private[netty] class NettyRpcEnv( promise.future.mapTo[T].recover(timeout.addMessageIfTimeout)(ThreadUtils.sameThread) } - private[netty] def serialize(content: Any): Array[Byte] = { - val buffer = javaSerializerInstance.serialize(content) - java.util.Arrays.copyOfRange( - buffer.array(), buffer.arrayOffset + buffer.position, buffer.arrayOffset + buffer.limit) + private[netty] def serialize(content: Any): ByteBuffer = { + javaSerializerInstance.serialize(content) } - private[netty] def deserialize[T: ClassTag](client: TransportClient, bytes: Array[Byte]): T = { + private[netty] def deserialize[T: ClassTag](client: TransportClient, bytes: ByteBuffer): T = { NettyRpcEnv.currentClient.withValue(client) { deserialize { () => - javaSerializerInstance.deserialize[T](ByteBuffer.wrap(bytes)) + javaSerializerInstance.deserialize[T](bytes) } } } @@ -557,7 +555,7 @@ private[netty] class NettyRpcHandler( override def receive( client: TransportClient, - message: Array[Byte], + message: ByteBuffer, callback: RpcResponseCallback): Unit = { val messageToDispatch = internalReceive(client, message) dispatcher.postRemoteMessage(messageToDispatch, callback) @@ -565,12 +563,12 @@ private[netty] class NettyRpcHandler( override def receive( client: TransportClient, - message: Array[Byte]): Unit = { + message: ByteBuffer): Unit = { val messageToDispatch = internalReceive(client, message) dispatcher.postOneWayMessage(messageToDispatch) } - private def internalReceive(client: TransportClient, message: Array[Byte]): RequestMessage = { + private def internalReceive(client: TransportClient, message: ByteBuffer): RequestMessage = { val addr = client.getChannel().remoteAddress().asInstanceOf[InetSocketAddress] assert(addr != null) val clientAddr = RpcAddress(addr.getHostName, addr.getPort) diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/Outbox.scala b/core/src/main/scala/org/apache/spark/rpc/netty/Outbox.scala index 36fdd00bbc4c2..2316ebe347bb7 100644 --- a/core/src/main/scala/org/apache/spark/rpc/netty/Outbox.scala +++ b/core/src/main/scala/org/apache/spark/rpc/netty/Outbox.scala @@ -17,6 +17,7 @@ package org.apache.spark.rpc.netty +import java.nio.ByteBuffer import java.util.concurrent.Callable import javax.annotation.concurrent.GuardedBy @@ -34,7 +35,7 @@ private[netty] sealed trait OutboxMessage { } -private[netty] case class OneWayOutboxMessage(content: Array[Byte]) extends OutboxMessage +private[netty] case class OneWayOutboxMessage(content: ByteBuffer) extends OutboxMessage with Logging { override def sendWith(client: TransportClient): Unit = { @@ -48,9 +49,9 @@ private[netty] case class OneWayOutboxMessage(content: Array[Byte]) extends Outb } private[netty] case class RpcOutboxMessage( - content: Array[Byte], + content: ByteBuffer, _onFailure: (Throwable) => Unit, - _onSuccess: (TransportClient, Array[Byte]) => Unit) + _onSuccess: (TransportClient, ByteBuffer) => Unit) extends OutboxMessage with RpcResponseCallback { private var client: TransportClient = _ @@ -70,7 +71,7 @@ private[netty] case class RpcOutboxMessage( _onFailure(e) } - override def onSuccess(response: Array[Byte]): Unit = { + override def onSuccess(response: ByteBuffer): Unit = { _onSuccess(client, response) } diff --git a/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcHandlerSuite.scala b/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcHandlerSuite.scala index 323184cdd9b6e..ebd6f700710bd 100644 --- a/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcHandlerSuite.scala +++ b/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcHandlerSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.rpc.netty import java.net.InetSocketAddress +import java.nio.ByteBuffer import io.netty.channel.Channel import org.mockito.Mockito._ @@ -32,7 +33,7 @@ class NettyRpcHandlerSuite extends SparkFunSuite { val env = mock(classOf[NettyRpcEnv]) val sm = mock(classOf[StreamManager]) - when(env.deserialize(any(classOf[TransportClient]), any(classOf[Array[Byte]]))(any())) + when(env.deserialize(any(classOf[TransportClient]), any(classOf[ByteBuffer]))(any())) .thenReturn(RequestMessage(RpcAddress("localhost", 12345), null, null)) test("receive") { diff --git a/network/common/src/main/java/org/apache/spark/network/client/RpcResponseCallback.java b/network/common/src/main/java/org/apache/spark/network/client/RpcResponseCallback.java index 6ec960d795420..47e93f9846fa6 100644 --- a/network/common/src/main/java/org/apache/spark/network/client/RpcResponseCallback.java +++ b/network/common/src/main/java/org/apache/spark/network/client/RpcResponseCallback.java @@ -17,13 +17,15 @@ package org.apache.spark.network.client; +import java.nio.ByteBuffer; + /** * Callback for the result of a single RPC. This will be invoked once with either success or * failure. */ public interface RpcResponseCallback { /** Successful serialized result from server. */ - void onSuccess(byte[] response); + void onSuccess(ByteBuffer response); /** Exception either propagated from server or raised on client side. */ void onFailure(Throwable e); diff --git a/network/common/src/main/java/org/apache/spark/network/client/TransportClient.java b/network/common/src/main/java/org/apache/spark/network/client/TransportClient.java index 8a58e7b24585b..c49ca4d5ee925 100644 --- a/network/common/src/main/java/org/apache/spark/network/client/TransportClient.java +++ b/network/common/src/main/java/org/apache/spark/network/client/TransportClient.java @@ -20,6 +20,7 @@ import java.io.Closeable; import java.io.IOException; import java.net.SocketAddress; +import java.nio.ByteBuffer; import java.util.UUID; import java.util.concurrent.ExecutionException; import java.util.concurrent.TimeUnit; @@ -36,6 +37,7 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import org.apache.spark.network.buffer.NioManagedBuffer; import org.apache.spark.network.protocol.ChunkFetchRequest; import org.apache.spark.network.protocol.OneWayMessage; import org.apache.spark.network.protocol.RpcRequest; @@ -212,7 +214,7 @@ public void operationComplete(ChannelFuture future) throws Exception { * @param callback Callback to handle the RPC's reply. * @return The RPC's id. */ - public long sendRpc(byte[] message, final RpcResponseCallback callback) { + public long sendRpc(ByteBuffer message, final RpcResponseCallback callback) { final String serverAddr = NettyUtils.getRemoteAddress(channel); final long startTime = System.currentTimeMillis(); logger.trace("Sending RPC to {}", serverAddr); @@ -220,7 +222,7 @@ public long sendRpc(byte[] message, final RpcResponseCallback callback) { final long requestId = Math.abs(UUID.randomUUID().getLeastSignificantBits()); handler.addRpcRequest(requestId, callback); - channel.writeAndFlush(new RpcRequest(requestId, message)).addListener( + channel.writeAndFlush(new RpcRequest(requestId, new NioManagedBuffer(message))).addListener( new ChannelFutureListener() { @Override public void operationComplete(ChannelFuture future) throws Exception { @@ -249,12 +251,12 @@ public void operationComplete(ChannelFuture future) throws Exception { * Synchronously sends an opaque message to the RpcHandler on the server-side, waiting for up to * a specified timeout for a response. */ - public byte[] sendRpcSync(byte[] message, long timeoutMs) { - final SettableFuture result = SettableFuture.create(); + public ByteBuffer sendRpcSync(ByteBuffer message, long timeoutMs) { + final SettableFuture result = SettableFuture.create(); sendRpc(message, new RpcResponseCallback() { @Override - public void onSuccess(byte[] response) { + public void onSuccess(ByteBuffer response) { result.set(response); } @@ -279,8 +281,8 @@ public void onFailure(Throwable e) { * * @param message The message to send. */ - public void send(byte[] message) { - channel.writeAndFlush(new OneWayMessage(message)); + public void send(ByteBuffer message) { + channel.writeAndFlush(new OneWayMessage(new NioManagedBuffer(message))); } /** diff --git a/network/common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java b/network/common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java index 4c15045363b84..23a8dba593442 100644 --- a/network/common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java +++ b/network/common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java @@ -136,7 +136,7 @@ public void exceptionCaught(Throwable cause) { } @Override - public void handle(ResponseMessage message) { + public void handle(ResponseMessage message) throws Exception { String remoteAddress = NettyUtils.getRemoteAddress(channel); if (message instanceof ChunkFetchSuccess) { ChunkFetchSuccess resp = (ChunkFetchSuccess) message; @@ -144,11 +144,11 @@ public void handle(ResponseMessage message) { if (listener == null) { logger.warn("Ignoring response for block {} from {} since it is not outstanding", resp.streamChunkId, remoteAddress); - resp.body.release(); + resp.body().release(); } else { outstandingFetches.remove(resp.streamChunkId); - listener.onSuccess(resp.streamChunkId.chunkIndex, resp.body); - resp.body.release(); + listener.onSuccess(resp.streamChunkId.chunkIndex, resp.body()); + resp.body().release(); } } else if (message instanceof ChunkFetchFailure) { ChunkFetchFailure resp = (ChunkFetchFailure) message; @@ -166,10 +166,14 @@ public void handle(ResponseMessage message) { RpcResponseCallback listener = outstandingRpcs.get(resp.requestId); if (listener == null) { logger.warn("Ignoring response for RPC {} from {} ({} bytes) since it is not outstanding", - resp.requestId, remoteAddress, resp.response.length); + resp.requestId, remoteAddress, resp.body().size()); } else { outstandingRpcs.remove(resp.requestId); - listener.onSuccess(resp.response); + try { + listener.onSuccess(resp.body().nioByteBuffer()); + } finally { + resp.body().release(); + } } } else if (message instanceof RpcFailure) { RpcFailure resp = (RpcFailure) message; diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/AbstractMessage.java b/network/common/src/main/java/org/apache/spark/network/protocol/AbstractMessage.java new file mode 100644 index 0000000000000..2924218c2f08b --- /dev/null +++ b/network/common/src/main/java/org/apache/spark/network/protocol/AbstractMessage.java @@ -0,0 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.protocol; + +import com.google.common.base.Objects; + +import org.apache.spark.network.buffer.ManagedBuffer; + +/** + * Abstract class for messages which optionally contain a body kept in a separate buffer. + */ +public abstract class AbstractMessage implements Message { + private final ManagedBuffer body; + private final boolean isBodyInFrame; + + protected AbstractMessage() { + this(null, false); + } + + protected AbstractMessage(ManagedBuffer body, boolean isBodyInFrame) { + this.body = body; + this.isBodyInFrame = isBodyInFrame; + } + + @Override + public ManagedBuffer body() { + return body; + } + + @Override + public boolean isBodyInFrame() { + return isBodyInFrame; + } + + protected boolean equals(AbstractMessage other) { + return isBodyInFrame == other.isBodyInFrame && Objects.equal(body, other.body); + } + +} diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/ResponseWithBody.java b/network/common/src/main/java/org/apache/spark/network/protocol/AbstractResponseMessage.java similarity index 63% rename from network/common/src/main/java/org/apache/spark/network/protocol/ResponseWithBody.java rename to network/common/src/main/java/org/apache/spark/network/protocol/AbstractResponseMessage.java index 67be77e39f711..c362c92fc4f52 100644 --- a/network/common/src/main/java/org/apache/spark/network/protocol/ResponseWithBody.java +++ b/network/common/src/main/java/org/apache/spark/network/protocol/AbstractResponseMessage.java @@ -17,23 +17,15 @@ package org.apache.spark.network.protocol; -import com.google.common.base.Objects; -import io.netty.buffer.ByteBuf; - import org.apache.spark.network.buffer.ManagedBuffer; -import org.apache.spark.network.buffer.NettyManagedBuffer; /** - * Abstract class for response messages that contain a large data portion kept in a separate - * buffer. These messages are treated especially by MessageEncoder. + * Abstract class for response messages. */ -public abstract class ResponseWithBody implements ResponseMessage { - public final ManagedBuffer body; - public final boolean isBodyInFrame; +public abstract class AbstractResponseMessage extends AbstractMessage implements ResponseMessage { - protected ResponseWithBody(ManagedBuffer body, boolean isBodyInFrame) { - this.body = body; - this.isBodyInFrame = isBodyInFrame; + protected AbstractResponseMessage(ManagedBuffer body, boolean isBodyInFrame) { + super(body, isBodyInFrame); } public abstract ResponseMessage createFailureResponse(String error); diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/ChunkFetchFailure.java b/network/common/src/main/java/org/apache/spark/network/protocol/ChunkFetchFailure.java index f0363830b61ac..7b28a9a969486 100644 --- a/network/common/src/main/java/org/apache/spark/network/protocol/ChunkFetchFailure.java +++ b/network/common/src/main/java/org/apache/spark/network/protocol/ChunkFetchFailure.java @@ -23,7 +23,7 @@ /** * Response to {@link ChunkFetchRequest} when there is an error fetching the chunk. */ -public final class ChunkFetchFailure implements ResponseMessage { +public final class ChunkFetchFailure extends AbstractMessage implements ResponseMessage { public final StreamChunkId streamChunkId; public final String errorString; diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/ChunkFetchRequest.java b/network/common/src/main/java/org/apache/spark/network/protocol/ChunkFetchRequest.java index 5a173af54f618..26d063feb5fe3 100644 --- a/network/common/src/main/java/org/apache/spark/network/protocol/ChunkFetchRequest.java +++ b/network/common/src/main/java/org/apache/spark/network/protocol/ChunkFetchRequest.java @@ -24,7 +24,7 @@ * Request to fetch a sequence of a single chunk of a stream. This will correspond to a single * {@link org.apache.spark.network.protocol.ResponseMessage} (either success or failure). */ -public final class ChunkFetchRequest implements RequestMessage { +public final class ChunkFetchRequest extends AbstractMessage implements RequestMessage { public final StreamChunkId streamChunkId; public ChunkFetchRequest(StreamChunkId streamChunkId) { diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/ChunkFetchSuccess.java b/network/common/src/main/java/org/apache/spark/network/protocol/ChunkFetchSuccess.java index e6a7e9a8b4145..94c2ac9b20e43 100644 --- a/network/common/src/main/java/org/apache/spark/network/protocol/ChunkFetchSuccess.java +++ b/network/common/src/main/java/org/apache/spark/network/protocol/ChunkFetchSuccess.java @@ -30,7 +30,7 @@ * may be written by Netty in a more efficient manner (i.e., zero-copy write). * Similarly, the client-side decoding will reuse the Netty ByteBuf as the buffer. */ -public final class ChunkFetchSuccess extends ResponseWithBody { +public final class ChunkFetchSuccess extends AbstractResponseMessage { public final StreamChunkId streamChunkId; public ChunkFetchSuccess(StreamChunkId streamChunkId, ManagedBuffer buffer) { @@ -67,14 +67,14 @@ public static ChunkFetchSuccess decode(ByteBuf buf) { @Override public int hashCode() { - return Objects.hashCode(streamChunkId, body); + return Objects.hashCode(streamChunkId, body()); } @Override public boolean equals(Object other) { if (other instanceof ChunkFetchSuccess) { ChunkFetchSuccess o = (ChunkFetchSuccess) other; - return streamChunkId.equals(o.streamChunkId) && body.equals(o.body); + return streamChunkId.equals(o.streamChunkId) && super.equals(o); } return false; } @@ -83,7 +83,7 @@ public boolean equals(Object other) { public String toString() { return Objects.toStringHelper(this) .add("streamChunkId", streamChunkId) - .add("buffer", body) + .add("buffer", body()) .toString(); } } diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/Message.java b/network/common/src/main/java/org/apache/spark/network/protocol/Message.java index 39afd03db60ee..66f5b8b3a59c8 100644 --- a/network/common/src/main/java/org/apache/spark/network/protocol/Message.java +++ b/network/common/src/main/java/org/apache/spark/network/protocol/Message.java @@ -19,17 +19,25 @@ import io.netty.buffer.ByteBuf; +import org.apache.spark.network.buffer.ManagedBuffer; + /** An on-the-wire transmittable message. */ public interface Message extends Encodable { /** Used to identify this request type. */ Type type(); + /** An optional body for the message. */ + ManagedBuffer body(); + + /** Whether to include the body of the message in the same frame as the message. */ + boolean isBodyInFrame(); + /** Preceding every serialized Message is its type, which allows us to deserialize it. */ public static enum Type implements Encodable { ChunkFetchRequest(0), ChunkFetchSuccess(1), ChunkFetchFailure(2), RpcRequest(3), RpcResponse(4), RpcFailure(5), StreamRequest(6), StreamResponse(7), StreamFailure(8), - OneWayMessage(9); + OneWayMessage(9), User(-1); private final byte id; @@ -57,6 +65,7 @@ public static Type decode(ByteBuf buf) { case 7: return StreamResponse; case 8: return StreamFailure; case 9: return OneWayMessage; + case -1: throw new IllegalArgumentException("User type messages cannot be decoded."); default: throw new IllegalArgumentException("Unknown message type: " + id); } } diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/MessageEncoder.java b/network/common/src/main/java/org/apache/spark/network/protocol/MessageEncoder.java index 6cce97c807dc0..abca22347b783 100644 --- a/network/common/src/main/java/org/apache/spark/network/protocol/MessageEncoder.java +++ b/network/common/src/main/java/org/apache/spark/network/protocol/MessageEncoder.java @@ -42,25 +42,28 @@ public final class MessageEncoder extends MessageToMessageEncoder { * data to 'out', in order to enable zero-copy transfer. */ @Override - public void encode(ChannelHandlerContext ctx, Message in, List out) { + public void encode(ChannelHandlerContext ctx, Message in, List out) throws Exception { Object body = null; long bodyLength = 0; boolean isBodyInFrame = false; - // Detect ResponseWithBody messages and get the data buffer out of them. - // The body is used in order to enable zero-copy transfer for the payload. - if (in instanceof ResponseWithBody) { - ResponseWithBody resp = (ResponseWithBody) in; + // If the message has a body, take it out to enable zero-copy transfer for the payload. + if (in.body() != null) { try { - bodyLength = resp.body.size(); - body = resp.body.convertToNetty(); - isBodyInFrame = resp.isBodyInFrame; + bodyLength = in.body().size(); + body = in.body().convertToNetty(); + isBodyInFrame = in.isBodyInFrame(); } catch (Exception e) { - // Re-encode this message as a failure response. - String error = e.getMessage() != null ? e.getMessage() : "null"; - logger.error(String.format("Error processing %s for client %s", - resp, ctx.channel().remoteAddress()), e); - encode(ctx, resp.createFailureResponse(error), out); + if (in instanceof AbstractResponseMessage) { + AbstractResponseMessage resp = (AbstractResponseMessage) in; + // Re-encode this message as a failure response. + String error = e.getMessage() != null ? e.getMessage() : "null"; + logger.error(String.format("Error processing %s for client %s", + in, ctx.channel().remoteAddress()), e); + encode(ctx, resp.createFailureResponse(error), out); + } else { + throw e; + } return; } } diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/OneWayMessage.java b/network/common/src/main/java/org/apache/spark/network/protocol/OneWayMessage.java index 95a0270be3da9..efe0470f35875 100644 --- a/network/common/src/main/java/org/apache/spark/network/protocol/OneWayMessage.java +++ b/network/common/src/main/java/org/apache/spark/network/protocol/OneWayMessage.java @@ -17,21 +17,21 @@ package org.apache.spark.network.protocol; -import java.util.Arrays; - import com.google.common.base.Objects; import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; + +import org.apache.spark.network.buffer.ManagedBuffer; +import org.apache.spark.network.buffer.NettyManagedBuffer; /** * A RPC that does not expect a reply, which is handled by a remote * {@link org.apache.spark.network.server.RpcHandler}. */ -public final class OneWayMessage implements RequestMessage { - /** Serialized message to send to remote RpcHandler. */ - public final byte[] message; +public final class OneWayMessage extends AbstractMessage implements RequestMessage { - public OneWayMessage(byte[] message) { - this.message = message; + public OneWayMessage(ManagedBuffer body) { + super(body, true); } @Override @@ -39,29 +39,34 @@ public OneWayMessage(byte[] message) { @Override public int encodedLength() { - return Encoders.ByteArrays.encodedLength(message); + // The integer (a.k.a. the body size) is not really used, since that information is already + // encoded in the frame length. But this maintains backwards compatibility with versions of + // RpcRequest that use Encoders.ByteArrays. + return 4; } @Override public void encode(ByteBuf buf) { - Encoders.ByteArrays.encode(buf, message); + // See comment in encodedLength(). + buf.writeInt((int) body().size()); } public static OneWayMessage decode(ByteBuf buf) { - byte[] message = Encoders.ByteArrays.decode(buf); - return new OneWayMessage(message); + // See comment in encodedLength(). + buf.readInt(); + return new OneWayMessage(new NettyManagedBuffer(buf.retain())); } @Override public int hashCode() { - return Arrays.hashCode(message); + return Objects.hashCode(body()); } @Override public boolean equals(Object other) { if (other instanceof OneWayMessage) { OneWayMessage o = (OneWayMessage) other; - return Arrays.equals(message, o.message); + return super.equals(o); } return false; } @@ -69,7 +74,7 @@ public boolean equals(Object other) { @Override public String toString() { return Objects.toStringHelper(this) - .add("message", message) + .add("body", body()) .toString(); } } diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/RpcFailure.java b/network/common/src/main/java/org/apache/spark/network/protocol/RpcFailure.java index 2dfc7876ba328..a76624ef5dc96 100644 --- a/network/common/src/main/java/org/apache/spark/network/protocol/RpcFailure.java +++ b/network/common/src/main/java/org/apache/spark/network/protocol/RpcFailure.java @@ -21,7 +21,7 @@ import io.netty.buffer.ByteBuf; /** Response to {@link RpcRequest} for a failed RPC. */ -public final class RpcFailure implements ResponseMessage { +public final class RpcFailure extends AbstractMessage implements ResponseMessage { public final long requestId; public final String errorString; diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/RpcRequest.java b/network/common/src/main/java/org/apache/spark/network/protocol/RpcRequest.java index 745039db742fa..96213794a8015 100644 --- a/network/common/src/main/java/org/apache/spark/network/protocol/RpcRequest.java +++ b/network/common/src/main/java/org/apache/spark/network/protocol/RpcRequest.java @@ -17,26 +17,25 @@ package org.apache.spark.network.protocol; -import java.util.Arrays; - import com.google.common.base.Objects; import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; + +import org.apache.spark.network.buffer.ManagedBuffer; +import org.apache.spark.network.buffer.NettyManagedBuffer; /** * A generic RPC which is handled by a remote {@link org.apache.spark.network.server.RpcHandler}. * This will correspond to a single * {@link org.apache.spark.network.protocol.ResponseMessage} (either success or failure). */ -public final class RpcRequest implements RequestMessage { +public final class RpcRequest extends AbstractMessage implements RequestMessage { /** Used to link an RPC request with its response. */ public final long requestId; - /** Serialized message to send to remote RpcHandler. */ - public final byte[] message; - - public RpcRequest(long requestId, byte[] message) { + public RpcRequest(long requestId, ManagedBuffer message) { + super(message, true); this.requestId = requestId; - this.message = message; } @Override @@ -44,31 +43,36 @@ public RpcRequest(long requestId, byte[] message) { @Override public int encodedLength() { - return 8 + Encoders.ByteArrays.encodedLength(message); + // The integer (a.k.a. the body size) is not really used, since that information is already + // encoded in the frame length. But this maintains backwards compatibility with versions of + // RpcRequest that use Encoders.ByteArrays. + return 8 + 4; } @Override public void encode(ByteBuf buf) { buf.writeLong(requestId); - Encoders.ByteArrays.encode(buf, message); + // See comment in encodedLength(). + buf.writeInt((int) body().size()); } public static RpcRequest decode(ByteBuf buf) { long requestId = buf.readLong(); - byte[] message = Encoders.ByteArrays.decode(buf); - return new RpcRequest(requestId, message); + // See comment in encodedLength(). + buf.readInt(); + return new RpcRequest(requestId, new NettyManagedBuffer(buf.retain())); } @Override public int hashCode() { - return Objects.hashCode(requestId, Arrays.hashCode(message)); + return Objects.hashCode(requestId, body()); } @Override public boolean equals(Object other) { if (other instanceof RpcRequest) { RpcRequest o = (RpcRequest) other; - return requestId == o.requestId && Arrays.equals(message, o.message); + return requestId == o.requestId && super.equals(o); } return false; } @@ -77,7 +81,7 @@ public boolean equals(Object other) { public String toString() { return Objects.toStringHelper(this) .add("requestId", requestId) - .add("message", message) + .add("body", body()) .toString(); } } diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/RpcResponse.java b/network/common/src/main/java/org/apache/spark/network/protocol/RpcResponse.java index 1671cd444f039..bae866e14a1e1 100644 --- a/network/common/src/main/java/org/apache/spark/network/protocol/RpcResponse.java +++ b/network/common/src/main/java/org/apache/spark/network/protocol/RpcResponse.java @@ -17,49 +17,62 @@ package org.apache.spark.network.protocol; -import java.util.Arrays; - import com.google.common.base.Objects; import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; + +import org.apache.spark.network.buffer.ManagedBuffer; +import org.apache.spark.network.buffer.NettyManagedBuffer; /** Response to {@link RpcRequest} for a successful RPC. */ -public final class RpcResponse implements ResponseMessage { +public final class RpcResponse extends AbstractResponseMessage { public final long requestId; - public final byte[] response; - public RpcResponse(long requestId, byte[] response) { + public RpcResponse(long requestId, ManagedBuffer message) { + super(message, true); this.requestId = requestId; - this.response = response; } @Override public Type type() { return Type.RpcResponse; } @Override - public int encodedLength() { return 8 + Encoders.ByteArrays.encodedLength(response); } + public int encodedLength() { + // The integer (a.k.a. the body size) is not really used, since that information is already + // encoded in the frame length. But this maintains backwards compatibility with versions of + // RpcRequest that use Encoders.ByteArrays. + return 8 + 4; + } @Override public void encode(ByteBuf buf) { buf.writeLong(requestId); - Encoders.ByteArrays.encode(buf, response); + // See comment in encodedLength(). + buf.writeInt((int) body().size()); + } + + @Override + public ResponseMessage createFailureResponse(String error) { + return new RpcFailure(requestId, error); } public static RpcResponse decode(ByteBuf buf) { long requestId = buf.readLong(); - byte[] response = Encoders.ByteArrays.decode(buf); - return new RpcResponse(requestId, response); + // See comment in encodedLength(). + buf.readInt(); + return new RpcResponse(requestId, new NettyManagedBuffer(buf.retain())); } @Override public int hashCode() { - return Objects.hashCode(requestId, Arrays.hashCode(response)); + return Objects.hashCode(requestId, body()); } @Override public boolean equals(Object other) { if (other instanceof RpcResponse) { RpcResponse o = (RpcResponse) other; - return requestId == o.requestId && Arrays.equals(response, o.response); + return requestId == o.requestId && super.equals(o); } return false; } @@ -68,7 +81,7 @@ public boolean equals(Object other) { public String toString() { return Objects.toStringHelper(this) .add("requestId", requestId) - .add("response", response) + .add("body", body()) .toString(); } } diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/StreamFailure.java b/network/common/src/main/java/org/apache/spark/network/protocol/StreamFailure.java index e3dade2ebf905..26747ee55b4de 100644 --- a/network/common/src/main/java/org/apache/spark/network/protocol/StreamFailure.java +++ b/network/common/src/main/java/org/apache/spark/network/protocol/StreamFailure.java @@ -26,7 +26,7 @@ /** * Message indicating an error when transferring a stream. */ -public final class StreamFailure implements ResponseMessage { +public final class StreamFailure extends AbstractMessage implements ResponseMessage { public final String streamId; public final String error; diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/StreamRequest.java b/network/common/src/main/java/org/apache/spark/network/protocol/StreamRequest.java index 821e8f53884d7..35af5a84ba6bd 100644 --- a/network/common/src/main/java/org/apache/spark/network/protocol/StreamRequest.java +++ b/network/common/src/main/java/org/apache/spark/network/protocol/StreamRequest.java @@ -29,7 +29,7 @@ * The stream ID is an arbitrary string that needs to be negotiated between the two endpoints before * the data can be streamed. */ -public final class StreamRequest implements RequestMessage { +public final class StreamRequest extends AbstractMessage implements RequestMessage { public final String streamId; public StreamRequest(String streamId) { diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/StreamResponse.java b/network/common/src/main/java/org/apache/spark/network/protocol/StreamResponse.java index ac5ab9a323a11..51b899930f721 100644 --- a/network/common/src/main/java/org/apache/spark/network/protocol/StreamResponse.java +++ b/network/common/src/main/java/org/apache/spark/network/protocol/StreamResponse.java @@ -30,15 +30,15 @@ * sender. The receiver is expected to set a temporary channel handler that will consume the * number of bytes this message says the stream has. */ -public final class StreamResponse extends ResponseWithBody { - public final String streamId; - public final long byteCount; +public final class StreamResponse extends AbstractResponseMessage { + public final String streamId; + public final long byteCount; - public StreamResponse(String streamId, long byteCount, ManagedBuffer buffer) { - super(buffer, false); - this.streamId = streamId; - this.byteCount = byteCount; - } + public StreamResponse(String streamId, long byteCount, ManagedBuffer buffer) { + super(buffer, false); + this.streamId = streamId; + this.byteCount = byteCount; + } @Override public Type type() { return Type.StreamResponse; } @@ -68,7 +68,7 @@ public static StreamResponse decode(ByteBuf buf) { @Override public int hashCode() { - return Objects.hashCode(byteCount, streamId); + return Objects.hashCode(byteCount, streamId, body()); } @Override @@ -85,6 +85,7 @@ public String toString() { return Objects.toStringHelper(this) .add("streamId", streamId) .add("byteCount", byteCount) + .add("body", body()) .toString(); } diff --git a/network/common/src/main/java/org/apache/spark/network/sasl/SaslClientBootstrap.java b/network/common/src/main/java/org/apache/spark/network/sasl/SaslClientBootstrap.java index 69923769d44b4..68381037d6891 100644 --- a/network/common/src/main/java/org/apache/spark/network/sasl/SaslClientBootstrap.java +++ b/network/common/src/main/java/org/apache/spark/network/sasl/SaslClientBootstrap.java @@ -17,6 +17,8 @@ package org.apache.spark.network.sasl; +import java.io.IOException; +import java.nio.ByteBuffer; import javax.security.sasl.Sasl; import javax.security.sasl.SaslException; @@ -28,6 +30,7 @@ import org.apache.spark.network.client.TransportClient; import org.apache.spark.network.client.TransportClientBootstrap; +import org.apache.spark.network.util.JavaUtils; import org.apache.spark.network.util.TransportConf; /** @@ -70,11 +73,12 @@ public void doBootstrap(TransportClient client, Channel channel) { while (!saslClient.isComplete()) { SaslMessage msg = new SaslMessage(appId, payload); - ByteBuf buf = Unpooled.buffer(msg.encodedLength()); + ByteBuf buf = Unpooled.buffer(msg.encodedLength() + (int) msg.body().size()); msg.encode(buf); + buf.writeBytes(msg.body().nioByteBuffer()); - byte[] response = client.sendRpcSync(buf.array(), conf.saslRTTimeoutMs()); - payload = saslClient.response(response); + ByteBuffer response = client.sendRpcSync(buf.nioBuffer(), conf.saslRTTimeoutMs()); + payload = saslClient.response(JavaUtils.bufferToArray(response)); } client.setClientId(appId); @@ -88,6 +92,8 @@ public void doBootstrap(TransportClient client, Channel channel) { saslClient = null; logger.debug("Channel {} configured for SASL encryption.", client); } + } catch (IOException ioe) { + throw new RuntimeException(ioe); } finally { if (saslClient != null) { try { diff --git a/network/common/src/main/java/org/apache/spark/network/sasl/SaslMessage.java b/network/common/src/main/java/org/apache/spark/network/sasl/SaslMessage.java index cad76ab7aa54e..e52b526f09c77 100644 --- a/network/common/src/main/java/org/apache/spark/network/sasl/SaslMessage.java +++ b/network/common/src/main/java/org/apache/spark/network/sasl/SaslMessage.java @@ -18,38 +18,50 @@ package org.apache.spark.network.sasl; import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; -import org.apache.spark.network.protocol.Encodable; +import org.apache.spark.network.buffer.NettyManagedBuffer; import org.apache.spark.network.protocol.Encoders; +import org.apache.spark.network.protocol.AbstractMessage; /** * Encodes a Sasl-related message which is attempting to authenticate using some credentials tagged * with the given appId. This appId allows a single SaslRpcHandler to multiplex different * applications which may be using different sets of credentials. */ -class SaslMessage implements Encodable { +class SaslMessage extends AbstractMessage { /** Serialization tag used to catch incorrect payloads. */ private static final byte TAG_BYTE = (byte) 0xEA; public final String appId; - public final byte[] payload; - public SaslMessage(String appId, byte[] payload) { + public SaslMessage(String appId, byte[] message) { + this(appId, Unpooled.wrappedBuffer(message)); + } + + public SaslMessage(String appId, ByteBuf message) { + super(new NettyManagedBuffer(message), true); this.appId = appId; - this.payload = payload; } + @Override + public Type type() { return Type.User; } + @Override public int encodedLength() { - return 1 + Encoders.Strings.encodedLength(appId) + Encoders.ByteArrays.encodedLength(payload); + // The integer (a.k.a. the body size) is not really used, since that information is already + // encoded in the frame length. But this maintains backwards compatibility with versions of + // RpcRequest that use Encoders.ByteArrays. + return 1 + Encoders.Strings.encodedLength(appId) + 4; } @Override public void encode(ByteBuf buf) { buf.writeByte(TAG_BYTE); Encoders.Strings.encode(buf, appId); - Encoders.ByteArrays.encode(buf, payload); + // See comment in encodedLength(). + buf.writeInt((int) body().size()); } public static SaslMessage decode(ByteBuf buf) { @@ -59,7 +71,8 @@ public static SaslMessage decode(ByteBuf buf) { } String appId = Encoders.Strings.decode(buf); - byte[] payload = Encoders.ByteArrays.decode(buf); - return new SaslMessage(appId, payload); + // See comment in encodedLength(). + buf.readInt(); + return new SaslMessage(appId, buf.retain()); } } diff --git a/network/common/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java b/network/common/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java index 830db94b890c5..c215bd9d15045 100644 --- a/network/common/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java +++ b/network/common/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java @@ -17,8 +17,11 @@ package org.apache.spark.network.sasl; +import java.io.IOException; +import java.nio.ByteBuffer; import javax.security.sasl.Sasl; +import io.netty.buffer.ByteBuf; import io.netty.buffer.Unpooled; import io.netty.channel.Channel; import org.slf4j.Logger; @@ -28,6 +31,7 @@ import org.apache.spark.network.client.TransportClient; import org.apache.spark.network.server.RpcHandler; import org.apache.spark.network.server.StreamManager; +import org.apache.spark.network.util.JavaUtils; import org.apache.spark.network.util.TransportConf; /** @@ -70,14 +74,20 @@ class SaslRpcHandler extends RpcHandler { } @Override - public void receive(TransportClient client, byte[] message, RpcResponseCallback callback) { + public void receive(TransportClient client, ByteBuffer message, RpcResponseCallback callback) { if (isComplete) { // Authentication complete, delegate to base handler. delegate.receive(client, message, callback); return; } - SaslMessage saslMessage = SaslMessage.decode(Unpooled.wrappedBuffer(message)); + ByteBuf nettyBuf = Unpooled.wrappedBuffer(message); + SaslMessage saslMessage; + try { + saslMessage = SaslMessage.decode(nettyBuf); + } finally { + nettyBuf.release(); + } if (saslServer == null) { // First message in the handshake, setup the necessary state. @@ -86,8 +96,14 @@ public void receive(TransportClient client, byte[] message, RpcResponseCallback conf.saslServerAlwaysEncrypt()); } - byte[] response = saslServer.response(saslMessage.payload); - callback.onSuccess(response); + byte[] response; + try { + response = saslServer.response(JavaUtils.bufferToArray( + saslMessage.body().nioByteBuffer())); + } catch (IOException ioe) { + throw new RuntimeException(ioe); + } + callback.onSuccess(ByteBuffer.wrap(response)); // Setup encryption after the SASL response is sent, otherwise the client can't parse the // response. It's ok to change the channel pipeline here since we are processing an incoming @@ -109,7 +125,7 @@ public void receive(TransportClient client, byte[] message, RpcResponseCallback } @Override - public void receive(TransportClient client, byte[] message) { + public void receive(TransportClient client, ByteBuffer message) { delegate.receive(client, message); } diff --git a/network/common/src/main/java/org/apache/spark/network/server/MessageHandler.java b/network/common/src/main/java/org/apache/spark/network/server/MessageHandler.java index b80c15106ecbd..3843406b27403 100644 --- a/network/common/src/main/java/org/apache/spark/network/server/MessageHandler.java +++ b/network/common/src/main/java/org/apache/spark/network/server/MessageHandler.java @@ -26,7 +26,7 @@ */ public abstract class MessageHandler { /** Handles the receipt of a single message. */ - public abstract void handle(T message); + public abstract void handle(T message) throws Exception; /** Invoked when an exception was caught on the Channel. */ public abstract void exceptionCaught(Throwable cause); diff --git a/network/common/src/main/java/org/apache/spark/network/server/NoOpRpcHandler.java b/network/common/src/main/java/org/apache/spark/network/server/NoOpRpcHandler.java index 1502b7489e864..6ed61da5c7eff 100644 --- a/network/common/src/main/java/org/apache/spark/network/server/NoOpRpcHandler.java +++ b/network/common/src/main/java/org/apache/spark/network/server/NoOpRpcHandler.java @@ -1,5 +1,3 @@ -package org.apache.spark.network.server; - /* * Licensed to the Apache Software Foundation (ASF) under one or more * contributor license agreements. See the NOTICE file distributed with @@ -17,6 +15,10 @@ * limitations under the License. */ +package org.apache.spark.network.server; + +import java.nio.ByteBuffer; + import org.apache.spark.network.client.RpcResponseCallback; import org.apache.spark.network.client.TransportClient; @@ -29,7 +31,7 @@ public NoOpRpcHandler() { } @Override - public void receive(TransportClient client, byte[] message, RpcResponseCallback callback) { + public void receive(TransportClient client, ByteBuffer message, RpcResponseCallback callback) { throw new UnsupportedOperationException("Cannot handle messages"); } diff --git a/network/common/src/main/java/org/apache/spark/network/server/RpcHandler.java b/network/common/src/main/java/org/apache/spark/network/server/RpcHandler.java index 65109ddfe13b9..88a92800c5ed8 100644 --- a/network/common/src/main/java/org/apache/spark/network/server/RpcHandler.java +++ b/network/common/src/main/java/org/apache/spark/network/server/RpcHandler.java @@ -17,6 +17,8 @@ package org.apache.spark.network.server; +import java.nio.ByteBuffer; + import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -44,7 +46,7 @@ public abstract class RpcHandler { */ public abstract void receive( TransportClient client, - byte[] message, + ByteBuffer message, RpcResponseCallback callback); /** @@ -62,7 +64,7 @@ public abstract void receive( * of this RPC. This will always be the exact same object for a particular channel. * @param message The serialized bytes of the RPC. */ - public void receive(TransportClient client, byte[] message) { + public void receive(TransportClient client, ByteBuffer message) { receive(client, message, ONE_WAY_CALLBACK); } @@ -79,7 +81,7 @@ private static class OneWayRpcCallback implements RpcResponseCallback { private final Logger logger = LoggerFactory.getLogger(OneWayRpcCallback.class); @Override - public void onSuccess(byte[] response) { + public void onSuccess(ByteBuffer response) { logger.warn("Response provided for one-way RPC."); } diff --git a/network/common/src/main/java/org/apache/spark/network/server/TransportChannelHandler.java b/network/common/src/main/java/org/apache/spark/network/server/TransportChannelHandler.java index 3164e00679035..09435bcbab35e 100644 --- a/network/common/src/main/java/org/apache/spark/network/server/TransportChannelHandler.java +++ b/network/common/src/main/java/org/apache/spark/network/server/TransportChannelHandler.java @@ -99,7 +99,7 @@ public void channelUnregistered(ChannelHandlerContext ctx) throws Exception { } @Override - public void channelRead0(ChannelHandlerContext ctx, Message request) { + public void channelRead0(ChannelHandlerContext ctx, Message request) throws Exception { if (request instanceof RequestMessage) { requestHandler.handle((RequestMessage) request); } else { diff --git a/network/common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java b/network/common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java index db18ea77d1073..c864d7ce16bd3 100644 --- a/network/common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java +++ b/network/common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java @@ -17,6 +17,8 @@ package org.apache.spark.network.server; +import java.nio.ByteBuffer; + import com.google.common.base.Preconditions; import com.google.common.base.Throwables; import io.netty.channel.Channel; @@ -26,6 +28,7 @@ import org.slf4j.LoggerFactory; import org.apache.spark.network.buffer.ManagedBuffer; +import org.apache.spark.network.buffer.NioManagedBuffer; import org.apache.spark.network.client.RpcResponseCallback; import org.apache.spark.network.client.TransportClient; import org.apache.spark.network.protocol.ChunkFetchRequest; @@ -143,10 +146,10 @@ private void processStreamRequest(final StreamRequest req) { private void processRpcRequest(final RpcRequest req) { try { - rpcHandler.receive(reverseClient, req.message, new RpcResponseCallback() { + rpcHandler.receive(reverseClient, req.body().nioByteBuffer(), new RpcResponseCallback() { @Override - public void onSuccess(byte[] response) { - respond(new RpcResponse(req.requestId, response)); + public void onSuccess(ByteBuffer response) { + respond(new RpcResponse(req.requestId, new NioManagedBuffer(response))); } @Override @@ -157,14 +160,18 @@ public void onFailure(Throwable e) { } catch (Exception e) { logger.error("Error while invoking RpcHandler#receive() on RPC id " + req.requestId, e); respond(new RpcFailure(req.requestId, Throwables.getStackTraceAsString(e))); + } finally { + req.body().release(); } } private void processOneWayMessage(OneWayMessage req) { try { - rpcHandler.receive(reverseClient, req.message); + rpcHandler.receive(reverseClient, req.body().nioByteBuffer()); } catch (Exception e) { logger.error("Error while invoking RpcHandler#receive() for one-way message.", e); + } finally { + req.body().release(); } } diff --git a/network/common/src/main/java/org/apache/spark/network/util/JavaUtils.java b/network/common/src/main/java/org/apache/spark/network/util/JavaUtils.java index 7d27439cfde7a..b3d8e0cd7cdcd 100644 --- a/network/common/src/main/java/org/apache/spark/network/util/JavaUtils.java +++ b/network/common/src/main/java/org/apache/spark/network/util/JavaUtils.java @@ -132,7 +132,7 @@ private static boolean isSymlink(File file) throws IOException { return !fileInCanonicalDir.getCanonicalFile().equals(fileInCanonicalDir.getAbsoluteFile()); } - private static final ImmutableMap timeSuffixes = + private static final ImmutableMap timeSuffixes = ImmutableMap.builder() .put("us", TimeUnit.MICROSECONDS) .put("ms", TimeUnit.MILLISECONDS) @@ -164,32 +164,32 @@ private static boolean isSymlink(File file) throws IOException { */ private static long parseTimeString(String str, TimeUnit unit) { String lower = str.toLowerCase().trim(); - + try { Matcher m = Pattern.compile("(-?[0-9]+)([a-z]+)?").matcher(lower); if (!m.matches()) { throw new NumberFormatException("Failed to parse time string: " + str); } - + long val = Long.parseLong(m.group(1)); String suffix = m.group(2); - + // Check for invalid suffixes if (suffix != null && !timeSuffixes.containsKey(suffix)) { throw new NumberFormatException("Invalid suffix: \"" + suffix + "\""); } - + // If suffix is valid use that, otherwise none was provided and use the default passed return unit.convert(val, suffix != null ? timeSuffixes.get(suffix) : unit); } catch (NumberFormatException e) { String timeError = "Time must be specified as seconds (s), " + "milliseconds (ms), microseconds (us), minutes (m or min), hour (h), or day (d). " + "E.g. 50s, 100ms, or 250us."; - + throw new NumberFormatException(timeError + "\n" + e.getMessage()); } } - + /** * Convert a time parameter such as (50s, 100ms, or 250us) to milliseconds for internal use. If * no suffix is provided, the passed number is assumed to be in ms. @@ -205,10 +205,10 @@ public static long timeStringAsMs(String str) { public static long timeStringAsSec(String str) { return parseTimeString(str, TimeUnit.SECONDS); } - + /** * Convert a passed byte string (e.g. 50b, 100kb, or 250mb) to a ByteUnit for - * internal use. If no suffix is provided a direct conversion of the provided default is + * internal use. If no suffix is provided a direct conversion of the provided default is * attempted. */ private static long parseByteString(String str, ByteUnit unit) { @@ -217,7 +217,7 @@ private static long parseByteString(String str, ByteUnit unit) { try { Matcher m = Pattern.compile("([0-9]+)([a-z]+)?").matcher(lower); Matcher fractionMatcher = Pattern.compile("([0-9]+\\.[0-9]+)([a-z]+)?").matcher(lower); - + if (m.matches()) { long val = Long.parseLong(m.group(1)); String suffix = m.group(2); @@ -228,14 +228,14 @@ private static long parseByteString(String str, ByteUnit unit) { } // If suffix is valid use that, otherwise none was provided and use the default passed - return unit.convertFrom(val, suffix != null ? byteSuffixes.get(suffix) : unit); + return unit.convertFrom(val, suffix != null ? byteSuffixes.get(suffix) : unit); } else if (fractionMatcher.matches()) { - throw new NumberFormatException("Fractional values are not supported. Input was: " + throw new NumberFormatException("Fractional values are not supported. Input was: " + fractionMatcher.group(1)); } else { - throw new NumberFormatException("Failed to parse byte string: " + str); + throw new NumberFormatException("Failed to parse byte string: " + str); } - + } catch (NumberFormatException e) { String timeError = "Size must be specified as bytes (b), " + "kibibytes (k), mebibytes (m), gibibytes (g), tebibytes (t), or pebibytes(p). " + @@ -248,7 +248,7 @@ private static long parseByteString(String str, ByteUnit unit) { /** * Convert a passed byte string (e.g. 50b, 100k, or 250m) to bytes for * internal use. - * + * * If no suffix is provided, the passed number is assumed to be in bytes. */ public static long byteStringAsBytes(String str) { @@ -264,7 +264,7 @@ public static long byteStringAsBytes(String str) { public static long byteStringAsKb(String str) { return parseByteString(str, ByteUnit.KiB); } - + /** * Convert a passed byte string (e.g. 50b, 100k, or 250m) to mebibytes for * internal use. @@ -284,4 +284,20 @@ public static long byteStringAsMb(String str) { public static long byteStringAsGb(String str) { return parseByteString(str, ByteUnit.GiB); } + + /** + * Returns a byte array with the buffer's contents, trying to avoid copying the data if + * possible. + */ + public static byte[] bufferToArray(ByteBuffer buffer) { + if (buffer.hasArray() && buffer.arrayOffset() == 0 && + buffer.array().length == buffer.remaining()) { + return buffer.array(); + } else { + byte[] bytes = new byte[buffer.remaining()]; + buffer.get(bytes); + return bytes; + } + } + } diff --git a/network/common/src/main/java/org/apache/spark/network/util/TransportFrameDecoder.java b/network/common/src/main/java/org/apache/spark/network/util/TransportFrameDecoder.java index 5889562dd9705..a466c729154aa 100644 --- a/network/common/src/main/java/org/apache/spark/network/util/TransportFrameDecoder.java +++ b/network/common/src/main/java/org/apache/spark/network/util/TransportFrameDecoder.java @@ -17,9 +17,13 @@ package org.apache.spark.network.util; +import java.util.Iterator; +import java.util.LinkedList; + import com.google.common.base.Preconditions; import io.netty.buffer.ByteBuf; import io.netty.buffer.CompositeByteBuf; +import io.netty.buffer.Unpooled; import io.netty.channel.ChannelHandlerContext; import io.netty.channel.ChannelInboundHandlerAdapter; @@ -44,84 +48,138 @@ public class TransportFrameDecoder extends ChannelInboundHandlerAdapter { public static final String HANDLER_NAME = "frameDecoder"; private static final int LENGTH_SIZE = 8; private static final int MAX_FRAME_SIZE = Integer.MAX_VALUE; + private static final int UNKNOWN_FRAME_SIZE = -1; + + private final LinkedList buffers = new LinkedList<>(); + private final ByteBuf frameLenBuf = Unpooled.buffer(LENGTH_SIZE, LENGTH_SIZE); - private CompositeByteBuf buffer; + private long totalSize = 0; + private long nextFrameSize = UNKNOWN_FRAME_SIZE; private volatile Interceptor interceptor; @Override public void channelRead(ChannelHandlerContext ctx, Object data) throws Exception { ByteBuf in = (ByteBuf) data; + buffers.add(in); + totalSize += in.readableBytes(); + + while (!buffers.isEmpty()) { + // First, feed the interceptor, and if it's still, active, try again. + if (interceptor != null) { + ByteBuf first = buffers.getFirst(); + int available = first.readableBytes(); + if (feedInterceptor(first)) { + assert !first.isReadable() : "Interceptor still active but buffer has data."; + } - if (buffer == null) { - buffer = in.alloc().compositeBuffer(); - } - - buffer.addComponent(in).writerIndex(buffer.writerIndex() + in.readableBytes()); - - while (buffer.isReadable()) { - discardReadBytes(); - if (!feedInterceptor()) { + int read = available - first.readableBytes(); + if (read == available) { + buffers.removeFirst().release(); + } + totalSize -= read; + } else { + // Interceptor is not active, so try to decode one frame. ByteBuf frame = decodeNext(); if (frame == null) { break; } - ctx.fireChannelRead(frame); } } - - discardReadBytes(); } - private void discardReadBytes() { - // If the buffer's been retained by downstream code, then make a copy of the remaining - // bytes into a new buffer. Otherwise, just discard stale components. - if (buffer.refCnt() > 1) { - CompositeByteBuf newBuffer = buffer.alloc().compositeBuffer(); + private long decodeFrameSize() { + if (nextFrameSize != UNKNOWN_FRAME_SIZE || totalSize < LENGTH_SIZE) { + return nextFrameSize; + } - if (buffer.readableBytes() > 0) { - ByteBuf spillBuf = buffer.alloc().buffer(buffer.readableBytes()); - spillBuf.writeBytes(buffer); - newBuffer.addComponent(spillBuf).writerIndex(spillBuf.readableBytes()); + // We know there's enough data. If the first buffer contains all the data, great. Otherwise, + // hold the bytes for the frame length in a composite buffer until we have enough data to read + // the frame size. Normally, it should be rare to need more than one buffer to read the frame + // size. + ByteBuf first = buffers.getFirst(); + if (first.readableBytes() >= LENGTH_SIZE) { + nextFrameSize = first.readLong() - LENGTH_SIZE; + totalSize -= LENGTH_SIZE; + if (!first.isReadable()) { + buffers.removeFirst().release(); } + return nextFrameSize; + } - buffer.release(); - buffer = newBuffer; - } else { - buffer.discardReadComponents(); + while (frameLenBuf.readableBytes() < LENGTH_SIZE) { + ByteBuf next = buffers.getFirst(); + int toRead = Math.min(next.readableBytes(), LENGTH_SIZE - frameLenBuf.readableBytes()); + frameLenBuf.writeBytes(next, toRead); + if (!next.isReadable()) { + buffers.removeFirst().release(); + } } + + nextFrameSize = frameLenBuf.readLong() - LENGTH_SIZE; + totalSize -= LENGTH_SIZE; + frameLenBuf.clear(); + return nextFrameSize; } private ByteBuf decodeNext() throws Exception { - if (buffer.readableBytes() < LENGTH_SIZE) { + long frameSize = decodeFrameSize(); + if (frameSize == UNKNOWN_FRAME_SIZE || totalSize < frameSize) { return null; } - int frameLen = (int) buffer.readLong() - LENGTH_SIZE; - if (buffer.readableBytes() < frameLen) { - buffer.readerIndex(buffer.readerIndex() - LENGTH_SIZE); - return null; + // Reset size for next frame. + nextFrameSize = UNKNOWN_FRAME_SIZE; + + Preconditions.checkArgument(frameSize < MAX_FRAME_SIZE, "Too large frame: %s", frameSize); + Preconditions.checkArgument(frameSize > 0, "Frame length should be positive: %s", frameSize); + + // If the first buffer holds the entire frame, return it. + int remaining = (int) frameSize; + if (buffers.getFirst().readableBytes() >= remaining) { + return nextBufferForFrame(remaining); } - Preconditions.checkArgument(frameLen < MAX_FRAME_SIZE, "Too large frame: %s", frameLen); - Preconditions.checkArgument(frameLen > 0, "Frame length should be positive: %s", frameLen); + // Otherwise, create a composite buffer. + CompositeByteBuf frame = buffers.getFirst().alloc().compositeBuffer(); + while (remaining > 0) { + ByteBuf next = nextBufferForFrame(remaining); + remaining -= next.readableBytes(); + frame.addComponent(next).writerIndex(frame.writerIndex() + next.readableBytes()); + } + assert remaining == 0; + return frame; + } + + /** + * Takes the first buffer in the internal list, and either adjust it to fit in the frame + * (by taking a slice out of it) or remove it from the internal list. + */ + private ByteBuf nextBufferForFrame(int bytesToRead) { + ByteBuf buf = buffers.getFirst(); + ByteBuf frame; + + if (buf.readableBytes() > bytesToRead) { + frame = buf.retain().readSlice(bytesToRead); + totalSize -= bytesToRead; + } else { + frame = buf; + buffers.removeFirst(); + totalSize -= frame.readableBytes(); + } - ByteBuf frame = buffer.readSlice(frameLen); - frame.retain(); return frame; } @Override public void channelInactive(ChannelHandlerContext ctx) throws Exception { - if (buffer != null) { - if (buffer.isReadable()) { - feedInterceptor(); - } - buffer.release(); + for (ByteBuf b : buffers) { + b.release(); } if (interceptor != null) { interceptor.channelInactive(); } + frameLenBuf.release(); super.channelInactive(ctx); } @@ -141,8 +199,8 @@ public void setInterceptor(Interceptor interceptor) { /** * @return Whether the interceptor is still active after processing the data. */ - private boolean feedInterceptor() throws Exception { - if (interceptor != null && !interceptor.handle(buffer)) { + private boolean feedInterceptor(ByteBuf buf) throws Exception { + if (interceptor != null && !interceptor.handle(buf)) { interceptor = null; } return interceptor != null; diff --git a/network/common/src/test/java/org/apache/spark/network/ChunkFetchIntegrationSuite.java b/network/common/src/test/java/org/apache/spark/network/ChunkFetchIntegrationSuite.java index 50a324e293386..70c849d60e0a6 100644 --- a/network/common/src/test/java/org/apache/spark/network/ChunkFetchIntegrationSuite.java +++ b/network/common/src/test/java/org/apache/spark/network/ChunkFetchIntegrationSuite.java @@ -107,7 +107,10 @@ public ManagedBuffer getChunk(long streamId, int chunkIndex) { }; RpcHandler handler = new RpcHandler() { @Override - public void receive(TransportClient client, byte[] message, RpcResponseCallback callback) { + public void receive( + TransportClient client, + ByteBuffer message, + RpcResponseCallback callback) { throw new UnsupportedOperationException(); } diff --git a/network/common/src/test/java/org/apache/spark/network/ProtocolSuite.java b/network/common/src/test/java/org/apache/spark/network/ProtocolSuite.java index 1aa20900ffe74..6c8dd742f4b64 100644 --- a/network/common/src/test/java/org/apache/spark/network/ProtocolSuite.java +++ b/network/common/src/test/java/org/apache/spark/network/ProtocolSuite.java @@ -82,10 +82,10 @@ private void testClientToServer(Message msg) { @Test public void requests() { testClientToServer(new ChunkFetchRequest(new StreamChunkId(1, 2))); - testClientToServer(new RpcRequest(12345, new byte[0])); - testClientToServer(new RpcRequest(12345, new byte[100])); + testClientToServer(new RpcRequest(12345, new TestManagedBuffer(0))); + testClientToServer(new RpcRequest(12345, new TestManagedBuffer(10))); testClientToServer(new StreamRequest("abcde")); - testClientToServer(new OneWayMessage(new byte[100])); + testClientToServer(new OneWayMessage(new TestManagedBuffer(10))); } @Test @@ -94,8 +94,8 @@ public void responses() { testServerToClient(new ChunkFetchSuccess(new StreamChunkId(1, 2), new TestManagedBuffer(0))); testServerToClient(new ChunkFetchFailure(new StreamChunkId(1, 2), "this is an error")); testServerToClient(new ChunkFetchFailure(new StreamChunkId(1, 2), "")); - testServerToClient(new RpcResponse(12345, new byte[0])); - testServerToClient(new RpcResponse(12345, new byte[1000])); + testServerToClient(new RpcResponse(12345, new TestManagedBuffer(0))); + testServerToClient(new RpcResponse(12345, new TestManagedBuffer(100))); testServerToClient(new RpcFailure(0, "this is an error")); testServerToClient(new RpcFailure(0, "")); // Note: buffer size must be "0" since StreamResponse's buffer is written differently to the diff --git a/network/common/src/test/java/org/apache/spark/network/RequestTimeoutIntegrationSuite.java b/network/common/src/test/java/org/apache/spark/network/RequestTimeoutIntegrationSuite.java index 42955ef69235a..f9b5bf96d6215 100644 --- a/network/common/src/test/java/org/apache/spark/network/RequestTimeoutIntegrationSuite.java +++ b/network/common/src/test/java/org/apache/spark/network/RequestTimeoutIntegrationSuite.java @@ -31,6 +31,7 @@ import org.apache.spark.network.util.MapConfigProvider; import org.apache.spark.network.util.TransportConf; import org.junit.*; +import static org.junit.Assert.*; import java.io.IOException; import java.nio.ByteBuffer; @@ -84,13 +85,16 @@ public void tearDown() { @Test public void timeoutInactiveRequests() throws Exception { final Semaphore semaphore = new Semaphore(1); - final byte[] response = new byte[16]; + final int responseSize = 16; RpcHandler handler = new RpcHandler() { @Override - public void receive(TransportClient client, byte[] message, RpcResponseCallback callback) { + public void receive( + TransportClient client, + ByteBuffer message, + RpcResponseCallback callback) { try { semaphore.tryAcquire(FOREVER, TimeUnit.MILLISECONDS); - callback.onSuccess(response); + callback.onSuccess(ByteBuffer.allocate(responseSize)); } catch (InterruptedException e) { // do nothing } @@ -110,15 +114,15 @@ public StreamManager getStreamManager() { // First completes quickly (semaphore starts at 1). TestCallback callback0 = new TestCallback(); synchronized (callback0) { - client.sendRpc(new byte[0], callback0); + client.sendRpc(ByteBuffer.allocate(0), callback0); callback0.wait(FOREVER); - assert (callback0.success.length == response.length); + assertEquals(responseSize, callback0.successLength); } // Second times out after 2 seconds, with slack. Must be IOException. TestCallback callback1 = new TestCallback(); synchronized (callback1) { - client.sendRpc(new byte[0], callback1); + client.sendRpc(ByteBuffer.allocate(0), callback1); callback1.wait(4 * 1000); assert (callback1.failure != null); assert (callback1.failure instanceof IOException); @@ -131,13 +135,16 @@ public StreamManager getStreamManager() { @Test public void timeoutCleanlyClosesClient() throws Exception { final Semaphore semaphore = new Semaphore(0); - final byte[] response = new byte[16]; + final int responseSize = 16; RpcHandler handler = new RpcHandler() { @Override - public void receive(TransportClient client, byte[] message, RpcResponseCallback callback) { + public void receive( + TransportClient client, + ByteBuffer message, + RpcResponseCallback callback) { try { semaphore.tryAcquire(FOREVER, TimeUnit.MILLISECONDS); - callback.onSuccess(response); + callback.onSuccess(ByteBuffer.allocate(responseSize)); } catch (InterruptedException e) { // do nothing } @@ -158,7 +165,7 @@ public StreamManager getStreamManager() { clientFactory.createClient(TestUtils.getLocalHost(), server.getPort()); TestCallback callback0 = new TestCallback(); synchronized (callback0) { - client0.sendRpc(new byte[0], callback0); + client0.sendRpc(ByteBuffer.allocate(0), callback0); callback0.wait(FOREVER); assert (callback0.failure instanceof IOException); assert (!client0.isActive()); @@ -170,10 +177,10 @@ public StreamManager getStreamManager() { clientFactory.createClient(TestUtils.getLocalHost(), server.getPort()); TestCallback callback1 = new TestCallback(); synchronized (callback1) { - client1.sendRpc(new byte[0], callback1); + client1.sendRpc(ByteBuffer.allocate(0), callback1); callback1.wait(FOREVER); - assert (callback1.success.length == response.length); - assert (callback1.failure == null); + assertEquals(responseSize, callback1.successLength); + assertNull(callback1.failure); } } @@ -191,7 +198,10 @@ public ManagedBuffer getChunk(long streamId, int chunkIndex) { }; RpcHandler handler = new RpcHandler() { @Override - public void receive(TransportClient client, byte[] message, RpcResponseCallback callback) { + public void receive( + TransportClient client, + ByteBuffer message, + RpcResponseCallback callback) { throw new UnsupportedOperationException(); } @@ -218,9 +228,10 @@ public StreamManager getStreamManager() { synchronized (callback0) { // not complete yet, but should complete soon - assert (callback0.success == null && callback0.failure == null); + assertEquals(-1, callback0.successLength); + assertNull(callback0.failure); callback0.wait(2 * 1000); - assert (callback0.failure instanceof IOException); + assertTrue(callback0.failure instanceof IOException); } synchronized (callback1) { @@ -235,13 +246,13 @@ public StreamManager getStreamManager() { */ class TestCallback implements RpcResponseCallback, ChunkReceivedCallback { - byte[] success; + int successLength = -1; Throwable failure; @Override - public void onSuccess(byte[] response) { + public void onSuccess(ByteBuffer response) { synchronized(this) { - success = response; + successLength = response.remaining(); this.notifyAll(); } } @@ -258,7 +269,7 @@ public void onFailure(Throwable e) { public void onSuccess(int chunkIndex, ManagedBuffer buffer) { synchronized(this) { try { - success = buffer.nioByteBuffer().array(); + successLength = buffer.nioByteBuffer().remaining(); this.notifyAll(); } catch (IOException e) { // weird diff --git a/network/common/src/test/java/org/apache/spark/network/RpcIntegrationSuite.java b/network/common/src/test/java/org/apache/spark/network/RpcIntegrationSuite.java index 88fa2258bb794..9e9be98c140b7 100644 --- a/network/common/src/test/java/org/apache/spark/network/RpcIntegrationSuite.java +++ b/network/common/src/test/java/org/apache/spark/network/RpcIntegrationSuite.java @@ -17,6 +17,7 @@ package org.apache.spark.network; +import java.nio.ByteBuffer; import java.util.ArrayList; import java.util.Collections; import java.util.HashSet; @@ -26,7 +27,6 @@ import java.util.concurrent.Semaphore; import java.util.concurrent.TimeUnit; -import com.google.common.base.Charsets; import com.google.common.collect.Sets; import org.junit.AfterClass; import org.junit.BeforeClass; @@ -41,6 +41,7 @@ import org.apache.spark.network.server.RpcHandler; import org.apache.spark.network.server.StreamManager; import org.apache.spark.network.server.TransportServer; +import org.apache.spark.network.util.JavaUtils; import org.apache.spark.network.util.SystemPropertyConfigProvider; import org.apache.spark.network.util.TransportConf; @@ -55,11 +56,14 @@ public static void setUp() throws Exception { TransportConf conf = new TransportConf("shuffle", new SystemPropertyConfigProvider()); rpcHandler = new RpcHandler() { @Override - public void receive(TransportClient client, byte[] message, RpcResponseCallback callback) { - String msg = new String(message, Charsets.UTF_8); + public void receive( + TransportClient client, + ByteBuffer message, + RpcResponseCallback callback) { + String msg = JavaUtils.bytesToString(message); String[] parts = msg.split("/"); if (parts[0].equals("hello")) { - callback.onSuccess(("Hello, " + parts[1] + "!").getBytes(Charsets.UTF_8)); + callback.onSuccess(JavaUtils.stringToBytes("Hello, " + parts[1] + "!")); } else if (parts[0].equals("return error")) { callback.onFailure(new RuntimeException("Returned: " + parts[1])); } else if (parts[0].equals("throw error")) { @@ -68,9 +72,8 @@ public void receive(TransportClient client, byte[] message, RpcResponseCallback } @Override - public void receive(TransportClient client, byte[] message) { - String msg = new String(message, Charsets.UTF_8); - oneWayMsgs.add(msg); + public void receive(TransportClient client, ByteBuffer message) { + oneWayMsgs.add(JavaUtils.bytesToString(message)); } @Override @@ -103,8 +106,9 @@ private RpcResult sendRPC(String ... commands) throws Exception { RpcResponseCallback callback = new RpcResponseCallback() { @Override - public void onSuccess(byte[] message) { - res.successMessages.add(new String(message, Charsets.UTF_8)); + public void onSuccess(ByteBuffer message) { + String response = JavaUtils.bytesToString(message); + res.successMessages.add(response); sem.release(); } @@ -116,7 +120,7 @@ public void onFailure(Throwable e) { }; for (String command : commands) { - client.sendRpc(command.getBytes(Charsets.UTF_8), callback); + client.sendRpc(JavaUtils.stringToBytes(command), callback); } if (!sem.tryAcquire(commands.length, 5, TimeUnit.SECONDS)) { @@ -173,7 +177,7 @@ public void sendOneWayMessage() throws Exception { final String message = "no reply"; TransportClient client = clientFactory.createClient(TestUtils.getLocalHost(), server.getPort()); try { - client.send(message.getBytes(Charsets.UTF_8)); + client.send(JavaUtils.stringToBytes(message)); assertEquals(0, client.getHandler().numOutstandingRequests()); // Make sure the message arrives. diff --git a/network/common/src/test/java/org/apache/spark/network/StreamSuite.java b/network/common/src/test/java/org/apache/spark/network/StreamSuite.java index 538f3efe8d6f2..9c49556927f0b 100644 --- a/network/common/src/test/java/org/apache/spark/network/StreamSuite.java +++ b/network/common/src/test/java/org/apache/spark/network/StreamSuite.java @@ -116,7 +116,10 @@ public ManagedBuffer openStream(String streamId) { }; RpcHandler handler = new RpcHandler() { @Override - public void receive(TransportClient client, byte[] message, RpcResponseCallback callback) { + public void receive( + TransportClient client, + ByteBuffer message, + RpcResponseCallback callback) { throw new UnsupportedOperationException(); } diff --git a/network/common/src/test/java/org/apache/spark/network/TransportResponseHandlerSuite.java b/network/common/src/test/java/org/apache/spark/network/TransportResponseHandlerSuite.java index 30144f4a9fc7a..128f7cba74350 100644 --- a/network/common/src/test/java/org/apache/spark/network/TransportResponseHandlerSuite.java +++ b/network/common/src/test/java/org/apache/spark/network/TransportResponseHandlerSuite.java @@ -17,6 +17,8 @@ package org.apache.spark.network; +import java.nio.ByteBuffer; + import io.netty.channel.Channel; import io.netty.channel.local.LocalChannel; import org.junit.Test; @@ -27,6 +29,7 @@ import static org.mockito.Mockito.*; import org.apache.spark.network.buffer.ManagedBuffer; +import org.apache.spark.network.buffer.NioManagedBuffer; import org.apache.spark.network.client.ChunkReceivedCallback; import org.apache.spark.network.client.RpcResponseCallback; import org.apache.spark.network.client.StreamCallback; @@ -42,7 +45,7 @@ public class TransportResponseHandlerSuite { @Test - public void handleSuccessfulFetch() { + public void handleSuccessfulFetch() throws Exception { StreamChunkId streamChunkId = new StreamChunkId(1, 0); TransportResponseHandler handler = new TransportResponseHandler(new LocalChannel()); @@ -56,7 +59,7 @@ public void handleSuccessfulFetch() { } @Test - public void handleFailedFetch() { + public void handleFailedFetch() throws Exception { StreamChunkId streamChunkId = new StreamChunkId(1, 0); TransportResponseHandler handler = new TransportResponseHandler(new LocalChannel()); ChunkReceivedCallback callback = mock(ChunkReceivedCallback.class); @@ -69,7 +72,7 @@ public void handleFailedFetch() { } @Test - public void clearAllOutstandingRequests() { + public void clearAllOutstandingRequests() throws Exception { TransportResponseHandler handler = new TransportResponseHandler(new LocalChannel()); ChunkReceivedCallback callback = mock(ChunkReceivedCallback.class); handler.addFetchRequest(new StreamChunkId(1, 0), callback); @@ -88,23 +91,24 @@ public void clearAllOutstandingRequests() { } @Test - public void handleSuccessfulRPC() { + public void handleSuccessfulRPC() throws Exception { TransportResponseHandler handler = new TransportResponseHandler(new LocalChannel()); RpcResponseCallback callback = mock(RpcResponseCallback.class); handler.addRpcRequest(12345, callback); assertEquals(1, handler.numOutstandingRequests()); - handler.handle(new RpcResponse(54321, new byte[7])); // should be ignored + // This response should be ignored. + handler.handle(new RpcResponse(54321, new NioManagedBuffer(ByteBuffer.allocate(7)))); assertEquals(1, handler.numOutstandingRequests()); - byte[] arr = new byte[10]; - handler.handle(new RpcResponse(12345, arr)); - verify(callback, times(1)).onSuccess(eq(arr)); + ByteBuffer resp = ByteBuffer.allocate(10); + handler.handle(new RpcResponse(12345, new NioManagedBuffer(resp))); + verify(callback, times(1)).onSuccess(eq(ByteBuffer.allocate(10))); assertEquals(0, handler.numOutstandingRequests()); } @Test - public void handleFailedRPC() { + public void handleFailedRPC() throws Exception { TransportResponseHandler handler = new TransportResponseHandler(new LocalChannel()); RpcResponseCallback callback = mock(RpcResponseCallback.class); handler.addRpcRequest(12345, callback); @@ -119,7 +123,7 @@ public void handleFailedRPC() { } @Test - public void testActiveStreams() { + public void testActiveStreams() throws Exception { Channel c = new LocalChannel(); c.pipeline().addLast(TransportFrameDecoder.HANDLER_NAME, new TransportFrameDecoder()); TransportResponseHandler handler = new TransportResponseHandler(c); diff --git a/network/common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java b/network/common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java index a6f180bc40c9a..751516b9d82a1 100644 --- a/network/common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java +++ b/network/common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java @@ -22,7 +22,7 @@ import java.io.File; import java.lang.reflect.Method; -import java.nio.charset.StandardCharsets; +import java.nio.ByteBuffer; import java.util.Arrays; import java.util.List; import java.util.Random; @@ -57,6 +57,7 @@ import org.apache.spark.network.server.TransportServer; import org.apache.spark.network.server.TransportServerBootstrap; import org.apache.spark.network.util.ByteArrayWritableChannel; +import org.apache.spark.network.util.JavaUtils; import org.apache.spark.network.util.SystemPropertyConfigProvider; import org.apache.spark.network.util.TransportConf; @@ -123,39 +124,53 @@ public void testNonMatching() { } @Test - public void testSaslAuthentication() throws Exception { + public void testSaslAuthentication() throws Throwable { testBasicSasl(false); } @Test - public void testSaslEncryption() throws Exception { + public void testSaslEncryption() throws Throwable { testBasicSasl(true); } - private void testBasicSasl(boolean encrypt) throws Exception { + private void testBasicSasl(boolean encrypt) throws Throwable { RpcHandler rpcHandler = mock(RpcHandler.class); doAnswer(new Answer() { @Override public Void answer(InvocationOnMock invocation) { - byte[] message = (byte[]) invocation.getArguments()[1]; + ByteBuffer message = (ByteBuffer) invocation.getArguments()[1]; RpcResponseCallback cb = (RpcResponseCallback) invocation.getArguments()[2]; - assertEquals("Ping", new String(message, StandardCharsets.UTF_8)); - cb.onSuccess("Pong".getBytes(StandardCharsets.UTF_8)); + assertEquals("Ping", JavaUtils.bytesToString(message)); + cb.onSuccess(JavaUtils.stringToBytes("Pong")); return null; } }) .when(rpcHandler) - .receive(any(TransportClient.class), any(byte[].class), any(RpcResponseCallback.class)); + .receive(any(TransportClient.class), any(ByteBuffer.class), any(RpcResponseCallback.class)); SaslTestCtx ctx = new SaslTestCtx(rpcHandler, encrypt, false); try { - byte[] response = ctx.client.sendRpcSync("Ping".getBytes(StandardCharsets.UTF_8), - TimeUnit.SECONDS.toMillis(10)); - assertEquals("Pong", new String(response, StandardCharsets.UTF_8)); + ByteBuffer response = ctx.client.sendRpcSync(JavaUtils.stringToBytes("Ping"), + TimeUnit.SECONDS.toMillis(10)); + assertEquals("Pong", JavaUtils.bytesToString(response)); } finally { ctx.close(); // There should be 2 terminated events; one for the client, one for the server. - verify(rpcHandler, times(2)).connectionTerminated(any(TransportClient.class)); + Throwable error = null; + long deadline = System.nanoTime() + TimeUnit.NANOSECONDS.convert(10, TimeUnit.SECONDS); + while (deadline > System.nanoTime()) { + try { + verify(rpcHandler, times(2)).connectionTerminated(any(TransportClient.class)); + error = null; + break; + } catch (Throwable t) { + error = t; + TimeUnit.MILLISECONDS.sleep(10); + } + } + if (error != null) { + throw error; + } } } @@ -325,8 +340,8 @@ public void testDataEncryptionIsActuallyEnabled() throws Exception { SaslTestCtx ctx = null; try { ctx = new SaslTestCtx(mock(RpcHandler.class), true, true); - ctx.client.sendRpcSync("Ping".getBytes(StandardCharsets.UTF_8), - TimeUnit.SECONDS.toMillis(10)); + ctx.client.sendRpcSync(JavaUtils.stringToBytes("Ping"), + TimeUnit.SECONDS.toMillis(10)); fail("Should have failed to send RPC to server."); } catch (Exception e) { assertFalse(e.getCause() instanceof TimeoutException); diff --git a/network/common/src/test/java/org/apache/spark/network/util/TransportFrameDecoderSuite.java b/network/common/src/test/java/org/apache/spark/network/util/TransportFrameDecoderSuite.java index 19475c21ffce9..d4de4a941d480 100644 --- a/network/common/src/test/java/org/apache/spark/network/util/TransportFrameDecoderSuite.java +++ b/network/common/src/test/java/org/apache/spark/network/util/TransportFrameDecoderSuite.java @@ -118,6 +118,27 @@ public Void answer(InvocationOnMock in) { } } + @Test + public void testSplitLengthField() throws Exception { + byte[] frame = new byte[1024 * (RND.nextInt(31) + 1)]; + ByteBuf buf = Unpooled.buffer(frame.length + 8); + buf.writeLong(frame.length + 8); + buf.writeBytes(frame); + + TransportFrameDecoder decoder = new TransportFrameDecoder(); + ChannelHandlerContext ctx = mockChannelHandlerContext(); + try { + decoder.channelRead(ctx, buf.readSlice(RND.nextInt(7)).retain()); + verify(ctx, never()).fireChannelRead(any(ByteBuf.class)); + decoder.channelRead(ctx, buf); + verify(ctx).fireChannelRead(any(ByteBuf.class)); + assertEquals(0, buf.refCnt()); + } finally { + decoder.channelInactive(ctx); + release(buf); + } + } + @Test(expected = IllegalArgumentException.class) public void testNegativeFrameSize() throws Exception { testInvalidFrame(-1); @@ -183,7 +204,7 @@ private void testInvalidFrame(long size) throws Exception { try { decoder.channelRead(ctx, frame); } finally { - frame.release(); + release(frame); } } diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java index 3ddf5c3c39189..f22187a01db02 100644 --- a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java +++ b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java @@ -19,6 +19,7 @@ import java.io.File; import java.io.IOException; +import java.nio.ByteBuffer; import java.util.List; import com.google.common.annotations.VisibleForTesting; @@ -66,8 +67,8 @@ public ExternalShuffleBlockHandler( } @Override - public void receive(TransportClient client, byte[] message, RpcResponseCallback callback) { - BlockTransferMessage msgObj = BlockTransferMessage.Decoder.fromByteArray(message); + public void receive(TransportClient client, ByteBuffer message, RpcResponseCallback callback) { + BlockTransferMessage msgObj = BlockTransferMessage.Decoder.fromByteBuffer(message); handleMessage(msgObj, client, callback); } @@ -85,13 +86,13 @@ protected void handleMessage( } long streamId = streamManager.registerStream(client.getClientId(), blocks.iterator()); logger.trace("Registered streamId {} with {} buffers", streamId, msg.blockIds.length); - callback.onSuccess(new StreamHandle(streamId, msg.blockIds.length).toByteArray()); + callback.onSuccess(new StreamHandle(streamId, msg.blockIds.length).toByteBuffer()); } else if (msgObj instanceof RegisterExecutor) { RegisterExecutor msg = (RegisterExecutor) msgObj; checkAuth(client, msg.appId); blockManager.registerExecutor(msg.appId, msg.execId, msg.executorInfo); - callback.onSuccess(new byte[0]); + callback.onSuccess(ByteBuffer.wrap(new byte[0])); } else { throw new UnsupportedOperationException("Unexpected message: " + msgObj); diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java index ef3a9dcc8711f..58ca87d9d3b13 100644 --- a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java +++ b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java @@ -18,6 +18,7 @@ package org.apache.spark.network.shuffle; import java.io.IOException; +import java.nio.ByteBuffer; import java.util.List; import com.google.common.base.Preconditions; @@ -139,7 +140,7 @@ public void registerWithShuffleServer( checkInit(); TransportClient client = clientFactory.createUnmanagedClient(host, port); try { - byte[] registerMessage = new RegisterExecutor(appId, execId, executorInfo).toByteArray(); + ByteBuffer registerMessage = new RegisterExecutor(appId, execId, executorInfo).toByteBuffer(); client.sendRpcSync(registerMessage, 5000 /* timeoutMs */); } finally { client.close(); diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java index e653f5cb147ee..1b2ddbf1ed917 100644 --- a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java +++ b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java @@ -17,6 +17,7 @@ package org.apache.spark.network.shuffle; +import java.nio.ByteBuffer; import java.util.Arrays; import org.slf4j.Logger; @@ -89,11 +90,11 @@ public void start() { throw new IllegalArgumentException("Zero-sized blockIds array"); } - client.sendRpc(openMessage.toByteArray(), new RpcResponseCallback() { + client.sendRpc(openMessage.toByteBuffer(), new RpcResponseCallback() { @Override - public void onSuccess(byte[] response) { + public void onSuccess(ByteBuffer response) { try { - streamHandle = (StreamHandle) BlockTransferMessage.Decoder.fromByteArray(response); + streamHandle = (StreamHandle) BlockTransferMessage.Decoder.fromByteBuffer(response); logger.trace("Successfully opened blocks {}, preparing to fetch chunks.", streamHandle); // Immediately request all chunks -- we expect that the total size of the request is diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/mesos/MesosExternalShuffleClient.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/mesos/MesosExternalShuffleClient.java index 7543b6be4f2a1..675820308bd4c 100644 --- a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/mesos/MesosExternalShuffleClient.java +++ b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/mesos/MesosExternalShuffleClient.java @@ -18,6 +18,7 @@ package org.apache.spark.network.shuffle.mesos; import java.io.IOException; +import java.nio.ByteBuffer; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -54,11 +55,11 @@ public MesosExternalShuffleClient( public void registerDriverWithShuffleService(String host, int port) throws IOException { checkInit(); - byte[] registerDriver = new RegisterDriver(appId).toByteArray(); + ByteBuffer registerDriver = new RegisterDriver(appId).toByteBuffer(); TransportClient client = clientFactory.createClient(host, port); client.sendRpc(registerDriver, new RpcResponseCallback() { @Override - public void onSuccess(byte[] response) { + public void onSuccess(ByteBuffer response) { logger.info("Successfully registered app " + appId + " with external shuffle service."); } diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/BlockTransferMessage.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/BlockTransferMessage.java index fcb52363e632c..7fbe3384b4d4f 100644 --- a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/BlockTransferMessage.java +++ b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/BlockTransferMessage.java @@ -17,6 +17,8 @@ package org.apache.spark.network.shuffle.protocol; +import java.nio.ByteBuffer; + import io.netty.buffer.ByteBuf; import io.netty.buffer.Unpooled; @@ -53,7 +55,7 @@ private Type(int id) { // NB: Java does not support static methods in interfaces, so we must put this in a static class. public static class Decoder { /** Deserializes the 'type' byte followed by the message itself. */ - public static BlockTransferMessage fromByteArray(byte[] msg) { + public static BlockTransferMessage fromByteBuffer(ByteBuffer msg) { ByteBuf buf = Unpooled.wrappedBuffer(msg); byte type = buf.readByte(); switch (type) { @@ -68,12 +70,12 @@ public static BlockTransferMessage fromByteArray(byte[] msg) { } /** Serializes the 'type' byte followed by the message itself. */ - public byte[] toByteArray() { + public ByteBuffer toByteBuffer() { // Allow room for encoded message, plus the type byte ByteBuf buf = Unpooled.buffer(encodedLength() + 1); buf.writeByte(type().id); encode(buf); assert buf.writableBytes() == 0 : "Writable bytes remain: " + buf.writableBytes(); - return buf.array(); + return buf.nioBuffer(); } } From 31f52e0b5014bd1d14a812b032778b0bda21f5d3 Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Sat, 28 Nov 2015 18:48:10 -0800 Subject: [PATCH 2/2] Fix shuffle tests. --- .../network/sasl/SaslIntegrationSuite.java | 18 +++++++++------- .../shuffle/BlockTransferMessagesSuite.java | 2 +- .../ExternalShuffleBlockHandlerSuite.java | 21 ++++++++++--------- .../shuffle/OneForOneBlockFetcherSuite.java | 8 +++---- 4 files changed, 26 insertions(+), 23 deletions(-) diff --git a/network/shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java b/network/shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java index 1c2fa4d0d462c..19c870aebb023 100644 --- a/network/shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java +++ b/network/shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java @@ -18,6 +18,7 @@ package org.apache.spark.network.sasl; import java.io.IOException; +import java.nio.ByteBuffer; import java.util.Arrays; import java.util.concurrent.atomic.AtomicReference; @@ -52,6 +53,7 @@ import org.apache.spark.network.shuffle.protocol.OpenBlocks; import org.apache.spark.network.shuffle.protocol.RegisterExecutor; import org.apache.spark.network.shuffle.protocol.StreamHandle; +import org.apache.spark.network.util.JavaUtils; import org.apache.spark.network.util.SystemPropertyConfigProvider; import org.apache.spark.network.util.TransportConf; @@ -107,8 +109,8 @@ public void testGoodClient() throws IOException { TransportClient client = clientFactory.createClient(TestUtils.getLocalHost(), server.getPort()); String msg = "Hello, World!"; - byte[] resp = client.sendRpcSync(msg.getBytes(), TIMEOUT_MS); - assertEquals(msg, new String(resp)); // our rpc handler should just return the given msg + ByteBuffer resp = client.sendRpcSync(JavaUtils.stringToBytes(msg), TIMEOUT_MS); + assertEquals(msg, JavaUtils.bytesToString(resp)); } @Test @@ -136,7 +138,7 @@ public void testNoSaslClient() throws IOException { TransportClient client = clientFactory.createClient(TestUtils.getLocalHost(), server.getPort()); try { - client.sendRpcSync(new byte[13], TIMEOUT_MS); + client.sendRpcSync(ByteBuffer.allocate(13), TIMEOUT_MS); fail("Should have failed"); } catch (Exception e) { assertTrue(e.getMessage(), e.getMessage().contains("Expected SaslMessage")); @@ -144,7 +146,7 @@ public void testNoSaslClient() throws IOException { try { // Guessing the right tag byte doesn't magically get you in... - client.sendRpcSync(new byte[] { (byte) 0xEA }, TIMEOUT_MS); + client.sendRpcSync(ByteBuffer.wrap(new byte[] { (byte) 0xEA }), TIMEOUT_MS); fail("Should have failed"); } catch (Exception e) { assertTrue(e.getMessage(), e.getMessage().contains("java.lang.IndexOutOfBoundsException")); @@ -222,13 +224,13 @@ public synchronized void onBlockFetchFailure(String blockId, Throwable t) { new String[] { System.getProperty("java.io.tmpdir") }, 1, "org.apache.spark.shuffle.sort.SortShuffleManager"); RegisterExecutor regmsg = new RegisterExecutor("app-1", "0", executorInfo); - client1.sendRpcSync(regmsg.toByteArray(), TIMEOUT_MS); + client1.sendRpcSync(regmsg.toByteBuffer(), TIMEOUT_MS); // Make a successful request to fetch blocks, which creates a new stream. But do not actually // fetch any blocks, to keep the stream open. OpenBlocks openMessage = new OpenBlocks("app-1", "0", blockIds); - byte[] response = client1.sendRpcSync(openMessage.toByteArray(), TIMEOUT_MS); - StreamHandle stream = (StreamHandle) BlockTransferMessage.Decoder.fromByteArray(response); + ByteBuffer response = client1.sendRpcSync(openMessage.toByteBuffer(), TIMEOUT_MS); + StreamHandle stream = (StreamHandle) BlockTransferMessage.Decoder.fromByteBuffer(response); long streamId = stream.streamId; // Create a second client, authenticated with a different app ID, and try to read from @@ -275,7 +277,7 @@ public synchronized void onFailure(int chunkIndex, Throwable t) { /** RPC handler which simply responds with the message it received. */ public static class TestRpcHandler extends RpcHandler { @Override - public void receive(TransportClient client, byte[] message, RpcResponseCallback callback) { + public void receive(TransportClient client, ByteBuffer message, RpcResponseCallback callback) { callback.onSuccess(message); } diff --git a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/BlockTransferMessagesSuite.java b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/BlockTransferMessagesSuite.java index d65de9ca550a3..86c8609e7070b 100644 --- a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/BlockTransferMessagesSuite.java +++ b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/BlockTransferMessagesSuite.java @@ -36,7 +36,7 @@ public void serializeOpenShuffleBlocks() { } private void checkSerializeDeserialize(BlockTransferMessage msg) { - BlockTransferMessage msg2 = BlockTransferMessage.Decoder.fromByteArray(msg.toByteArray()); + BlockTransferMessage msg2 = BlockTransferMessage.Decoder.fromByteBuffer(msg.toByteBuffer()); assertEquals(msg, msg2); assertEquals(msg.hashCode(), msg2.hashCode()); assertEquals(msg.toString(), msg2.toString()); diff --git a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandlerSuite.java b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandlerSuite.java index e61390cf57061..9379412155e88 100644 --- a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandlerSuite.java +++ b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandlerSuite.java @@ -60,12 +60,12 @@ public void testRegisterExecutor() { RpcResponseCallback callback = mock(RpcResponseCallback.class); ExecutorShuffleInfo config = new ExecutorShuffleInfo(new String[] {"/a", "/b"}, 16, "sort"); - byte[] registerMessage = new RegisterExecutor("app0", "exec1", config).toByteArray(); + ByteBuffer registerMessage = new RegisterExecutor("app0", "exec1", config).toByteBuffer(); handler.receive(client, registerMessage, callback); verify(blockResolver, times(1)).registerExecutor("app0", "exec1", config); - verify(callback, times(1)).onSuccess((byte[]) any()); - verify(callback, never()).onFailure((Throwable) any()); + verify(callback, times(1)).onSuccess(any(ByteBuffer.class)); + verify(callback, never()).onFailure(any(Throwable.class)); } @SuppressWarnings("unchecked") @@ -77,17 +77,18 @@ public void testOpenShuffleBlocks() { ManagedBuffer block1Marker = new NioManagedBuffer(ByteBuffer.wrap(new byte[7])); when(blockResolver.getBlockData("app0", "exec1", "b0")).thenReturn(block0Marker); when(blockResolver.getBlockData("app0", "exec1", "b1")).thenReturn(block1Marker); - byte[] openBlocks = new OpenBlocks("app0", "exec1", new String[] { "b0", "b1" }).toByteArray(); + ByteBuffer openBlocks = new OpenBlocks("app0", "exec1", new String[] { "b0", "b1" }) + .toByteBuffer(); handler.receive(client, openBlocks, callback); verify(blockResolver, times(1)).getBlockData("app0", "exec1", "b0"); verify(blockResolver, times(1)).getBlockData("app0", "exec1", "b1"); - ArgumentCaptor response = ArgumentCaptor.forClass(byte[].class); + ArgumentCaptor response = ArgumentCaptor.forClass(ByteBuffer.class); verify(callback, times(1)).onSuccess(response.capture()); verify(callback, never()).onFailure((Throwable) any()); StreamHandle handle = - (StreamHandle) BlockTransferMessage.Decoder.fromByteArray(response.getValue()); + (StreamHandle) BlockTransferMessage.Decoder.fromByteBuffer(response.getValue()); assertEquals(2, handle.numChunks); @SuppressWarnings("unchecked") @@ -104,7 +105,7 @@ public void testOpenShuffleBlocks() { public void testBadMessages() { RpcResponseCallback callback = mock(RpcResponseCallback.class); - byte[] unserializableMsg = new byte[] { 0x12, 0x34, 0x56 }; + ByteBuffer unserializableMsg = ByteBuffer.wrap(new byte[] { 0x12, 0x34, 0x56 }); try { handler.receive(client, unserializableMsg, callback); fail("Should have thrown"); @@ -112,7 +113,7 @@ public void testBadMessages() { // pass } - byte[] unexpectedMsg = new UploadBlock("a", "e", "b", new byte[1], new byte[2]).toByteArray(); + ByteBuffer unexpectedMsg = new UploadBlock("a", "e", "b", new byte[1], new byte[2]).toByteBuffer(); try { handler.receive(client, unexpectedMsg, callback); fail("Should have thrown"); @@ -120,7 +121,7 @@ public void testBadMessages() { // pass } - verify(callback, never()).onSuccess((byte[]) any()); - verify(callback, never()).onFailure((Throwable) any()); + verify(callback, never()).onSuccess(any(ByteBuffer.class)); + verify(callback, never()).onFailure(any(Throwable.class)); } } diff --git a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/OneForOneBlockFetcherSuite.java b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/OneForOneBlockFetcherSuite.java index b35a6d685dd02..2590b9ce4c1f1 100644 --- a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/OneForOneBlockFetcherSuite.java +++ b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/OneForOneBlockFetcherSuite.java @@ -134,14 +134,14 @@ private BlockFetchingListener fetchBlocks(final LinkedHashMap() { @Override public Void answer(InvocationOnMock invocationOnMock) throws Throwable { - BlockTransferMessage message = BlockTransferMessage.Decoder.fromByteArray( - (byte[]) invocationOnMock.getArguments()[0]); + BlockTransferMessage message = BlockTransferMessage.Decoder.fromByteBuffer( + (ByteBuffer) invocationOnMock.getArguments()[0]); RpcResponseCallback callback = (RpcResponseCallback) invocationOnMock.getArguments()[1]; - callback.onSuccess(new StreamHandle(123, blocks.size()).toByteArray()); + callback.onSuccess(new StreamHandle(123, blocks.size()).toByteBuffer()); assertEquals(new OpenBlocks("app-id", "exec-id", blockIds), message); return null; } - }).when(client).sendRpc((byte[]) any(), (RpcResponseCallback) any()); + }).when(client).sendRpc(any(ByteBuffer.class), any(RpcResponseCallback.class)); // Respond to each chunk request with a single buffer from our blocks array. final AtomicInteger expectedChunkIndex = new AtomicInteger(0);