From e2b8c4a2595c3196821ffed582ceea487d0d65d4 Mon Sep 17 00:00:00 2001 From: Kousuke Saruta Date: Fri, 18 Jul 2014 18:01:35 -0700 Subject: [PATCH 01/14] Modify to propagete error using ConnectionManager --- .../apache/spark/network/BufferMessage.scala | 6 ++-- .../spark/network/ConnectionManager.scala | 1 - .../org/apache/spark/network/Message.scala | 2 ++ .../spark/network/MessageChunkHeader.scala | 19 ++++++++++-- .../spark/storage/BlockFetcherIterator.scala | 29 ++++++++++++------- .../spark/storage/BlockManagerWorker.scala | 8 +++-- 6 files changed, 46 insertions(+), 19 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/network/BufferMessage.scala b/core/src/main/scala/org/apache/spark/network/BufferMessage.scala index 04df2f3b0d696..09541e721fe81 100644 --- a/core/src/main/scala/org/apache/spark/network/BufferMessage.scala +++ b/core/src/main/scala/org/apache/spark/network/BufferMessage.scala @@ -48,7 +48,7 @@ class BufferMessage(id_ : Int, val buffers: ArrayBuffer[ByteBuffer], var ackId: val security = if (isSecurityNeg) 1 else 0 if (size == 0 && !gotChunkForSendingOnce) { val newChunk = new MessageChunk( - new MessageChunkHeader(typ, id, 0, 0, ackId, security, senderAddress), null) + new MessageChunkHeader(typ, id, 0, 0, ackId, hasError, security, senderAddress), null) gotChunkForSendingOnce = true return Some(newChunk) } @@ -66,7 +66,7 @@ class BufferMessage(id_ : Int, val buffers: ArrayBuffer[ByteBuffer], var ackId: } buffer.position(buffer.position + newBuffer.remaining) val newChunk = new MessageChunk(new MessageChunkHeader( - typ, id, size, newBuffer.remaining, ackId, security, senderAddress), newBuffer) + typ, id, size, newBuffer.remaining, ackId, hasError, security, senderAddress), newBuffer) gotChunkForSendingOnce = true return Some(newChunk) } @@ -88,7 +88,7 @@ class BufferMessage(id_ : Int, val buffers: ArrayBuffer[ByteBuffer], var ackId: val newBuffer = buffer.slice().limit(chunkSize).asInstanceOf[ByteBuffer] buffer.position(buffer.position + newBuffer.remaining) val newChunk = new MessageChunk(new MessageChunkHeader( - typ, id, size, newBuffer.remaining, ackId, security, senderAddress), newBuffer) + typ, id, size, newBuffer.remaining, ackId, hasError, security, senderAddress), newBuffer) return Some(newChunk) } None diff --git a/core/src/main/scala/org/apache/spark/network/ConnectionManager.scala b/core/src/main/scala/org/apache/spark/network/ConnectionManager.scala index 8a1cdb812962e..9c5e5816356c5 100644 --- a/core/src/main/scala/org/apache/spark/network/ConnectionManager.scala +++ b/core/src/main/scala/org/apache/spark/network/ConnectionManager.scala @@ -674,7 +674,6 @@ private[spark] class ConnectionManager(port: Int, conf: SparkConf, ackMessage.get.asInstanceOf[BufferMessage].ackId = bufferMessage.id } } - sendMessage(connectionManagerId, ackMessage.getOrElse { Message.createBufferMessage(bufferMessage.id) }) diff --git a/core/src/main/scala/org/apache/spark/network/Message.scala b/core/src/main/scala/org/apache/spark/network/Message.scala index 7caccfdbb44f9..04ea50f62918c 100644 --- a/core/src/main/scala/org/apache/spark/network/Message.scala +++ b/core/src/main/scala/org/apache/spark/network/Message.scala @@ -28,6 +28,7 @@ private[spark] abstract class Message(val typ: Long, val id: Int) { var startTime = -1L var finishTime = -1L var isSecurityNeg = false + var hasError = false def size: Int @@ -87,6 +88,7 @@ private[spark] object Message { case BUFFER_MESSAGE => new BufferMessage(header.id, ArrayBuffer(ByteBuffer.allocate(header.totalSize)), header.other) } + newMessage.hasError = header.hasError newMessage.senderAddress = header.address newMessage } diff --git a/core/src/main/scala/org/apache/spark/network/MessageChunkHeader.scala b/core/src/main/scala/org/apache/spark/network/MessageChunkHeader.scala index ead663ede7a1c..a9d0be082464f 100644 --- a/core/src/main/scala/org/apache/spark/network/MessageChunkHeader.scala +++ b/core/src/main/scala/org/apache/spark/network/MessageChunkHeader.scala @@ -27,6 +27,7 @@ private[spark] class MessageChunkHeader( val totalSize: Int, val chunkSize: Int, val other: Int, + val hasError: Boolean, val securityNeg: Int, val address: InetSocketAddress) { lazy val buffer = { @@ -41,6 +42,13 @@ private[spark] class MessageChunkHeader( putInt(totalSize). putInt(chunkSize). putInt(other). + put{ + if (hasError) { + 1.asInstanceOf[Byte] + } else { + 0.asInstanceOf[Byte] + } + }. putInt(securityNeg). putInt(ip.size). put(ip). @@ -56,7 +64,7 @@ private[spark] class MessageChunkHeader( private[spark] object MessageChunkHeader { - val HEADER_SIZE = 44 + val HEADER_SIZE = 45 def create(buffer: ByteBuffer): MessageChunkHeader = { if (buffer.remaining != HEADER_SIZE) { @@ -67,13 +75,20 @@ private[spark] object MessageChunkHeader { val totalSize = buffer.getInt() val chunkSize = buffer.getInt() val other = buffer.getInt() + val hasError = { + if (buffer.get() == 0) { + false + } else { + true + } + } val securityNeg = buffer.getInt() val ipSize = buffer.getInt() val ipBytes = new Array[Byte](ipSize) buffer.get(ipBytes) val ip = InetAddress.getByAddress(ipBytes) val port = buffer.getInt() - new MessageChunkHeader(typ, id, totalSize, chunkSize, other, securityNeg, + new MessageChunkHeader(typ, id, totalSize, chunkSize, other, hasError, securityNeg, new InetSocketAddress(ip, port)) } } diff --git a/core/src/main/scala/org/apache/spark/storage/BlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/BlockFetcherIterator.scala index 408a797088059..a928ca2c07caf 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockFetcherIterator.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockFetcherIterator.scala @@ -122,18 +122,25 @@ object BlockFetcherIterator { future.onSuccess { case Some(message) => { val bufferMessage = message.asInstanceOf[BufferMessage] - val blockMessageArray = BlockMessageArray.fromBufferMessage(bufferMessage) - for (blockMessage <- blockMessageArray) { - if (blockMessage.getType != BlockMessage.TYPE_GOT_BLOCK) { - throw new SparkException( - "Unexpected message " + blockMessage.getType + " received from " + cmId) + if (bufferMessage.hasError) { + logError("Could not get block(s) from " + cmId) + for ((blockId, size) <- req.blocks) { + results.put(new FetchResult(blockId, -1, null)) + } + } else { + val blockMessageArray = BlockMessageArray.fromBufferMessage(bufferMessage) + for (blockMessage <- blockMessageArray) { + if (blockMessage.getType != BlockMessage.TYPE_GOT_BLOCK) { + throw new SparkException( + "Unexpected message " + blockMessage.getType + " received from " + cmId) + } + val blockId = blockMessage.getId + val networkSize = blockMessage.getData.limit() + results.put(new FetchResult(blockId, sizeMap(blockId), + () => dataDeserialize(blockId, blockMessage.getData, serializer))) + _remoteBytesRead += networkSize + logDebug("Got remote block " + blockId + " after " + Utils.getUsedTimeMs(startTime)) } - val blockId = blockMessage.getId - val networkSize = blockMessage.getData.limit() - results.put(new FetchResult(blockId, sizeMap(blockId), - () => dataDeserialize(blockId, blockMessage.getData, serializer))) - _remoteBytesRead += networkSize - logDebug("Got remote block " + blockId + " after " + Utils.getUsedTimeMs(startTime)) } } case None => { diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerWorker.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerWorker.scala index c7766a3a65671..419995eb9c479 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerWorker.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerWorker.scala @@ -44,8 +44,12 @@ private[spark] class BlockManagerWorker(val blockManager: BlockManager) extends val responseMessages = blockMessages.map(processBlockMessage).filter(_ != None).map(_.get) Some(new BlockMessageArray(responseMessages).toBufferMessage) } catch { - case e: Exception => logError("Exception handling buffer message", e) - None + case e: Exception => { + logError("Exception handling buffer message", e) + val errorMessage = Message.createBufferMessage(msg.id) + errorMessage.hasError = true + Some(errorMessage) + } } } case otherMessage: Any => { From 4117b8fb0da568da1c6c41a06bab8004ef4dc422 Mon Sep 17 00:00:00 2001 From: Kousuke Saruta Date: Sun, 20 Jul 2014 02:00:22 -0700 Subject: [PATCH 02/14] Modified ConnectionManager to be alble to handle error during processing message --- .../spark/network/ConnectionManager.scala | 45 ++++++++++++------- .../spark/network/MessageChunkHeader.scala | 16 +------ .../spark/storage/BlockManagerWorker.scala | 4 +- 3 files changed, 33 insertions(+), 32 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/network/ConnectionManager.scala b/core/src/main/scala/org/apache/spark/network/ConnectionManager.scala index 9c5e5816356c5..ba78d470ec16e 100644 --- a/core/src/main/scala/org/apache/spark/network/ConnectionManager.scala +++ b/core/src/main/scala/org/apache/spark/network/ConnectionManager.scala @@ -657,26 +657,37 @@ private[spark] class ConnectionManager(port: Int, conf: SparkConf, sentMessageStatus.markDone() } } else { - val ackMessage = if (onReceiveCallback != null) { - logDebug("Calling back") - onReceiveCallback(bufferMessage, connectionManagerId) - } else { - logDebug("Not calling back as callback is null") - None - } + var ackMessage : Option[Message] = None + try { + ackMessage = if (onReceiveCallback != null) { + logDebug("Calling back") + onReceiveCallback(bufferMessage, connectionManagerId) + } else { + logDebug("Not calling back as callback is null") + None + } - if (ackMessage.isDefined) { - if (!ackMessage.get.isInstanceOf[BufferMessage]) { - logDebug("Response to " + bufferMessage + " is not a buffer message, it is of type " - + ackMessage.get.getClass) - } else if (!ackMessage.get.asInstanceOf[BufferMessage].hasAckId) { - logDebug("Response to " + bufferMessage + " does not have ack id set") - ackMessage.get.asInstanceOf[BufferMessage].ackId = bufferMessage.id + if (ackMessage.isDefined) { + if (!ackMessage.get.isInstanceOf[BufferMessage]) { + logDebug("Response to " + bufferMessage + " is not a buffer message, it is of type " + + ackMessage.get.getClass) + } else if (!ackMessage.get.asInstanceOf[BufferMessage].hasAckId) { + logDebug("Response to " + bufferMessage + " does not have ack id set") + ackMessage.get.asInstanceOf[BufferMessage].ackId = bufferMessage.id + } } + } catch { + case e: Exception => { + logError(s"Exception was thrown during processing message", e) + val m = Message.createBufferMessage(bufferMessage.id) + m.hasError = true + ackMessage = Some(m) + } + } finally { + sendMessage(connectionManagerId, ackMessage.getOrElse { + Message.createBufferMessage(bufferMessage.id) + }) } - sendMessage(connectionManagerId, ackMessage.getOrElse { - Message.createBufferMessage(bufferMessage.id) - }) } } case _ => throw new Exception("Unknown type message received") diff --git a/core/src/main/scala/org/apache/spark/network/MessageChunkHeader.scala b/core/src/main/scala/org/apache/spark/network/MessageChunkHeader.scala index a9d0be082464f..f3ecca5f992e0 100644 --- a/core/src/main/scala/org/apache/spark/network/MessageChunkHeader.scala +++ b/core/src/main/scala/org/apache/spark/network/MessageChunkHeader.scala @@ -42,13 +42,7 @@ private[spark] class MessageChunkHeader( putInt(totalSize). putInt(chunkSize). putInt(other). - put{ - if (hasError) { - 1.asInstanceOf[Byte] - } else { - 0.asInstanceOf[Byte] - } - }. + put(if (hasError) 1.asInstanceOf[Byte] else 0.asInstanceOf[Byte]). putInt(securityNeg). putInt(ip.size). put(ip). @@ -75,13 +69,7 @@ private[spark] object MessageChunkHeader { val totalSize = buffer.getInt() val chunkSize = buffer.getInt() val other = buffer.getInt() - val hasError = { - if (buffer.get() == 0) { - false - } else { - true - } - } + val hasError = buffer.get() != 0 val securityNeg = buffer.getInt() val ipSize = buffer.getInt() val ipBytes = new Array[Byte](ipSize) diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerWorker.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerWorker.scala index 419995eb9c479..9c459cd3a3437 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerWorker.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerWorker.scala @@ -54,7 +54,9 @@ private[spark] class BlockManagerWorker(val blockManager: BlockManager) extends } case otherMessage: Any => { logError("Unknown type message received: " + otherMessage) - None + val errorMessage = Message.createBufferMessage(msg.id) + errorMessage.hasError = true + Some(errorMessage) } } } From 12d3de84a4d9d6d47df92059bc41965723bf02f7 Mon Sep 17 00:00:00 2001 From: Kousuke Saruta Date: Sun, 20 Jul 2014 11:02:24 -0700 Subject: [PATCH 03/14] Added BlockFetcherIteratorSuite.scala --- .../storage/BlockFetcherIteratorSuite.scala | 74 +++++++++++++++++++ 1 file changed, 74 insertions(+) create mode 100644 core/src/test/scala/org/apache/spark/storage/BlockFetcherIteratorSuite.scala diff --git a/core/src/test/scala/org/apache/spark/storage/BlockFetcherIteratorSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockFetcherIteratorSuite.scala new file mode 100644 index 0000000000000..4cc81572693a0 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/storage/BlockFetcherIteratorSuite.scala @@ -0,0 +1,74 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.storage + +import org.scalatest.{FunSuite, Matchers} + +import org.mockito.Mockito.{mock, when} +import org.mockito.Matchers.any + +import java.nio.ByteBuffer + +import scala.concurrent.future +import scala.concurrent.ExecutionContext.Implicits.global +import org.apache.spark._ +import org.apache.spark.storage.BlockFetcherIterator._ +import org.apache.spark.network.{ConnectionManager, ConnectionManagerId, + Message} + +class BlockFetcherIteratorSuite extends FunSuite with Matchers { + + test("block fetch from remote fails using BasicBlockFetcherIterator") { + val conf = new SparkConf + val blockManager = mock(classOf[BlockManager]) + val connManager = mock(classOf[ConnectionManager]) + val message = Message.createBufferMessage(0) + message.hasError = true + val someMessage = Some(message) + + val f = future { + someMessage + } + when(blockManager.connectionManager).thenReturn(connManager) + when(connManager.sendMessageReliably(any(), + any())).thenReturn(f) + when(blockManager.futureExecContext).thenReturn(global) + when(blockManager.blockManagerId).thenReturn( + BlockManagerId("test-client", "test-client", 1, 0)) + when(blockManager.maxBytesInFlight).thenReturn(48 * 1024 * 1024) + + val dummyBlId1 = ShuffleBlockId(0,0,0) + val dummyBlId2 = ShuffleBlockId(0,1,0) + val bmId = BlockManagerId("test-server", "test-server",1 , 0) + val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long)])]( + (bmId, Seq((dummyBlId1, 1))), + (bmId, Seq((dummyBlId2, 1))) + ) + + val iterator = new BasicBlockFetcherIterator(blockManager, + blocksByAddress, null) + + iterator.initialize() + iterator.foreach{ + case (_, r) => { + (!r.isDefined) should be(true) + } + } + } + +} From 281589c5018b0cadbebc6f8cfac8a831dd40edeb Mon Sep 17 00:00:00 2001 From: Kousuke Saruta Date: Tue, 22 Jul 2014 22:55:35 -0700 Subject: [PATCH 04/14] Add a test case to BlockFetcherIteratorSuite.scala for fetching block from remote from successfully --- .../storage/BlockFetcherIteratorSuite.scala | 73 ++++++++++++++++--- 1 file changed, 64 insertions(+), 9 deletions(-) diff --git a/core/src/test/scala/org/apache/spark/storage/BlockFetcherIteratorSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockFetcherIteratorSuite.scala index 4cc81572693a0..6adb9f51436d0 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockFetcherIteratorSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockFetcherIteratorSuite.scala @@ -24,8 +24,10 @@ import org.mockito.Matchers.any import java.nio.ByteBuffer +import scala.collection.mutable.ArrayBuffer import scala.concurrent.future import scala.concurrent.ExecutionContext.Implicits.global + import org.apache.spark._ import org.apache.spark.storage.BlockFetcherIterator._ import org.apache.spark.network.{ConnectionManager, ConnectionManagerId, @@ -34,30 +36,29 @@ import org.apache.spark.network.{ConnectionManager, ConnectionManagerId, class BlockFetcherIteratorSuite extends FunSuite with Matchers { test("block fetch from remote fails using BasicBlockFetcherIterator") { - val conf = new SparkConf val blockManager = mock(classOf[BlockManager]) val connManager = mock(classOf[ConnectionManager]) - val message = Message.createBufferMessage(0) - message.hasError = true - val someMessage = Some(message) + when(blockManager.connectionManager).thenReturn(connManager) val f = future { + val message = Message.createBufferMessage(0) + message.hasError = true + val someMessage = Some(message) someMessage } - when(blockManager.connectionManager).thenReturn(connManager) when(connManager.sendMessageReliably(any(), any())).thenReturn(f) when(blockManager.futureExecContext).thenReturn(global) + when(blockManager.blockManagerId).thenReturn( BlockManagerId("test-client", "test-client", 1, 0)) when(blockManager.maxBytesInFlight).thenReturn(48 * 1024 * 1024) - val dummyBlId1 = ShuffleBlockId(0,0,0) - val dummyBlId2 = ShuffleBlockId(0,1,0) + val blId1 = ShuffleBlockId(0,0,0) + val blId2 = ShuffleBlockId(0,1,0) val bmId = BlockManagerId("test-server", "test-server",1 , 0) val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long)])]( - (bmId, Seq((dummyBlId1, 1))), - (bmId, Seq((dummyBlId2, 1))) + (bmId, Seq((blId1, 1), (blId2, 1))) ) val iterator = new BasicBlockFetcherIterator(blockManager, @@ -71,4 +72,58 @@ class BlockFetcherIteratorSuite extends FunSuite with Matchers { } } + test("block fetch from remote succeed using BasicBlockFetcherIterator") { + val blockManager = mock(classOf[BlockManager]) + val connManager = mock(classOf[ConnectionManager]) + when(blockManager.connectionManager).thenReturn(connManager) + + val blId1 = ShuffleBlockId(0,0,0) + val blId2 = ShuffleBlockId(0,1,0) + val buf1 = ByteBuffer.allocate(4) + val buf2 = ByteBuffer.allocate(4) + buf1.putInt(1) + buf1.flip() + buf2.putInt(1) + buf2.flip() + val blockMessage1 = BlockMessage.fromGotBlock(GotBlock(blId1, buf1)) + val blockMessage2 = BlockMessage.fromGotBlock(GotBlock(blId2, buf2)) + val blockMessageArray = new BlockMessageArray( + Seq(blockMessage1, blockMessage2)) + + val bufferMessage = blockMessageArray.toBufferMessage + val buffer = ByteBuffer.allocate(bufferMessage.size) + val arrayBuffer = new ArrayBuffer[ByteBuffer] + bufferMessage.buffers.foreach{ b => + buffer.put(b) + } + buffer.flip + arrayBuffer += buffer + + val someMessage = Some(Message.createBufferMessage(arrayBuffer)) + + val f = future { + someMessage + } + when(connManager.sendMessageReliably(any(), + any())).thenReturn(f) + when(blockManager.futureExecContext).thenReturn(global) + + when(blockManager.blockManagerId).thenReturn( + BlockManagerId("test-client", "test-client", 1, 0)) + when(blockManager.maxBytesInFlight).thenReturn(48 * 1024 * 1024) + + val bmId = BlockManagerId("test-server", "test-server",1 , 0) + val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long)])]( + (bmId, Seq((blId1, 1), (blId2, 1))) + ) + + val iterator = new BasicBlockFetcherIterator(blockManager, + blocksByAddress, null) + iterator.initialize() + iterator.foreach{ + case (_, r) => { + (r.isDefined) should be(true) + } + } + } } From 22d7ebde6d706e4336154af5bec87f36d1501b20 Mon Sep 17 00:00:00 2001 From: Kousuke Saruta Date: Wed, 23 Jul 2014 17:45:01 -0700 Subject: [PATCH 05/14] Add test cases to BlockManagerSuite for SPARK-2583 --- .../storage/BlockFetcherIteratorSuite.scala | 2 +- .../spark/storage/BlockManagerSuite.scala | 112 +++++++++++++++++- 2 files changed, 112 insertions(+), 2 deletions(-) diff --git a/core/src/test/scala/org/apache/spark/storage/BlockFetcherIteratorSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockFetcherIteratorSuite.scala index 6adb9f51436d0..db37efbf70fe6 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockFetcherIteratorSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockFetcherIteratorSuite.scala @@ -96,7 +96,7 @@ class BlockFetcherIteratorSuite extends FunSuite with Matchers { bufferMessage.buffers.foreach{ b => buffer.put(b) } - buffer.flip + buffer.flip() arrayBuffer += buffer val someMessage = Some(Message.createBufferMessage(arrayBuffer)) diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala index 23cb6905bfdeb..f7f1ded1d4a67 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala @@ -24,7 +24,10 @@ import akka.actor._ import org.apache.spark.SparkConf import org.apache.spark.serializer.{JavaSerializer, KryoSerializer} import org.apache.spark.util.{AkkaUtils, ByteBufferInputStream, SizeEstimator, Utils} -import org.mockito.Mockito.{mock, when} +import org.mockito.invocation.InvocationOnMock +import org.mockito.Mockito.{doAnswer, mock, spy, when} +import org.mockito.stubbing.Answer +import org.mockito.Matchers.any import org.scalatest.{BeforeAndAfter, FunSuite, PrivateMethodTester} import org.scalatest.concurrent.Eventually._ import org.scalatest.concurrent.Timeouts._ @@ -33,8 +36,11 @@ import org.scalatest.time.SpanSugar._ import org.apache.spark.{MapOutputTrackerMaster, SecurityManager, SparkConf, SparkContext} import org.apache.spark.executor.DataReadMethod +import org.apache.spark.network.{BufferMessage, ConnectionManagerId, Message} import org.apache.spark.scheduler.LiveListenerBus import org.apache.spark.serializer.{JavaSerializer, KryoSerializer} +import org.apache.spark.storage.{BlockMessage, BlockMessageArray} +import org.apache.spark.storage.BlockMessage._ import org.apache.spark.util.{AkkaUtils, ByteBufferInputStream, SizeEstimator, Utils} import scala.collection.mutable.ArrayBuffer @@ -1015,4 +1021,108 @@ class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfter assert(store.memoryStore.contains(rdd(0, 1)), "rdd_0_1 was not in store") assert(!store.memoryStore.contains(rdd(1, 0)), "rdd_1_0 was in store") } + + test("return error message when error occurred in BlockManagerWorker#onBlockMessageReceive") { + store = new BlockManager("", actorSystem, master, serializer, 1200, conf, + securityMgr, mapOutputTracker) + + val worker = spy(new BlockManagerWorker(store)) + val connManagerId = mock(classOf[ConnectionManagerId]) + + // setup request block messages + val reqBlId1 = ShuffleBlockId(0,0,0) + val reqBlId2 = ShuffleBlockId(0,1,0) + val reqBlockMessage1 = BlockMessage.fromGetBlock(GetBlock(reqBlId1)) + val reqBlockMessage2 = BlockMessage.fromGetBlock(GetBlock(reqBlId2)) + val reqBlockMessages = new BlockMessageArray( + Seq(reqBlockMessage1, reqBlockMessage2)) + val reqBufferMessage = reqBlockMessages.toBufferMessage + + val answer = new Answer[Option[BlockMessage]] { + override def answer(invocation: InvocationOnMock) + :Option[BlockMessage]= { + throw new Exception + } + } + + doAnswer(answer).when(worker).processBlockMessage(any()) + + // Test when exception was thrown during processing block messages + var ackMessage = worker.onBlockMessageReceive(reqBufferMessage, connManagerId) + + assert(ackMessage.isDefined, "When Exception was thrown in " + + "BlockManagerWorker#processBlockMessage, " + + "ackMessage should be defined") + assert(ackMessage.get.hasError, "When Exception was thown in " + + "BlockManagerWorker#processBlockMessage, " + + "ackMessage should have error") + + val notBufferMessage = mock(classOf[Message]) + + // Test when not BufferMessage was received + ackMessage = worker.onBlockMessageReceive(notBufferMessage, connManagerId) + assert(ackMessage.isDefined, "When not BufferMessage was passed to " + + "BlockManagerWorker#onBlockMessageReceive, " + + "ackMessage should be defined") + assert(ackMessage.get.hasError, "When not BufferMessage was passed to " + + "BlockManagerWorker#onBlockMessageReceive, " + + "ackMessage should have error") + } + + test("return ack message when no error occurred in BlocManagerWorker#onBlockMessageReceive") { + store = new BlockManager("", actorSystem, master, serializer, 1200, conf, + securityMgr, mapOutputTracker) + + val worker = spy(new BlockManagerWorker(store)) + val connManagerId = mock(classOf[ConnectionManagerId]) + + // setup request block messages + val reqBlId1 = ShuffleBlockId(0,0,0) + val reqBlId2 = ShuffleBlockId(0,1,0) + val reqBlockMessage1 = BlockMessage.fromGetBlock(GetBlock(reqBlId1)) + val reqBlockMessage2 = BlockMessage.fromGetBlock(GetBlock(reqBlId2)) + val reqBlockMessages = new BlockMessageArray( + Seq(reqBlockMessage1, reqBlockMessage2)) + + val tmpBufferMessage = reqBlockMessages.toBufferMessage + val buffer = ByteBuffer.allocate(tmpBufferMessage.size) + val arrayBuffer = new ArrayBuffer[ByteBuffer] + tmpBufferMessage.buffers.foreach{ b => + buffer.put(b) + } + buffer.flip() + arrayBuffer += buffer + val reqBufferMessage = Message.createBufferMessage(arrayBuffer) + + // setup ack block messages + val buf1 = ByteBuffer.allocate(4) + val buf2 = ByteBuffer.allocate(4) + buf1.putInt(1) + buf1.flip() + buf2.putInt(1) + buf2.flip() + val ackBlockMessage1 = BlockMessage.fromGotBlock(GotBlock(reqBlId1, buf1)) + val ackBlockMessage2 = BlockMessage.fromGotBlock(GotBlock(reqBlId2, buf2)) + + val answer = new Answer[Option[BlockMessage]] { + override def answer(invocation: InvocationOnMock) + :Option[BlockMessage]= { + if (invocation.getArguments()(0).asInstanceOf[BlockMessage].eq( + reqBlockMessage1)) { + return Some(ackBlockMessage1) + } else { + return Some(ackBlockMessage2) + } + } + } + + doAnswer(answer).when(worker).processBlockMessage(any()) + + val ackMessage = worker.onBlockMessageReceive(reqBufferMessage, connManagerId) + assert(ackMessage.isDefined, "When BlockManagerWorker#onBlockMessageReceive " + + "was executed successfully, ackMessage should be defined") + assert(!ackMessage.get.hasError, "When BlockManagerWorker#onBlockMessageReceive " + + "was executed successfully, ackMessage should not have error") + } + } From 326a17f05e078f81486a9c64361b40eb991d5e01 Mon Sep 17 00:00:00 2001 From: Kousuke Saruta Date: Thu, 24 Jul 2014 00:21:42 -0700 Subject: [PATCH 06/14] Add test cases to ConnectionManagerSuite.scala for SPARK-2583 --- .../network/ConnectionManagerSuite.scala | 24 +++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/core/src/test/scala/org/apache/spark/network/ConnectionManagerSuite.scala b/core/src/test/scala/org/apache/spark/network/ConnectionManagerSuite.scala index 415ad8c432c12..162d0a016c136 100644 --- a/core/src/test/scala/org/apache/spark/network/ConnectionManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/network/ConnectionManagerSuite.scala @@ -223,7 +223,31 @@ class ConnectionManagerSuite extends FunSuite { managerServer.stop() } + test("Ack error message") { + val conf = new SparkConf + conf.set("spark.authenticate", "false") + val securityManager = new SecurityManager(conf) + val manager = new ConnectionManager(0, conf, securityManager) + val managerServer = new ConnectionManager(0, conf, securityManager) + managerServer.onReceiveMessage((msg: Message, id: ConnectionManagerId) => { + throw new Exception + }) + val size = 10 * 1024 * 1024 + val buffer = ByteBuffer.allocate(size).put(Array.tabulate[Byte](size)(x => x.toByte)) + buffer.flip + val bufferMessage = Message.createBufferMessage(buffer) + + val future = manager.sendMessageReliably(managerServer.id, bufferMessage) + + val message = Await.result(future, 1 second) + assert(message.isDefined) + assert(message.get.hasError) + + manager.stop() + managerServer.stop() + + } } From ee91bb793666b260c9018e532dcf32c7115b3194 Mon Sep 17 00:00:00 2001 From: Kousuke Saruta Date: Wed, 30 Jul 2014 04:57:42 +0900 Subject: [PATCH 07/14] Modified BufferMessage.scala to keep the spark code style --- .../main/scala/org/apache/spark/network/BufferMessage.scala | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/network/BufferMessage.scala b/core/src/main/scala/org/apache/spark/network/BufferMessage.scala index 09541e721fe81..af35f1fc3e459 100644 --- a/core/src/main/scala/org/apache/spark/network/BufferMessage.scala +++ b/core/src/main/scala/org/apache/spark/network/BufferMessage.scala @@ -66,7 +66,8 @@ class BufferMessage(id_ : Int, val buffers: ArrayBuffer[ByteBuffer], var ackId: } buffer.position(buffer.position + newBuffer.remaining) val newChunk = new MessageChunk(new MessageChunkHeader( - typ, id, size, newBuffer.remaining, ackId, hasError, security, senderAddress), newBuffer) + typ, id, size, newBuffer.remaining, ackId, + hasError, security, senderAddress), newBuffer) gotChunkForSendingOnce = true return Some(newChunk) } From f1cd1bb103607adc9b8d6bbcaac9eea61a1696ec Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Sun, 3 Aug 2014 17:58:26 -0700 Subject: [PATCH 08/14] Clean up @sarutak's PR #1490 for [SPARK-2583]: ConnectionManager error reporting Use Futures to signal failures, rather than exposing empty messages to user code. --- .../spark/network/ConnectionManager.scala | 94 +++++++++++-------- .../org/apache/spark/network/SenderTest.scala | 3 +- .../spark/storage/BlockFetcherIterator.scala | 36 +++---- .../spark/storage/BlockManagerWorker.scala | 16 ++-- .../network/ConnectionManagerSuite.scala | 7 +- .../storage/BlockFetcherIteratorSuite.scala | 7 +- 6 files changed, 87 insertions(+), 76 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/network/ConnectionManager.scala b/core/src/main/scala/org/apache/spark/network/ConnectionManager.scala index ac9321665c251..ddf6b6e5404c1 100644 --- a/core/src/main/scala/org/apache/spark/network/ConnectionManager.scala +++ b/core/src/main/scala/org/apache/spark/network/ConnectionManager.scala @@ -17,6 +17,7 @@ package org.apache.spark.network +import java.io.IOException import java.nio._ import java.nio.channels._ import java.nio.channels.spi._ @@ -41,16 +42,26 @@ import org.apache.spark.util.{SystemClock, Utils} private[spark] class ConnectionManager(port: Int, conf: SparkConf, securityManager: SecurityManager) extends Logging { + /** + * Used by sendMessageReliably to track messages being sent. + * @param message the message that was sent + * @param connectionManagerId the connection manager that sent this message + * @param completionHandler callback that's invoked when the send has completed or failed + */ class MessageStatus( val message: Message, val connectionManagerId: ConnectionManagerId, completionHandler: MessageStatus => Unit) { + /** This is non-None if message has been ack'd */ var ackMessage: Option[Message] = None - var attempted = false - var acked = false - def markDone() { completionHandler(this) } + def markDone(ackMessage: Option[Message]) { + this.synchronized { + this.ackMessage = ackMessage + completionHandler(this) + } + } } private val selector = SelectorProvider.provider.openSelector() @@ -434,11 +445,7 @@ private[spark] class ConnectionManager(port: Int, conf: SparkConf, messageStatuses.values.filter(_.connectionManagerId == sendingConnectionManagerId) .foreach(status => { logInfo("Notifying " + status) - status.synchronized { - status.attempted = true - status.acked = false - status.markDone() - } + status.markDone(None) }) messageStatuses.retain((i, status) => { @@ -467,11 +474,7 @@ private[spark] class ConnectionManager(port: Int, conf: SparkConf, for (s <- messageStatuses.values if s.connectionManagerId == sendingConnectionManagerId) { logInfo("Notifying " + s) - s.synchronized { - s.attempted = true - s.acked = false - s.markDone() - } + s.markDone(None) } messageStatuses.retain((i, status) => { @@ -539,13 +542,13 @@ private[spark] class ConnectionManager(port: Int, conf: SparkConf, val securityMsgResp = SecurityMessage.fromResponse(replyToken, securityMsg.getConnectionId.toString) val message = securityMsgResp.toBufferMessage - if (message == null) throw new Exception("Error creating security message") + if (message == null) throw new IOException("Error creating security message") sendSecurityMessage(waitingConn.getRemoteConnectionManagerId(), message) } catch { case e: Exception => { logError("Error handling sasl client authentication", e) waitingConn.close() - throw new Exception("Error evaluating sasl response: " + e) + throw new IOException("Error evaluating sasl response: " + e) } } } @@ -653,12 +656,7 @@ private[spark] class ConnectionManager(port: Int, conf: SparkConf, } } } - sentMessageStatus.synchronized { - sentMessageStatus.ackMessage = Some(message) - sentMessageStatus.attempted = true - sentMessageStatus.acked = true - sentMessageStatus.markDone() - } + sentMessageStatus.markDone(Some(message)) } else { var ackMessage : Option[Message] = None try { @@ -681,7 +679,7 @@ private[spark] class ConnectionManager(port: Int, conf: SparkConf, } } catch { case e: Exception => { - logError(s"Exception was thrown during processing message", e) + logError(s"Exception was thrown while processing message", e) val m = Message.createBufferMessage(bufferMessage.id) m.hasError = true ackMessage = Some(m) @@ -802,11 +800,7 @@ private[spark] class ConnectionManager(port: Int, conf: SparkConf, case Some(msgStatus) => { messageStatuses -= message.id logInfo("Notifying " + msgStatus.connectionManagerId) - msgStatus.synchronized { - msgStatus.attempted = true - msgStatus.acked = false - msgStatus.markDone() - } + msgStatus.markDone(None) } case None => { logError("no messageStatus for failed message id: " + message.id) @@ -825,11 +819,28 @@ private[spark] class ConnectionManager(port: Int, conf: SparkConf, selector.wakeup() } + /** + * Send a message and block until an acknowldgment is received or an error occurs. + * @param connectionManagerId the message's destination + * @param message the message being sent + * @return a Future that either returns the acknowledgment message or captures an exception. + */ def sendMessageReliably(connectionManagerId: ConnectionManagerId, message: Message) - : Future[Option[Message]] = { - val promise = Promise[Option[Message]] - val status = new MessageStatus( - message, connectionManagerId, s => promise.success(s.ackMessage)) + : Future[Message] = { + val promise = Promise[Message]() + val status = new MessageStatus(message, connectionManagerId, s => { + s.ackMessage match { + case None => // Indicates a failure where we either never sent or never got ACK'd + promise.failure(new IOException("sendMessageReliably failed without being ACK'd")) + case Some(ackMessage) => + if (ackMessage.hasError) { + promise.failure( + new IOException("sendMessageReliably failed with ACK that signalled a remote error")) + } else { + promise.success(ackMessage) + } + } + }) messageStatuses.synchronized { messageStatuses += ((message.id, status)) } @@ -838,7 +849,7 @@ private[spark] class ConnectionManager(port: Int, conf: SparkConf, } def sendMessageReliablySync(connectionManagerId: ConnectionManagerId, - message: Message): Option[Message] = { + message: Message): Message = { Await.result(sendMessageReliably(connectionManagerId, message), Duration.Inf) } @@ -864,6 +875,7 @@ private[spark] class ConnectionManager(port: Int, conf: SparkConf, private[spark] object ConnectionManager { + import ExecutionContext.Implicits.global def main(args: Array[String]) { val conf = new SparkConf @@ -919,8 +931,10 @@ private[spark] object ConnectionManager { val bufferMessage = Message.createBufferMessage(buffer.duplicate) manager.sendMessageReliably(manager.id, bufferMessage) }).foreach(f => { - val g = Await.result(f, 1 second) - if (!g.isDefined) println("Failed") + f.onFailure { + case e => println("Failed due to " + e) + } + Await.ready(f, 1 second) }) val finishTime = System.currentTimeMillis @@ -954,8 +968,10 @@ private[spark] object ConnectionManager { val bufferMessage = Message.createBufferMessage(buffers(count - 1 - i).duplicate) manager.sendMessageReliably(manager.id, bufferMessage) }).foreach(f => { - val g = Await.result(f, 1 second) - if (!g.isDefined) println("Failed") + f.onFailure { + case e => println("Failed due to " + e) + } + Await.ready(f, 1 second) }) val finishTime = System.currentTimeMillis @@ -984,8 +1000,10 @@ private[spark] object ConnectionManager { val bufferMessage = Message.createBufferMessage(buffer.duplicate) manager.sendMessageReliably(manager.id, bufferMessage) }).foreach(f => { - val g = Await.result(f, 1 second) - if (!g.isDefined) println("Failed") + f.onFailure { + case e => println("Failed due to " + e) + } + Await.ready(f, 1 second) }) val finishTime = System.currentTimeMillis Thread.sleep(1000) diff --git a/core/src/main/scala/org/apache/spark/network/SenderTest.scala b/core/src/main/scala/org/apache/spark/network/SenderTest.scala index b8ea7c2cff9a2..f15eaef80d3e1 100644 --- a/core/src/main/scala/org/apache/spark/network/SenderTest.scala +++ b/core/src/main/scala/org/apache/spark/network/SenderTest.scala @@ -18,6 +18,7 @@ package org.apache.spark.network import java.nio.ByteBuffer +import scala.util.Try import org.apache.spark.{SecurityManager, SparkConf} private[spark] object SenderTest { @@ -51,7 +52,7 @@ private[spark] object SenderTest { val dataMessage = Message.createBufferMessage(buffer.duplicate) val startTime = System.currentTimeMillis /* println("Started timer at " + startTime) */ - val responseStr = manager.sendMessageReliablySync(targetConnectionManagerId, dataMessage) + val responseStr = Try(manager.sendMessageReliablySync(targetConnectionManagerId, dataMessage)) .map { response => val buffer = response.asInstanceOf[BufferMessage].buffers(0) new String(buffer.array, "utf-8") diff --git a/core/src/main/scala/org/apache/spark/storage/BlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/BlockFetcherIterator.scala index 59c97e1dece4a..f05fadf5c26da 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockFetcherIterator.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockFetcherIterator.scala @@ -22,6 +22,7 @@ import java.util.concurrent.LinkedBlockingQueue import scala.collection.mutable.ArrayBuffer import scala.collection.mutable.HashSet import scala.collection.mutable.Queue +import scala.util.{Failure, Success} import io.netty.buffer.ByteBuf @@ -118,31 +119,24 @@ object BlockFetcherIterator { bytesInFlight += req.size val sizeMap = req.blocks.toMap // so we can look up the size of each blockID val future = connectionManager.sendMessageReliably(cmId, blockMessageArray.toBufferMessage) - future.onSuccess { - case Some(message) => { + future.onComplete { + case Success(message) => { val bufferMessage = message.asInstanceOf[BufferMessage] - if (bufferMessage.hasError) { - logError("Could not get block(s) from " + cmId) - for ((blockId, size) <- req.blocks) { - results.put(new FetchResult(blockId, -1, null)) - } - } else { - val blockMessageArray = BlockMessageArray.fromBufferMessage(bufferMessage) - for (blockMessage <- blockMessageArray) { - if (blockMessage.getType != BlockMessage.TYPE_GOT_BLOCK) { - throw new SparkException( - "Unexpected message " + blockMessage.getType + " received from " + cmId) - } - val blockId = blockMessage.getId - val networkSize = blockMessage.getData.limit() - results.put(new FetchResult(blockId, sizeMap(blockId), - () => dataDeserialize(blockId, blockMessage.getData, serializer))) - _remoteBytesRead += networkSize - logDebug("Got remote block " + blockId + " after " + Utils.getUsedTimeMs(startTime)) + val blockMessageArray = BlockMessageArray.fromBufferMessage(bufferMessage) + for (blockMessage <- blockMessageArray) { + if (blockMessage.getType != BlockMessage.TYPE_GOT_BLOCK) { + throw new SparkException( + "Unexpected message " + blockMessage.getType + " received from " + cmId) } + val blockId = blockMessage.getId + val networkSize = blockMessage.getData.limit() + results.put(new FetchResult(blockId, sizeMap(blockId), + () => dataDeserialize(blockId, blockMessage.getData, serializer))) + _remoteBytesRead += networkSize + logDebug("Got remote block " + blockId + " after " + Utils.getUsedTimeMs(startTime)) } } - case None => { + case Failure(exception) => { logError("Could not get block(s) from " + cmId) for ((blockId, size) <- req.blocks) { results.put(new FetchResult(blockId, -1, null)) diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerWorker.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerWorker.scala index 9c459cd3a3437..518cc5e419102 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerWorker.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerWorker.scala @@ -23,6 +23,8 @@ import org.apache.spark.Logging import org.apache.spark.network._ import org.apache.spark.util.Utils +import scala.util.{Failure, Success, Try} + /** * A network interface for BlockManager. Each slave should have one * BlockManagerWorker. @@ -115,9 +117,9 @@ private[spark] object BlockManagerWorker extends Logging { val connectionManager = blockManager.connectionManager val blockMessage = BlockMessage.fromPutBlock(msg) val blockMessageArray = new BlockMessageArray(blockMessage) - val resultMessage = connectionManager.sendMessageReliablySync( - toConnManagerId, blockMessageArray.toBufferMessage) - resultMessage.isDefined + val resultMessage = Try(connectionManager.sendMessageReliablySync( + toConnManagerId, blockMessageArray.toBufferMessage)) + resultMessage.isSuccess } def syncGetBlock(msg: GetBlock, toConnManagerId: ConnectionManagerId): ByteBuffer = { @@ -125,10 +127,10 @@ private[spark] object BlockManagerWorker extends Logging { val connectionManager = blockManager.connectionManager val blockMessage = BlockMessage.fromGetBlock(msg) val blockMessageArray = new BlockMessageArray(blockMessage) - val responseMessage = connectionManager.sendMessageReliablySync( - toConnManagerId, blockMessageArray.toBufferMessage) + val responseMessage = Try(connectionManager.sendMessageReliablySync( + toConnManagerId, blockMessageArray.toBufferMessage)) responseMessage match { - case Some(message) => { + case Success(message) => { val bufferMessage = message.asInstanceOf[BufferMessage] logDebug("Response message received " + bufferMessage) BlockMessageArray.fromBufferMessage(bufferMessage).foreach(blockMessage => { @@ -136,7 +138,7 @@ private[spark] object BlockManagerWorker extends Logging { return blockMessage.getData }) } - case None => logDebug("No response message received") + case Failure(exception) => logDebug("No response message received") } null } diff --git a/core/src/test/scala/org/apache/spark/network/ConnectionManagerSuite.scala b/core/src/test/scala/org/apache/spark/network/ConnectionManagerSuite.scala index 162d0a016c136..03f36fe50d9c4 100644 --- a/core/src/test/scala/org/apache/spark/network/ConnectionManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/network/ConnectionManagerSuite.scala @@ -25,6 +25,7 @@ import org.scalatest.FunSuite import scala.concurrent.{Await, TimeoutException} import scala.concurrent.duration._ import scala.language.postfixOps +import scala.util.Try /** * Test the ConnectionManager with various security settings. @@ -209,7 +210,6 @@ class ConnectionManagerSuite extends FunSuite { }).foreach(f => { try { val g = Await.result(f, 1 second) - if (!g.isDefined) assert(false) else assert(true) } catch { case e: Exception => { assert(false) @@ -240,9 +240,8 @@ class ConnectionManagerSuite extends FunSuite { val future = manager.sendMessageReliably(managerServer.id, bufferMessage) - val message = Await.result(future, 1 second) - assert(message.isDefined) - assert(message.get.hasError) + val message = Try(Await.result(future, 1 second)) + assert(message.isFailure) manager.stop() managerServer.stop() diff --git a/core/src/test/scala/org/apache/spark/storage/BlockFetcherIteratorSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockFetcherIteratorSuite.scala index db867f21e0cb9..0bcf42c981cea 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockFetcherIteratorSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockFetcherIteratorSuite.scala @@ -148,8 +148,7 @@ class BlockFetcherIteratorSuite extends FunSuite with Matchers { val f = future { val message = Message.createBufferMessage(0) message.hasError = true - val someMessage = Some(message) - someMessage + message } when(connManager.sendMessageReliably(any(), any())).thenReturn(f) @@ -204,10 +203,8 @@ class BlockFetcherIteratorSuite extends FunSuite with Matchers { buffer.flip() arrayBuffer += buffer - val someMessage = Some(Message.createBufferMessage(arrayBuffer)) - val f = future { - someMessage + Message.createBufferMessage(arrayBuffer) } when(connManager.sendMessageReliably(any(), any())).thenReturn(f) From c01c4501446bde1c5c045e2ac57c1be10c7d1b18 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Mon, 4 Aug 2014 14:07:10 -0700 Subject: [PATCH 09/14] Return Try[Message] from sendMessageReliablySync. This forces callers to consider error handling. --- .../org/apache/spark/network/ConnectionManager.scala | 5 +++-- .../scala/org/apache/spark/network/SenderTest.scala | 3 +-- .../org/apache/spark/storage/BlockManagerWorker.scala | 10 +++++----- 3 files changed, 9 insertions(+), 9 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/network/ConnectionManager.scala b/core/src/main/scala/org/apache/spark/network/ConnectionManager.scala index ddf6b6e5404c1..9440bcbe009ba 100644 --- a/core/src/main/scala/org/apache/spark/network/ConnectionManager.scala +++ b/core/src/main/scala/org/apache/spark/network/ConnectionManager.scala @@ -35,6 +35,7 @@ import scala.collection.mutable.SynchronizedQueue import scala.concurrent.{Await, ExecutionContext, Future, Promise} import scala.concurrent.duration._ import scala.language.postfixOps +import scala.util.Try import org.apache.spark._ import org.apache.spark.util.{SystemClock, Utils} @@ -849,8 +850,8 @@ private[spark] class ConnectionManager(port: Int, conf: SparkConf, } def sendMessageReliablySync(connectionManagerId: ConnectionManagerId, - message: Message): Message = { - Await.result(sendMessageReliably(connectionManagerId, message), Duration.Inf) + message: Message): Try[Message] = { + Try(Await.result(sendMessageReliably(connectionManagerId, message), Duration.Inf)) } def onReceiveMessage(callback: (Message, ConnectionManagerId) => Option[Message]) { diff --git a/core/src/main/scala/org/apache/spark/network/SenderTest.scala b/core/src/main/scala/org/apache/spark/network/SenderTest.scala index f15eaef80d3e1..b8ea7c2cff9a2 100644 --- a/core/src/main/scala/org/apache/spark/network/SenderTest.scala +++ b/core/src/main/scala/org/apache/spark/network/SenderTest.scala @@ -18,7 +18,6 @@ package org.apache.spark.network import java.nio.ByteBuffer -import scala.util.Try import org.apache.spark.{SecurityManager, SparkConf} private[spark] object SenderTest { @@ -52,7 +51,7 @@ private[spark] object SenderTest { val dataMessage = Message.createBufferMessage(buffer.duplicate) val startTime = System.currentTimeMillis /* println("Started timer at " + startTime) */ - val responseStr = Try(manager.sendMessageReliablySync(targetConnectionManagerId, dataMessage)) + val responseStr = manager.sendMessageReliablySync(targetConnectionManagerId, dataMessage) .map { response => val buffer = response.asInstanceOf[BufferMessage].buffers(0) new String(buffer.array, "utf-8") diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerWorker.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerWorker.scala index 518cc5e419102..35f99d7569fe8 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerWorker.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerWorker.scala @@ -23,7 +23,7 @@ import org.apache.spark.Logging import org.apache.spark.network._ import org.apache.spark.util.Utils -import scala.util.{Failure, Success, Try} +import scala.util.{Failure, Success} /** * A network interface for BlockManager. Each slave should have one @@ -117,8 +117,8 @@ private[spark] object BlockManagerWorker extends Logging { val connectionManager = blockManager.connectionManager val blockMessage = BlockMessage.fromPutBlock(msg) val blockMessageArray = new BlockMessageArray(blockMessage) - val resultMessage = Try(connectionManager.sendMessageReliablySync( - toConnManagerId, blockMessageArray.toBufferMessage)) + val resultMessage = connectionManager.sendMessageReliablySync( + toConnManagerId, blockMessageArray.toBufferMessage) resultMessage.isSuccess } @@ -127,8 +127,8 @@ private[spark] object BlockManagerWorker extends Logging { val connectionManager = blockManager.connectionManager val blockMessage = BlockMessage.fromGetBlock(msg) val blockMessageArray = new BlockMessageArray(blockMessage) - val responseMessage = Try(connectionManager.sendMessageReliablySync( - toConnManagerId, blockMessageArray.toBufferMessage)) + val responseMessage = connectionManager.sendMessageReliablySync( + toConnManagerId, blockMessageArray.toBufferMessage) responseMessage match { case Success(message) => { val bufferMessage = message.asInstanceOf[BufferMessage] From a2f745ca7a3d80b190197786bee4cec44b7d5f9d Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Tue, 5 Aug 2014 13:52:55 -0700 Subject: [PATCH 10/14] Remove sendMessageReliablySync; callers can wait themselves. This makes the waiting more explicit. --- .../org/apache/spark/network/ConnectionManager.scala | 8 +------- .../scala/org/apache/spark/network/SenderTest.scala | 7 ++++++- .../apache/spark/storage/BlockManagerWorker.scala | 12 +++++++----- .../spark/network/ConnectionManagerSuite.scala | 6 +++--- 4 files changed, 17 insertions(+), 16 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/network/ConnectionManager.scala b/core/src/main/scala/org/apache/spark/network/ConnectionManager.scala index 9440bcbe009ba..e47e316599301 100644 --- a/core/src/main/scala/org/apache/spark/network/ConnectionManager.scala +++ b/core/src/main/scala/org/apache/spark/network/ConnectionManager.scala @@ -35,7 +35,6 @@ import scala.collection.mutable.SynchronizedQueue import scala.concurrent.{Await, ExecutionContext, Future, Promise} import scala.concurrent.duration._ import scala.language.postfixOps -import scala.util.Try import org.apache.spark._ import org.apache.spark.util.{SystemClock, Utils} @@ -849,11 +848,6 @@ private[spark] class ConnectionManager(port: Int, conf: SparkConf, promise.future } - def sendMessageReliablySync(connectionManagerId: ConnectionManagerId, - message: Message): Try[Message] = { - Try(Await.result(sendMessageReliably(connectionManagerId, message), Duration.Inf)) - } - def onReceiveMessage(callback: (Message, ConnectionManagerId) => Option[Message]) { onReceiveCallback = callback } @@ -911,7 +905,7 @@ private[spark] object ConnectionManager { (0 until count).map(i => { val bufferMessage = Message.createBufferMessage(buffer.duplicate) - manager.sendMessageReliablySync(manager.id, bufferMessage) + Await.result(manager.sendMessageReliably(manager.id, bufferMessage), Duration.Inf) }) println("--------------------------") println() diff --git a/core/src/main/scala/org/apache/spark/network/SenderTest.scala b/core/src/main/scala/org/apache/spark/network/SenderTest.scala index b8ea7c2cff9a2..ea2ad104ecae1 100644 --- a/core/src/main/scala/org/apache/spark/network/SenderTest.scala +++ b/core/src/main/scala/org/apache/spark/network/SenderTest.scala @@ -20,6 +20,10 @@ package org.apache.spark.network import java.nio.ByteBuffer import org.apache.spark.{SecurityManager, SparkConf} +import scala.concurrent.Await +import scala.concurrent.duration.Duration +import scala.util.Try + private[spark] object SenderTest { def main(args: Array[String]) { @@ -51,7 +55,8 @@ private[spark] object SenderTest { val dataMessage = Message.createBufferMessage(buffer.duplicate) val startTime = System.currentTimeMillis /* println("Started timer at " + startTime) */ - val responseStr = manager.sendMessageReliablySync(targetConnectionManagerId, dataMessage) + val promise = manager.sendMessageReliably(targetConnectionManagerId, dataMessage) + val responseStr: String = Try(Await.result(promise, Duration.Inf)) .map { response => val buffer = response.asInstanceOf[BufferMessage].buffers(0) new String(buffer.array, "utf-8") diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerWorker.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerWorker.scala index 35f99d7569fe8..bf002a42d5dc5 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerWorker.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerWorker.scala @@ -23,7 +23,9 @@ import org.apache.spark.Logging import org.apache.spark.network._ import org.apache.spark.util.Utils -import scala.util.{Failure, Success} +import scala.concurrent.Await +import scala.concurrent.duration.Duration +import scala.util.{Try, Failure, Success} /** * A network interface for BlockManager. Each slave should have one @@ -117,8 +119,8 @@ private[spark] object BlockManagerWorker extends Logging { val connectionManager = blockManager.connectionManager val blockMessage = BlockMessage.fromPutBlock(msg) val blockMessageArray = new BlockMessageArray(blockMessage) - val resultMessage = connectionManager.sendMessageReliablySync( - toConnManagerId, blockMessageArray.toBufferMessage) + val resultMessage = Try(Await.result(connectionManager.sendMessageReliably( + toConnManagerId, blockMessageArray.toBufferMessage), Duration.Inf)) resultMessage.isSuccess } @@ -127,8 +129,8 @@ private[spark] object BlockManagerWorker extends Logging { val connectionManager = blockManager.connectionManager val blockMessage = BlockMessage.fromGetBlock(msg) val blockMessageArray = new BlockMessageArray(blockMessage) - val responseMessage = connectionManager.sendMessageReliablySync( - toConnManagerId, blockMessageArray.toBufferMessage) + val responseMessage = Try(Await.result(connectionManager.sendMessageReliably( + toConnManagerId, blockMessageArray.toBufferMessage), Duration.Inf)) responseMessage match { case Success(message) => { val bufferMessage = message.asInstanceOf[BufferMessage] diff --git a/core/src/test/scala/org/apache/spark/network/ConnectionManagerSuite.scala b/core/src/test/scala/org/apache/spark/network/ConnectionManagerSuite.scala index 03f36fe50d9c4..32d7fadb95b00 100644 --- a/core/src/test/scala/org/apache/spark/network/ConnectionManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/network/ConnectionManagerSuite.scala @@ -47,7 +47,7 @@ class ConnectionManagerSuite extends FunSuite { buffer.flip val bufferMessage = Message.createBufferMessage(buffer.duplicate) - manager.sendMessageReliablySync(manager.id, bufferMessage) + Await.ready(manager.sendMessageReliably(manager.id, bufferMessage), 10 seconds) assert(receivedMessage == true) @@ -80,7 +80,7 @@ class ConnectionManagerSuite extends FunSuite { (0 until count).map(i => { val bufferMessage = Message.createBufferMessage(buffer.duplicate) - manager.sendMessageReliablySync(managerServer.id, bufferMessage) + Await.ready(manager.sendMessageReliably(manager.id, bufferMessage), 10 seconds) }) assert(numReceivedServerMessages == 10) @@ -119,7 +119,7 @@ class ConnectionManagerSuite extends FunSuite { val buffer = ByteBuffer.allocate(size).put(Array.tabulate[Byte](size)(x => x.toByte)) buffer.flip val bufferMessage = Message.createBufferMessage(buffer.duplicate) - manager.sendMessageReliablySync(managerServer.id, bufferMessage) + Await.result(manager.sendMessageReliably(manager.id, bufferMessage), 10 seconds) assert(numReceivedServerMessages == 0) assert(numReceivedMessages == 0) From 659521fc4e52ffd4a10eee7dba213018a7a9cb7e Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Wed, 6 Aug 2014 11:18:20 -0700 Subject: [PATCH 11/14] Include previous exception when throwing new one --- .../main/scala/org/apache/spark/network/ConnectionManager.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/network/ConnectionManager.scala b/core/src/main/scala/org/apache/spark/network/ConnectionManager.scala index e47e316599301..2492968d5ca7c 100644 --- a/core/src/main/scala/org/apache/spark/network/ConnectionManager.scala +++ b/core/src/main/scala/org/apache/spark/network/ConnectionManager.scala @@ -548,7 +548,7 @@ private[spark] class ConnectionManager(port: Int, conf: SparkConf, case e: Exception => { logError("Error handling sasl client authentication", e) waitingConn.close() - throw new IOException("Error evaluating sasl response: " + e) + throw new IOException("Error evaluating sasl response: ", e) } } } From b8bb4d487b0044c7f3505188aa3933cee05d3d5e Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Wed, 6 Aug 2014 14:01:33 -0700 Subject: [PATCH 12/14] Fix manager.id vs managerServer.id typo that broke security tests. Intercept IOException that's now thrown when a sending connection is closed while there are unacknowledged messages. --- .../spark/network/ConnectionManagerSuite.scala | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/core/src/test/scala/org/apache/spark/network/ConnectionManagerSuite.scala b/core/src/test/scala/org/apache/spark/network/ConnectionManagerSuite.scala index 32d7fadb95b00..e983c0c3e7815 100644 --- a/core/src/test/scala/org/apache/spark/network/ConnectionManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/network/ConnectionManagerSuite.scala @@ -17,6 +17,7 @@ package org.apache.spark.network +import java.io.IOException import java.nio._ import org.apache.spark.{SecurityManager, SparkConf} @@ -47,7 +48,7 @@ class ConnectionManagerSuite extends FunSuite { buffer.flip val bufferMessage = Message.createBufferMessage(buffer.duplicate) - Await.ready(manager.sendMessageReliably(manager.id, bufferMessage), 10 seconds) + Await.result(manager.sendMessageReliably(manager.id, bufferMessage), 10 seconds) assert(receivedMessage == true) @@ -80,7 +81,7 @@ class ConnectionManagerSuite extends FunSuite { (0 until count).map(i => { val bufferMessage = Message.createBufferMessage(buffer.duplicate) - Await.ready(manager.sendMessageReliably(manager.id, bufferMessage), 10 seconds) + Await.result(manager.sendMessageReliably(managerServer.id, bufferMessage), 10 seconds) }) assert(numReceivedServerMessages == 10) @@ -119,7 +120,10 @@ class ConnectionManagerSuite extends FunSuite { val buffer = ByteBuffer.allocate(size).put(Array.tabulate[Byte](size)(x => x.toByte)) buffer.flip val bufferMessage = Message.createBufferMessage(buffer.duplicate) - Await.result(manager.sendMessageReliably(manager.id, bufferMessage), 10 seconds) + // Expect managerServer to close connection, which we'll report as an error: + intercept[IOException] { + Await.result(manager.sendMessageReliably(managerServer.id, bufferMessage), 10 seconds) + } assert(numReceivedServerMessages == 0) assert(numReceivedMessages == 0) @@ -164,6 +168,8 @@ class ConnectionManagerSuite extends FunSuite { val g = Await.result(f, 1 second) assert(false) } catch { + case i: IOException => + assert(true) case e: TimeoutException => { // we should timeout here since the client can't do the negotiation assert(true) From 83673de3a429492f4520015fd0e7bd9cffdfaabf Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Wed, 6 Aug 2014 14:17:30 -0700 Subject: [PATCH 13/14] Error ACKs should trigger IOExceptions, so catch only those exceptions in the test. --- .../org/apache/spark/network/ConnectionManagerSuite.scala | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/core/src/test/scala/org/apache/spark/network/ConnectionManagerSuite.scala b/core/src/test/scala/org/apache/spark/network/ConnectionManagerSuite.scala index e983c0c3e7815..846537df003df 100644 --- a/core/src/test/scala/org/apache/spark/network/ConnectionManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/network/ConnectionManagerSuite.scala @@ -246,8 +246,9 @@ class ConnectionManagerSuite extends FunSuite { val future = manager.sendMessageReliably(managerServer.id, bufferMessage) - val message = Try(Await.result(future, 1 second)) - assert(message.isFailure) + intercept[IOException] { + Await.result(future, 1 second) + } manager.stop() managerServer.stop() From 68620cbf73f91475483f49ebece358bf9c499f1c Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Wed, 6 Aug 2014 16:18:51 -0700 Subject: [PATCH 14/14] Fix test in BlockFetcherIteratorSuite: We now signal failures via Futures that throw exception, not Futures that contain messages with hasError set. --- .../org/apache/spark/storage/BlockFetcherIterator.scala | 2 +- .../org/apache/spark/storage/BlockFetcherIteratorSuite.scala | 5 ++--- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/storage/BlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/BlockFetcherIterator.scala index f05fadf5c26da..938af6f5b923a 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockFetcherIterator.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockFetcherIterator.scala @@ -137,7 +137,7 @@ object BlockFetcherIterator { } } case Failure(exception) => { - logError("Could not get block(s) from " + cmId) + logError("Could not get block(s) from " + cmId, exception) for ((blockId, size) <- req.blocks) { results.put(new FetchResult(blockId, -1, null)) } diff --git a/core/src/test/scala/org/apache/spark/storage/BlockFetcherIteratorSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockFetcherIteratorSuite.scala index 0bcf42c981cea..1538995a6b404 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockFetcherIteratorSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockFetcherIteratorSuite.scala @@ -17,6 +17,7 @@ package org.apache.spark.storage +import java.io.IOException import java.nio.ByteBuffer import scala.collection.mutable.ArrayBuffer @@ -146,9 +147,7 @@ class BlockFetcherIteratorSuite extends FunSuite with Matchers { when(blockManager.connectionManager).thenReturn(connManager) val f = future { - val message = Message.createBufferMessage(0) - message.hasError = true - message + throw new IOException("Send failed or we received an error ACK") } when(connManager.sendMessageReliably(any(), any())).thenReturn(f)