From 1d6b71898e2a640e3c0809695d2b83f3f84eaa38 Mon Sep 17 00:00:00 2001 From: Jose Torres Date: Tue, 15 May 2018 11:07:54 -0700 Subject: [PATCH 01/33] continuous shuffle read RDD --- .../shuffle/ContinuousShuffleReadRDD.scala | 64 +++++++++ .../shuffle/UnsafeRowReceiver.scala | 56 ++++++++ .../shuffle/ContinuousShuffleReadSuite.scala | 122 ++++++++++++++++++ 3 files changed, 242 insertions(+) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleReadRDD.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/UnsafeRowReceiver.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/shuffle/ContinuousShuffleReadSuite.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleReadRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleReadRDD.scala new file mode 100644 index 0000000000000..110b51f8922d8 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleReadRDD.scala @@ -0,0 +1,64 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.continuous.shuffle + +import java.util.UUID + +import org.apache.spark.{Partition, SparkContext, SparkEnv, TaskContext} +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.expressions.UnsafeRow +import org.apache.spark.util.NextIterator + +case class ContinuousShuffleReadPartition(index: Int) extends Partition { + // Initialized only on the executor, and only once even as we call compute() multiple times. + lazy val (receiver, endpoint) = { + val env = SparkEnv.get.rpcEnv + val receiver = new UnsafeRowReceiver(env) + val endpoint = env.setupEndpoint(UUID.randomUUID().toString, receiver) + TaskContext.get().addTaskCompletionListener { ctx => + env.stop(endpoint) + } + (receiver, endpoint) + } +} + +/** + * RDD at the bottom of each continuous processing shuffle task, reading from the + */ +class ContinuousShuffleReadRDD(sc: SparkContext, numPartitions: Int) + extends RDD[UnsafeRow](sc, Nil) { + + override protected def getPartitions: Array[Partition] = { + (0 until numPartitions).map(ContinuousShuffleReadPartition).toArray + } + + override def compute(split: Partition, context: TaskContext): Iterator[UnsafeRow] = { + val receiver = split.asInstanceOf[ContinuousShuffleReadPartition].receiver + + new NextIterator[UnsafeRow] { + override def getNext(): UnsafeRow = receiver.poll() match { + case ReceiverRow(r) => r + case ReceiverEpochMarker() => + finished = true + null + } + + override def close(): Unit = {} + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/UnsafeRowReceiver.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/UnsafeRowReceiver.scala new file mode 100644 index 0000000000000..637305e54519f --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/UnsafeRowReceiver.scala @@ -0,0 +1,56 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.continuous.shuffle + +import java.util.concurrent.{ArrayBlockingQueue, BlockingQueue} +import java.util.concurrent.atomic.AtomicBoolean + +import org.apache.spark.internal.Logging +import org.apache.spark.rpc.{RpcCallContext, RpcEnv, ThreadSafeRpcEndpoint} +import org.apache.spark.sql.catalyst.expressions.UnsafeRow + +/** + * Messages for the UnsafeRowReceiver endpoint. Either an incoming row or an epoch marker. + */ +private[shuffle] sealed trait UnsafeRowReceiverMessage extends Serializable +private[shuffle] case class ReceiverRow(row: UnsafeRow) extends UnsafeRowReceiverMessage +private[shuffle] case class ReceiverEpochMarker() extends UnsafeRowReceiverMessage + +/** + * RPC endpoint for receiving rows into a continuous processing shuffle task. + */ +private[shuffle] class UnsafeRowReceiver(val rpcEnv: RpcEnv) + extends ThreadSafeRpcEndpoint with Logging { + private val queue = new ArrayBlockingQueue[UnsafeRowReceiverMessage](1024) + var stopped = new AtomicBoolean(false) + + override def onStop(): Unit = { + stopped.set(true) + } + + override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { + case r: UnsafeRowReceiverMessage => + queue.put(r) + context.reply(()) + } + + /** + * Polls until a new row is available. + */ + def poll(): UnsafeRowReceiverMessage = queue.poll() +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/shuffle/ContinuousShuffleReadSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/shuffle/ContinuousShuffleReadSuite.scala new file mode 100644 index 0000000000000..ee0ff8ada23f4 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/shuffle/ContinuousShuffleReadSuite.scala @@ -0,0 +1,122 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.continuous.shuffle + +import org.apache.spark.{TaskContext, TaskContextImpl} +import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeProjection} +import org.apache.spark.sql.streaming.StreamTest +import org.apache.spark.sql.types.{DataType, IntegerType} + +class ContinuousShuffleReadSuite extends StreamTest { + + private def unsafeRow(value: Int) = { + UnsafeProjection.create(Array(IntegerType : DataType))( + new GenericInternalRow(Array(value: Any))) + } + + var ctx: TaskContextImpl = _ + + override def beforeEach(): Unit = { + super.beforeEach() + ctx = TaskContext.empty() + TaskContext.setTaskContext(ctx) + } + + override def afterEach(): Unit = { + ctx.markTaskCompleted(None) + ctx = null + super.afterEach() + } + + test("receiver stopped with row last") { + val rdd = new ContinuousShuffleReadRDD(sparkContext, numPartitions = 1) + val endpoint = rdd.partitions(0).asInstanceOf[ContinuousShuffleReadPartition].endpoint + endpoint.askSync[Unit](ReceiverEpochMarker()) + endpoint.askSync[Unit](ReceiverRow(unsafeRow(111))) + + ctx.markTaskCompleted(None) + val receiver = rdd.partitions(0).asInstanceOf[ContinuousShuffleReadPartition].receiver + eventually(timeout(streamingTimeout)) { + assert(receiver.stopped.get()) + } + } + + test("receiver stopped with marker last") { + val rdd = new ContinuousShuffleReadRDD(sparkContext, numPartitions = 1) + val endpoint = rdd.partitions(0).asInstanceOf[ContinuousShuffleReadPartition].endpoint + endpoint.askSync[Unit](ReceiverRow(unsafeRow(111))) + endpoint.askSync[Unit](ReceiverEpochMarker()) + + ctx.markTaskCompleted(None) + val receiver = rdd.partitions(0).asInstanceOf[ContinuousShuffleReadPartition].receiver + eventually(timeout(streamingTimeout)) { + assert(receiver.stopped.get()) + } + } + + test("one epoch") { + val rdd = new ContinuousShuffleReadRDD(sparkContext, numPartitions = 1) + val endpoint = rdd.partitions(0).asInstanceOf[ContinuousShuffleReadPartition].endpoint + endpoint.askSync[Unit](ReceiverRow(unsafeRow(111))) + endpoint.askSync[Unit](ReceiverRow(unsafeRow(222))) + endpoint.askSync[Unit](ReceiverRow(unsafeRow(333))) + endpoint.askSync[Unit](ReceiverEpochMarker()) + + val iter = rdd.compute(rdd.partitions(0), ctx) + assert(iter.next().getInt(0) == 111) + assert(iter.next().getInt(0) == 222) + assert(iter.next().getInt(0) == 333) + assert(!iter.hasNext) + } + + test("multiple epochs") { + val rdd = new ContinuousShuffleReadRDD(sparkContext, numPartitions = 1) + val endpoint = rdd.partitions(0).asInstanceOf[ContinuousShuffleReadPartition].endpoint + endpoint.askSync[Unit](ReceiverRow(unsafeRow(111))) + endpoint.askSync[Unit](ReceiverEpochMarker()) + endpoint.askSync[Unit](ReceiverRow(unsafeRow(222))) + endpoint.askSync[Unit](ReceiverRow(unsafeRow(333))) + endpoint.askSync[Unit](ReceiverEpochMarker()) + + val firstEpoch = rdd.compute(rdd.partitions(0), ctx) + assert(firstEpoch.next().getInt(0) == 111) + assert(!firstEpoch.hasNext) + + val secondEpoch = rdd.compute(rdd.partitions(0), ctx) + assert(secondEpoch.next().getInt(0) == 222) + assert(secondEpoch.next().getInt(0) == 333) + assert(!secondEpoch.hasNext) + } + + test("empty epochs") { + val rdd = new ContinuousShuffleReadRDD(sparkContext, numPartitions = 1) + val endpoint = rdd.partitions(0).asInstanceOf[ContinuousShuffleReadPartition].endpoint + endpoint.askSync[Unit](ReceiverEpochMarker()) + endpoint.askSync[Unit](ReceiverEpochMarker()) + endpoint.askSync[Unit](ReceiverRow(unsafeRow(111))) + endpoint.askSync[Unit](ReceiverEpochMarker()) + endpoint.askSync[Unit](ReceiverEpochMarker()) + + assert(rdd.compute(rdd.partitions(0), ctx).isEmpty) + assert(rdd.compute(rdd.partitions(0), ctx).isEmpty) + val thirdEpoch = rdd.compute(rdd.partitions(0), ctx) + assert(thirdEpoch.next().getInt(0) == 111) + assert(rdd.compute(rdd.partitions(0), ctx).isEmpty) + assert(rdd.compute(rdd.partitions(0), ctx).isEmpty) + } +} From b5d100875932bdfcb645c8f6b2cdb7b815d84c80 Mon Sep 17 00:00:00 2001 From: Jose Torres Date: Wed, 16 May 2018 20:11:11 -0700 Subject: [PATCH 02/33] docs --- .../continuous/shuffle/ContinuousShuffleReadRDD.scala | 4 +++- .../streaming/continuous/shuffle/UnsafeRowReceiver.scala | 8 +++++++- 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleReadRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleReadRDD.scala index 110b51f8922d8..01a7172357883 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleReadRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleReadRDD.scala @@ -38,7 +38,9 @@ case class ContinuousShuffleReadPartition(index: Int) extends Partition { } /** - * RDD at the bottom of each continuous processing shuffle task, reading from the + * RDD at the map side of each continuous processing shuffle task. Upstream tasks send their + * shuffle output to the wrapped receivers in partitions of this RDD; each of the RDD's tasks + * poll from their receiver until an epoch marker is sent. */ class ContinuousShuffleReadRDD(sc: SparkContext, numPartitions: Int) extends RDD[UnsafeRow](sc, Nil) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/UnsafeRowReceiver.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/UnsafeRowReceiver.scala index 637305e54519f..addfe357f63f1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/UnsafeRowReceiver.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/UnsafeRowReceiver.scala @@ -32,10 +32,16 @@ private[shuffle] case class ReceiverRow(row: UnsafeRow) extends UnsafeRowReceive private[shuffle] case class ReceiverEpochMarker() extends UnsafeRowReceiverMessage /** - * RPC endpoint for receiving rows into a continuous processing shuffle task. + * RPC endpoint for receiving rows into a continuous processing shuffle task. Continuous shuffle + * writers will send rows here, with continuous shuffle readers polling for new rows as needed. + * + * TODO: Support multiple source tasks. We need to output a single epoch marker once all + * source tasks have sent one. */ private[shuffle] class UnsafeRowReceiver(val rpcEnv: RpcEnv) extends ThreadSafeRpcEndpoint with Logging { + // Note that this queue will be drained from the main task thread and populated in the RPC + // response thread. private val queue = new ArrayBlockingQueue[UnsafeRowReceiverMessage](1024) var stopped = new AtomicBoolean(false) From 46456dc75a6aec9659b18523c421999debd060eb Mon Sep 17 00:00:00 2001 From: Jose Torres Date: Wed, 16 May 2018 20:22:49 -0700 Subject: [PATCH 03/33] fix ctor --- .../continuous/shuffle/ContinuousShuffleReadRDD.scala | 3 ++- .../streaming/continuous/shuffle/UnsafeRowReceiver.scala | 6 ++++-- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleReadRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleReadRDD.scala index 01a7172357883..f69977749de08 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleReadRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleReadRDD.scala @@ -22,13 +22,14 @@ import java.util.UUID import org.apache.spark.{Partition, SparkContext, SparkEnv, TaskContext} import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.expressions.UnsafeRow +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.util.NextIterator case class ContinuousShuffleReadPartition(index: Int) extends Partition { // Initialized only on the executor, and only once even as we call compute() multiple times. lazy val (receiver, endpoint) = { val env = SparkEnv.get.rpcEnv - val receiver = new UnsafeRowReceiver(env) + val receiver = new UnsafeRowReceiver(SQLConf.get.continuousStreamingExecutorQueueSize, env) val endpoint = env.setupEndpoint(UUID.randomUUID().toString, receiver) TaskContext.get().addTaskCompletionListener { ctx => env.stop(endpoint) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/UnsafeRowReceiver.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/UnsafeRowReceiver.scala index addfe357f63f1..f50b89f276eaa 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/UnsafeRowReceiver.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/UnsafeRowReceiver.scala @@ -38,11 +38,13 @@ private[shuffle] case class ReceiverEpochMarker() extends UnsafeRowReceiverMessa * TODO: Support multiple source tasks. We need to output a single epoch marker once all * source tasks have sent one. */ -private[shuffle] class UnsafeRowReceiver(val rpcEnv: RpcEnv) +private[shuffle] class UnsafeRowReceiver( + queueSize: Int, + override val rpcEnv: RpcEnv) extends ThreadSafeRpcEndpoint with Logging { // Note that this queue will be drained from the main task thread and populated in the RPC // response thread. - private val queue = new ArrayBlockingQueue[UnsafeRowReceiverMessage](1024) + private val queue = new ArrayBlockingQueue[UnsafeRowReceiverMessage](queueSize) var stopped = new AtomicBoolean(false) override def onStop(): Unit = { From 2ea8a6f94216e8b184e5780ec3e6ffb2838de382 Mon Sep 17 00:00:00 2001 From: Jose Torres Date: Wed, 16 May 2018 20:43:10 -0700 Subject: [PATCH 04/33] multiple partition test --- .../shuffle/ContinuousShuffleReadSuite.scala | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/shuffle/ContinuousShuffleReadSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/shuffle/ContinuousShuffleReadSuite.scala index ee0ff8ada23f4..ad841626c8b10 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/shuffle/ContinuousShuffleReadSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/shuffle/ContinuousShuffleReadSuite.scala @@ -119,4 +119,18 @@ class ContinuousShuffleReadSuite extends StreamTest { assert(rdd.compute(rdd.partitions(0), ctx).isEmpty) assert(rdd.compute(rdd.partitions(0), ctx).isEmpty) } + + test("multiple partitions") { + val rdd = new ContinuousShuffleReadRDD(sparkContext, numPartitions = 5) + for (p <- rdd.partitions) { + val part = p.asInstanceOf[ContinuousShuffleReadPartition] + // Send index to ensure data doesn't somehow cross over between partitions. + part.endpoint.askSync[Unit](ReceiverRow(unsafeRow(part.index))) + part.endpoint.askSync[Unit](ReceiverEpochMarker()) + + val iter = rdd.compute(part, ctx) + assert(iter.next().getInt(0) == part.index) + assert(!iter.hasNext) + } + } } From 955ac79eb05dc389e632d1aaa6c59396835c6ed5 Mon Sep 17 00:00:00 2001 From: Jose Torres Date: Thu, 17 May 2018 06:33:51 -0700 Subject: [PATCH 05/33] unset task context after test --- .../continuous/shuffle/ContinuousShuffleReadSuite.scala | 1 + 1 file changed, 1 insertion(+) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/shuffle/ContinuousShuffleReadSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/shuffle/ContinuousShuffleReadSuite.scala index ad841626c8b10..faec753768895 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/shuffle/ContinuousShuffleReadSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/shuffle/ContinuousShuffleReadSuite.scala @@ -39,6 +39,7 @@ class ContinuousShuffleReadSuite extends StreamTest { override def afterEach(): Unit = { ctx.markTaskCompleted(None) + TaskContext.unset() ctx = null super.afterEach() } From 8cefb724512b51f2aa1fdd81fa8a2d4560e60ce3 Mon Sep 17 00:00:00 2001 From: Jose Torres Date: Thu, 17 May 2018 17:00:05 -0700 Subject: [PATCH 06/33] conf from RDD --- .../shuffle/ContinuousShuffleReadRDD.scala | 31 ++++++++++++------- 1 file changed, 20 insertions(+), 11 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleReadRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleReadRDD.scala index f69977749de08..a85ba28a8d528 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleReadRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleReadRDD.scala @@ -20,22 +20,19 @@ package org.apache.spark.sql.execution.streaming.continuous.shuffle import java.util.UUID import org.apache.spark.{Partition, SparkContext, SparkEnv, TaskContext} + import org.apache.spark.rdd.RDD +import org.apache.spark.rpc.{RpcEndpoint, RpcEndpointRef} import org.apache.spark.sql.catalyst.expressions.UnsafeRow import org.apache.spark.sql.internal.SQLConf import org.apache.spark.util.NextIterator case class ContinuousShuffleReadPartition(index: Int) extends Partition { - // Initialized only on the executor, and only once even as we call compute() multiple times. - lazy val (receiver, endpoint) = { - val env = SparkEnv.get.rpcEnv - val receiver = new UnsafeRowReceiver(SQLConf.get.continuousStreamingExecutorQueueSize, env) - val endpoint = env.setupEndpoint(UUID.randomUUID().toString, receiver) - TaskContext.get().addTaskCompletionListener { ctx => - env.stop(endpoint) - } - (receiver, endpoint) - } + // Semantically lazy vals - initialized only on the executor, and only once even as we call + // compute() multiple times. We need to initialize them inside compute() so we have access to the + // RDD's conf. + var receiver: UnsafeRowReceiver = _ + var endpoint: RpcEndpointRef = _ } /** @@ -46,14 +43,26 @@ case class ContinuousShuffleReadPartition(index: Int) extends Partition { class ContinuousShuffleReadRDD(sc: SparkContext, numPartitions: Int) extends RDD[UnsafeRow](sc, Nil) { + private val queueSize = sc.conf.get(SQLConf.CONTINUOUS_STREAMING_EXECUTOR_QUEUE_SIZE) + override protected def getPartitions: Array[Partition] = { (0 until numPartitions).map(ContinuousShuffleReadPartition).toArray } override def compute(split: Partition, context: TaskContext): Iterator[UnsafeRow] = { - val receiver = split.asInstanceOf[ContinuousShuffleReadPartition].receiver + val part = split.asInstanceOf[ContinuousShuffleReadPartition] + if (part.receiver == null) { + val env = SparkEnv.get.rpcEnv + part.receiver = new UnsafeRowReceiver(queueSize, env) + part.endpoint = env.setupEndpoint(UUID.randomUUID().toString, part.receiver) + TaskContext.get().addTaskCompletionListener { _ => + env.stop(part.endpoint) + } + } new NextIterator[UnsafeRow] { + private val receiver = part.receiver + override def getNext(): UnsafeRow = receiver.poll() match { case ReceiverRow(r) => r case ReceiverEpochMarker() => From f91bfe7e3fc174202d7d5c7cde5a8fb7ce86bfd3 Mon Sep 17 00:00:00 2001 From: Jose Torres Date: Thu, 17 May 2018 17:00:44 -0700 Subject: [PATCH 07/33] endpoint name --- .../streaming/continuous/shuffle/ContinuousShuffleReadRDD.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleReadRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleReadRDD.scala index a85ba28a8d528..d4128c5e274fe 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleReadRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleReadRDD.scala @@ -54,7 +54,7 @@ class ContinuousShuffleReadRDD(sc: SparkContext, numPartitions: Int) if (part.receiver == null) { val env = SparkEnv.get.rpcEnv part.receiver = new UnsafeRowReceiver(queueSize, env) - part.endpoint = env.setupEndpoint(UUID.randomUUID().toString, part.receiver) + part.endpoint = env.setupEndpoint(s"UnsafeRowReceiver-${UUID.randomUUID()}", part.receiver) TaskContext.get().addTaskCompletionListener { _ => env.stop(part.endpoint) } From 259029298fc42a65e8ebb4d2effe49b7fafa96f1 Mon Sep 17 00:00:00 2001 From: Jose Torres Date: Thu, 17 May 2018 17:02:08 -0700 Subject: [PATCH 08/33] testing bool --- .../streaming/continuous/shuffle/UnsafeRowReceiver.scala | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/UnsafeRowReceiver.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/UnsafeRowReceiver.scala index f50b89f276eaa..d6a6d978d2653 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/UnsafeRowReceiver.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/UnsafeRowReceiver.scala @@ -45,7 +45,9 @@ private[shuffle] class UnsafeRowReceiver( // Note that this queue will be drained from the main task thread and populated in the RPC // response thread. private val queue = new ArrayBlockingQueue[UnsafeRowReceiverMessage](queueSize) - var stopped = new AtomicBoolean(false) + + // Exposed for testing to determine if the endpoint gets stopped on task end. + private[shuffle] val stopped = new AtomicBoolean(false) override def onStop(): Unit = { stopped.set(true) From 859e6e4dd4dd90ffd70fc9cbd243c94090d72506 Mon Sep 17 00:00:00 2001 From: Jose Torres Date: Thu, 17 May 2018 17:22:10 -0700 Subject: [PATCH 09/33] tests --- .../shuffle/ContinuousShuffleReadRDD.scala | 33 ++++++++----------- .../shuffle/ContinuousShuffleReadSuite.scala | 23 ++++++------- 2 files changed, 25 insertions(+), 31 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleReadRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleReadRDD.scala index d4128c5e274fe..7f5aaecabf5ad 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleReadRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleReadRDD.scala @@ -20,19 +20,22 @@ package org.apache.spark.sql.execution.streaming.continuous.shuffle import java.util.UUID import org.apache.spark.{Partition, SparkContext, SparkEnv, TaskContext} - import org.apache.spark.rdd.RDD -import org.apache.spark.rpc.{RpcEndpoint, RpcEndpointRef} import org.apache.spark.sql.catalyst.expressions.UnsafeRow import org.apache.spark.sql.internal.SQLConf import org.apache.spark.util.NextIterator -case class ContinuousShuffleReadPartition(index: Int) extends Partition { - // Semantically lazy vals - initialized only on the executor, and only once even as we call - // compute() multiple times. We need to initialize them inside compute() so we have access to the - // RDD's conf. - var receiver: UnsafeRowReceiver = _ - var endpoint: RpcEndpointRef = _ +case class ContinuousShuffleReadPartition(index: Int, queueSize: Int) extends Partition { + // Initialized only on the executor, and only once even as we call compute() multiple times. + lazy val (receiver, endpoint) = { + val env = SparkEnv.get.rpcEnv + val receiver = new UnsafeRowReceiver(queueSize, env) + val endpoint = env.setupEndpoint(s"UnsafeRowReceiver-${UUID.randomUUID().toString}", receiver) + TaskContext.get().addTaskCompletionListener { ctx => + env.stop(endpoint) + } + (receiver, endpoint) + } } /** @@ -46,23 +49,13 @@ class ContinuousShuffleReadRDD(sc: SparkContext, numPartitions: Int) private val queueSize = sc.conf.get(SQLConf.CONTINUOUS_STREAMING_EXECUTOR_QUEUE_SIZE) override protected def getPartitions: Array[Partition] = { - (0 until numPartitions).map(ContinuousShuffleReadPartition).toArray + (0 until numPartitions).map(ContinuousShuffleReadPartition(_, queueSize)).toArray } override def compute(split: Partition, context: TaskContext): Iterator[UnsafeRow] = { - val part = split.asInstanceOf[ContinuousShuffleReadPartition] - if (part.receiver == null) { - val env = SparkEnv.get.rpcEnv - part.receiver = new UnsafeRowReceiver(queueSize, env) - part.endpoint = env.setupEndpoint(s"UnsafeRowReceiver-${UUID.randomUUID()}", part.receiver) - TaskContext.get().addTaskCompletionListener { _ => - env.stop(part.endpoint) - } - } + val receiver = split.asInstanceOf[ContinuousShuffleReadPartition].receiver new NextIterator[UnsafeRow] { - private val receiver = part.receiver - override def getNext(): UnsafeRow = receiver.poll() match { case ReceiverRow(r) => r case ReceiverEpochMarker() => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/shuffle/ContinuousShuffleReadSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/shuffle/ContinuousShuffleReadSuite.scala index faec753768895..19558e1c49319 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/shuffle/ContinuousShuffleReadSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/shuffle/ContinuousShuffleReadSuite.scala @@ -79,10 +79,7 @@ class ContinuousShuffleReadSuite extends StreamTest { endpoint.askSync[Unit](ReceiverEpochMarker()) val iter = rdd.compute(rdd.partitions(0), ctx) - assert(iter.next().getInt(0) == 111) - assert(iter.next().getInt(0) == 222) - assert(iter.next().getInt(0) == 333) - assert(!iter.hasNext) + assert(iter.toSeq.map(_.getInt(0)) == Seq(111, 222, 333)) } test("multiple epochs") { @@ -95,13 +92,10 @@ class ContinuousShuffleReadSuite extends StreamTest { endpoint.askSync[Unit](ReceiverEpochMarker()) val firstEpoch = rdd.compute(rdd.partitions(0), ctx) - assert(firstEpoch.next().getInt(0) == 111) - assert(!firstEpoch.hasNext) + assert(firstEpoch.toSeq.map(_.getInt(0)) == Seq(111)) val secondEpoch = rdd.compute(rdd.partitions(0), ctx) - assert(secondEpoch.next().getInt(0) == 222) - assert(secondEpoch.next().getInt(0) == 333) - assert(!secondEpoch.hasNext) + assert(secondEpoch.toSeq.map(_.getInt(0)) == Seq(222, 333)) } test("empty epochs") { @@ -112,23 +106,30 @@ class ContinuousShuffleReadSuite extends StreamTest { endpoint.askSync[Unit](ReceiverRow(unsafeRow(111))) endpoint.askSync[Unit](ReceiverEpochMarker()) endpoint.askSync[Unit](ReceiverEpochMarker()) + endpoint.askSync[Unit](ReceiverEpochMarker()) assert(rdd.compute(rdd.partitions(0), ctx).isEmpty) assert(rdd.compute(rdd.partitions(0), ctx).isEmpty) + val thirdEpoch = rdd.compute(rdd.partitions(0), ctx) - assert(thirdEpoch.next().getInt(0) == 111) + assert(thirdEpoch.toSeq.map(_.getInt(0)) == Seq(111)) + assert(rdd.compute(rdd.partitions(0), ctx).isEmpty) assert(rdd.compute(rdd.partitions(0), ctx).isEmpty) } test("multiple partitions") { val rdd = new ContinuousShuffleReadRDD(sparkContext, numPartitions = 5) + // Send all data before processing to ensure there's no crossover. for (p <- rdd.partitions) { val part = p.asInstanceOf[ContinuousShuffleReadPartition] - // Send index to ensure data doesn't somehow cross over between partitions. + // Send index for identification. part.endpoint.askSync[Unit](ReceiverRow(unsafeRow(part.index))) part.endpoint.askSync[Unit](ReceiverEpochMarker()) + } + for (p <- rdd.partitions) { + val part = p.asInstanceOf[ContinuousShuffleReadPartition] val iter = rdd.compute(part, ctx) assert(iter.next().getInt(0) == part.index) assert(!iter.hasNext) From b23b7bb17abe3cbc873a3144c56d08c88bc0c963 Mon Sep 17 00:00:00 2001 From: Jose Torres Date: Thu, 17 May 2018 17:40:55 -0700 Subject: [PATCH 10/33] take instead of poll --- .../shuffle/ContinuousShuffleReadRDD.scala | 2 +- .../shuffle/UnsafeRowReceiver.scala | 4 ++-- .../shuffle/ContinuousShuffleReadSuite.scala | 21 +++++++++++++++++++ 3 files changed, 24 insertions(+), 3 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleReadRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleReadRDD.scala index 7f5aaecabf5ad..16d7f29ea31cc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleReadRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleReadRDD.scala @@ -56,7 +56,7 @@ class ContinuousShuffleReadRDD(sc: SparkContext, numPartitions: Int) val receiver = split.asInstanceOf[ContinuousShuffleReadPartition].receiver new NextIterator[UnsafeRow] { - override def getNext(): UnsafeRow = receiver.poll() match { + override def getNext(): UnsafeRow = receiver.take() match { case ReceiverRow(r) => r case ReceiverEpochMarker() => finished = true diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/UnsafeRowReceiver.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/UnsafeRowReceiver.scala index d6a6d978d2653..37cee999b1bed 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/UnsafeRowReceiver.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/UnsafeRowReceiver.scala @@ -60,7 +60,7 @@ private[shuffle] class UnsafeRowReceiver( } /** - * Polls until a new row is available. + * Take the next row, blocking until it's ready. */ - def poll(): UnsafeRowReceiverMessage = queue.poll() + def take(): UnsafeRowReceiverMessage = queue.take() } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/shuffle/ContinuousShuffleReadSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/shuffle/ContinuousShuffleReadSuite.scala index 19558e1c49319..6f201186d572c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/shuffle/ContinuousShuffleReadSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/shuffle/ContinuousShuffleReadSuite.scala @@ -29,6 +29,9 @@ class ContinuousShuffleReadSuite extends StreamTest { new GenericInternalRow(Array(value: Any))) } + // In this unit test, we emulate that we're in the task thread where + // ContinuousShuffleReadRDD.compute() will be evaluated. This requires a task context + // thread local to be set. var ctx: TaskContextImpl = _ override def beforeEach(): Unit = { @@ -135,4 +138,22 @@ class ContinuousShuffleReadSuite extends StreamTest { assert(!iter.hasNext) } } + + test("blocks waiting for new rows") { + val rdd = new ContinuousShuffleReadRDD(sparkContext, numPartitions = 1) + + val readRow = new Thread { + override def run(): Unit = { + // set the non-inheritable thread local + TaskContext.setTaskContext(ctx) + val epoch = rdd.compute(rdd.partitions(0), ctx) + epoch.next().getInt(0) + } + } + + readRow.start() + eventually(timeout(streamingTimeout)) { + assert(readRow.getState == Thread.State.WAITING) + } + } } From 97f7e8ff865e6054d0d70914ce9bb51880b161f6 Mon Sep 17 00:00:00 2001 From: Jose Torres Date: Thu, 17 May 2018 17:58:44 -0700 Subject: [PATCH 11/33] add interface --- .../shuffle/ContinuousShuffleReadRDD.scala | 15 ++------- .../shuffle/ContinuousShuffleReader.scala | 31 +++++++++++++++++++ .../shuffle/UnsafeRowReceiver.scala | 19 +++++++++--- .../shuffle/ContinuousShuffleReadSuite.scala | 8 ++--- 4 files changed, 51 insertions(+), 22 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleReader.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleReadRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleReadRDD.scala index 16d7f29ea31cc..74419a5b35fdb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleReadRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleReadRDD.scala @@ -27,7 +27,7 @@ import org.apache.spark.util.NextIterator case class ContinuousShuffleReadPartition(index: Int, queueSize: Int) extends Partition { // Initialized only on the executor, and only once even as we call compute() multiple times. - lazy val (receiver, endpoint) = { + lazy val (reader: ContinuousShuffleReader, endpoint) = { val env = SparkEnv.get.rpcEnv val receiver = new UnsafeRowReceiver(queueSize, env) val endpoint = env.setupEndpoint(s"UnsafeRowReceiver-${UUID.randomUUID().toString}", receiver) @@ -53,17 +53,6 @@ class ContinuousShuffleReadRDD(sc: SparkContext, numPartitions: Int) } override def compute(split: Partition, context: TaskContext): Iterator[UnsafeRow] = { - val receiver = split.asInstanceOf[ContinuousShuffleReadPartition].receiver - - new NextIterator[UnsafeRow] { - override def getNext(): UnsafeRow = receiver.take() match { - case ReceiverRow(r) => r - case ReceiverEpochMarker() => - finished = true - null - } - - override def close(): Unit = {} - } + split.asInstanceOf[ContinuousShuffleReadPartition].reader.read() } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleReader.scala new file mode 100644 index 0000000000000..08091bbc12893 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleReader.scala @@ -0,0 +1,31 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.continuous.shuffle + +import org.apache.spark.sql.catalyst.expressions.UnsafeRow + +/** + * Trait for reading from a continuous processing shuffle. + */ +trait ContinuousShuffleReader { + /** + * Returns an iterator over the incoming rows in the current epoch. Note that this iterator can + * block waiting for new rows to arrive. + */ + def read(): Iterator[UnsafeRow] +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/UnsafeRowReceiver.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/UnsafeRowReceiver.scala index 37cee999b1bed..b8adbb743c6c2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/UnsafeRowReceiver.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/UnsafeRowReceiver.scala @@ -23,6 +23,7 @@ import java.util.concurrent.atomic.AtomicBoolean import org.apache.spark.internal.Logging import org.apache.spark.rpc.{RpcCallContext, RpcEnv, ThreadSafeRpcEndpoint} import org.apache.spark.sql.catalyst.expressions.UnsafeRow +import org.apache.spark.util.NextIterator /** * Messages for the UnsafeRowReceiver endpoint. Either an incoming row or an epoch marker. @@ -41,7 +42,7 @@ private[shuffle] case class ReceiverEpochMarker() extends UnsafeRowReceiverMessa private[shuffle] class UnsafeRowReceiver( queueSize: Int, override val rpcEnv: RpcEnv) - extends ThreadSafeRpcEndpoint with Logging { + extends ThreadSafeRpcEndpoint with ContinuousShuffleReader with Logging { // Note that this queue will be drained from the main task thread and populated in the RPC // response thread. private val queue = new ArrayBlockingQueue[UnsafeRowReceiverMessage](queueSize) @@ -59,8 +60,16 @@ private[shuffle] class UnsafeRowReceiver( context.reply(()) } - /** - * Take the next row, blocking until it's ready. - */ - def take(): UnsafeRowReceiverMessage = queue.take() + override def read(): Iterator[UnsafeRow] = { + new NextIterator[UnsafeRow] { + override def getNext(): UnsafeRow = queue.take() match { + case ReceiverRow(r) => r + case ReceiverEpochMarker() => + finished = true + null + } + + override def close(): Unit = {} + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/shuffle/ContinuousShuffleReadSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/shuffle/ContinuousShuffleReadSuite.scala index 6f201186d572c..718fb0740bbe7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/shuffle/ContinuousShuffleReadSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/shuffle/ContinuousShuffleReadSuite.scala @@ -54,9 +54,9 @@ class ContinuousShuffleReadSuite extends StreamTest { endpoint.askSync[Unit](ReceiverRow(unsafeRow(111))) ctx.markTaskCompleted(None) - val receiver = rdd.partitions(0).asInstanceOf[ContinuousShuffleReadPartition].receiver + val receiver = rdd.partitions(0).asInstanceOf[ContinuousShuffleReadPartition].reader eventually(timeout(streamingTimeout)) { - assert(receiver.stopped.get()) + assert(receiver.asInstanceOf[UnsafeRowReceiver].stopped.get()) } } @@ -67,9 +67,9 @@ class ContinuousShuffleReadSuite extends StreamTest { endpoint.askSync[Unit](ReceiverEpochMarker()) ctx.markTaskCompleted(None) - val receiver = rdd.partitions(0).asInstanceOf[ContinuousShuffleReadPartition].receiver + val receiver = rdd.partitions(0).asInstanceOf[ContinuousShuffleReadPartition].reader eventually(timeout(streamingTimeout)) { - assert(receiver.stopped.get()) + assert(receiver.asInstanceOf[UnsafeRowReceiver].stopped.get()) } } From de21b1c25a333d44c0521fe151b468e51f0bdc47 Mon Sep 17 00:00:00 2001 From: Jose Torres Date: Thu, 17 May 2018 18:02:37 -0700 Subject: [PATCH 12/33] clarify comment --- .../continuous/shuffle/ContinuousShuffleReader.scala | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleReader.scala index 08091bbc12893..42631c90ebc55 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleReader.scala @@ -24,8 +24,9 @@ import org.apache.spark.sql.catalyst.expressions.UnsafeRow */ trait ContinuousShuffleReader { /** - * Returns an iterator over the incoming rows in the current epoch. Note that this iterator can - * block waiting for new rows to arrive. + * Returns an iterator over the incoming rows in an epoch. Implementations should block waiting + * for new rows to arrive, and end the iterator once they've received epoch markers from all + * shuffle writers. */ def read(): Iterator[UnsafeRow] } From 7dcf51a13e92a0bb2998e2a12e67d351e1c1a4fc Mon Sep 17 00:00:00 2001 From: Jose Torres Date: Fri, 18 May 2018 15:39:28 -0700 Subject: [PATCH 13/33] multiple --- .../shuffle/ContinuousShuffleReadRDD.scala | 11 +- .../shuffle/UnsafeRowReceiver.scala | 60 +++++++--- .../shuffle/ContinuousShuffleReadSuite.scala | 105 ++++++++++++++---- 3 files changed, 137 insertions(+), 39 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleReadRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleReadRDD.scala index 74419a5b35fdb..05fdeeb4450a3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleReadRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleReadRDD.scala @@ -25,11 +25,12 @@ import org.apache.spark.sql.catalyst.expressions.UnsafeRow import org.apache.spark.sql.internal.SQLConf import org.apache.spark.util.NextIterator -case class ContinuousShuffleReadPartition(index: Int, queueSize: Int) extends Partition { +case class ContinuousShuffleReadPartition(index: Int, queueSize: Int, numShuffleWriters: Int) + extends Partition { // Initialized only on the executor, and only once even as we call compute() multiple times. lazy val (reader: ContinuousShuffleReader, endpoint) = { val env = SparkEnv.get.rpcEnv - val receiver = new UnsafeRowReceiver(queueSize, env) + val receiver = new UnsafeRowReceiver(queueSize, numShuffleWriters, env) val endpoint = env.setupEndpoint(s"UnsafeRowReceiver-${UUID.randomUUID().toString}", receiver) TaskContext.get().addTaskCompletionListener { ctx => env.stop(endpoint) @@ -43,13 +44,15 @@ case class ContinuousShuffleReadPartition(index: Int, queueSize: Int) extends Pa * shuffle output to the wrapped receivers in partitions of this RDD; each of the RDD's tasks * poll from their receiver until an epoch marker is sent. */ -class ContinuousShuffleReadRDD(sc: SparkContext, numPartitions: Int) +class ContinuousShuffleReadRDD(sc: SparkContext, numPartitions: Int, numShuffleWriters: Int = 1) extends RDD[UnsafeRow](sc, Nil) { private val queueSize = sc.conf.get(SQLConf.CONTINUOUS_STREAMING_EXECUTOR_QUEUE_SIZE) override protected def getPartitions: Array[Partition] = { - (0 until numPartitions).map(ContinuousShuffleReadPartition(_, queueSize)).toArray + (0 until numPartitions).map { partIndex => + ContinuousShuffleReadPartition(partIndex, queueSize, numShuffleWriters) + }.toArray } override def compute(split: Partition, context: TaskContext): Iterator[UnsafeRow] = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/UnsafeRowReceiver.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/UnsafeRowReceiver.scala index b8adbb743c6c2..c7f3f0db41440 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/UnsafeRowReceiver.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/UnsafeRowReceiver.scala @@ -17,8 +17,10 @@ package org.apache.spark.sql.execution.streaming.continuous.shuffle -import java.util.concurrent.{ArrayBlockingQueue, BlockingQueue} -import java.util.concurrent.atomic.AtomicBoolean +import java.util.concurrent._ +import java.util.concurrent.atomic.{AtomicBoolean, AtomicInteger} + +import scala.concurrent.Future import org.apache.spark.internal.Logging import org.apache.spark.rpc.{RpcCallContext, RpcEnv, ThreadSafeRpcEndpoint} @@ -28,9 +30,12 @@ import org.apache.spark.util.NextIterator /** * Messages for the UnsafeRowReceiver endpoint. Either an incoming row or an epoch marker. */ -private[shuffle] sealed trait UnsafeRowReceiverMessage extends Serializable -private[shuffle] case class ReceiverRow(row: UnsafeRow) extends UnsafeRowReceiverMessage -private[shuffle] case class ReceiverEpochMarker() extends UnsafeRowReceiverMessage +private[shuffle] sealed trait UnsafeRowReceiverMessage extends Serializable { + def writerId: Int +} +private[shuffle] case class ReceiverRow(writerId: Int, row: UnsafeRow) + extends UnsafeRowReceiverMessage +private[shuffle] case class ReceiverEpochMarker(writerId: Int) extends UnsafeRowReceiverMessage /** * RPC endpoint for receiving rows into a continuous processing shuffle task. Continuous shuffle @@ -41,11 +46,14 @@ private[shuffle] case class ReceiverEpochMarker() extends UnsafeRowReceiverMessa */ private[shuffle] class UnsafeRowReceiver( queueSize: Int, + numShuffleWriters: Int, override val rpcEnv: RpcEnv) extends ThreadSafeRpcEndpoint with ContinuousShuffleReader with Logging { // Note that this queue will be drained from the main task thread and populated in the RPC // response thread. - private val queue = new ArrayBlockingQueue[UnsafeRowReceiverMessage](queueSize) + private val queues = Array.fill(numShuffleWriters) { + new ArrayBlockingQueue[UnsafeRowReceiverMessage](queueSize) + } // Exposed for testing to determine if the endpoint gets stopped on task end. private[shuffle] val stopped = new AtomicBoolean(false) @@ -56,20 +64,46 @@ private[shuffle] class UnsafeRowReceiver( override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { case r: UnsafeRowReceiverMessage => - queue.put(r) + queues(r.writerId).put(r) context.reply(()) } override def read(): Iterator[UnsafeRow] = { new NextIterator[UnsafeRow] { - override def getNext(): UnsafeRow = queue.take() match { - case ReceiverRow(r) => r - case ReceiverEpochMarker() => - finished = true - null + private val numWriterEpochMarkers = new AtomicInteger(0) + + private val executor = Executors.newFixedThreadPool(numShuffleWriters) + private val completion = new ExecutorCompletionService[UnsafeRowReceiverMessage](executor) + + private def completionTask(writerId: Int) = new Callable[UnsafeRowReceiverMessage] { + override def call(): UnsafeRowReceiverMessage = queues(writerId).take() } - override def close(): Unit = {} + (0 until numShuffleWriters).foreach(writerId => completion.submit(completionTask(writerId))) + + override def getNext(): UnsafeRow = { + completion.take().get() match { + case ReceiverRow(writerId, r) => + // Start reading the next element in the queue we just took from. + completion.submit(completionTask(writerId)) + r + // TODO use writerId + case ReceiverEpochMarker(writerId) => + // Don't read any more from this queue. If all the writers have sent epoch markers, + // the epoch is over; otherwise we need rows from one of the remaining writers. + val writersCompleted = numWriterEpochMarkers.incrementAndGet() + if (writersCompleted == numShuffleWriters) { + finished = true + null + } else { + getNext() + } + } + } + + override def close(): Unit = { + executor.shutdownNow() + } } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/shuffle/ContinuousShuffleReadSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/shuffle/ContinuousShuffleReadSuite.scala index 718fb0740bbe7..156c97e27bdda 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/shuffle/ContinuousShuffleReadSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/shuffle/ContinuousShuffleReadSuite.scala @@ -17,10 +17,13 @@ package org.apache.spark.sql.execution.streaming.continuous.shuffle +import scala.concurrent.Future + import org.apache.spark.{TaskContext, TaskContextImpl} import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeProjection} import org.apache.spark.sql.streaming.StreamTest -import org.apache.spark.sql.types.{DataType, IntegerType} +import org.apache.spark.sql.types.{DataType, IntegerType, StringType} +import org.apache.spark.unsafe.types.UTF8String class ContinuousShuffleReadSuite extends StreamTest { @@ -29,6 +32,11 @@ class ContinuousShuffleReadSuite extends StreamTest { new GenericInternalRow(Array(value: Any))) } + private def unsafeRow(value: String) = { + UnsafeProjection.create(Array(StringType : DataType))( + new GenericInternalRow(Array(UTF8String.fromString(value): Any))) + } + // In this unit test, we emulate that we're in the task thread where // ContinuousShuffleReadRDD.compute() will be evaluated. This requires a task context // thread local to be set. @@ -50,8 +58,8 @@ class ContinuousShuffleReadSuite extends StreamTest { test("receiver stopped with row last") { val rdd = new ContinuousShuffleReadRDD(sparkContext, numPartitions = 1) val endpoint = rdd.partitions(0).asInstanceOf[ContinuousShuffleReadPartition].endpoint - endpoint.askSync[Unit](ReceiverEpochMarker()) - endpoint.askSync[Unit](ReceiverRow(unsafeRow(111))) + endpoint.askSync[Unit](ReceiverEpochMarker(0)) + endpoint.askSync[Unit](ReceiverRow(0, unsafeRow(111))) ctx.markTaskCompleted(None) val receiver = rdd.partitions(0).asInstanceOf[ContinuousShuffleReadPartition].reader @@ -63,8 +71,8 @@ class ContinuousShuffleReadSuite extends StreamTest { test("receiver stopped with marker last") { val rdd = new ContinuousShuffleReadRDD(sparkContext, numPartitions = 1) val endpoint = rdd.partitions(0).asInstanceOf[ContinuousShuffleReadPartition].endpoint - endpoint.askSync[Unit](ReceiverRow(unsafeRow(111))) - endpoint.askSync[Unit](ReceiverEpochMarker()) + endpoint.askSync[Unit](ReceiverRow(0, unsafeRow(111))) + endpoint.askSync[Unit](ReceiverEpochMarker(0)) ctx.markTaskCompleted(None) val receiver = rdd.partitions(0).asInstanceOf[ContinuousShuffleReadPartition].reader @@ -76,10 +84,10 @@ class ContinuousShuffleReadSuite extends StreamTest { test("one epoch") { val rdd = new ContinuousShuffleReadRDD(sparkContext, numPartitions = 1) val endpoint = rdd.partitions(0).asInstanceOf[ContinuousShuffleReadPartition].endpoint - endpoint.askSync[Unit](ReceiverRow(unsafeRow(111))) - endpoint.askSync[Unit](ReceiverRow(unsafeRow(222))) - endpoint.askSync[Unit](ReceiverRow(unsafeRow(333))) - endpoint.askSync[Unit](ReceiverEpochMarker()) + endpoint.askSync[Unit](ReceiverRow(0, unsafeRow(111))) + endpoint.askSync[Unit](ReceiverRow(0, unsafeRow(222))) + endpoint.askSync[Unit](ReceiverRow(0, unsafeRow(333))) + endpoint.askSync[Unit](ReceiverEpochMarker(0)) val iter = rdd.compute(rdd.partitions(0), ctx) assert(iter.toSeq.map(_.getInt(0)) == Seq(111, 222, 333)) @@ -88,11 +96,11 @@ class ContinuousShuffleReadSuite extends StreamTest { test("multiple epochs") { val rdd = new ContinuousShuffleReadRDD(sparkContext, numPartitions = 1) val endpoint = rdd.partitions(0).asInstanceOf[ContinuousShuffleReadPartition].endpoint - endpoint.askSync[Unit](ReceiverRow(unsafeRow(111))) - endpoint.askSync[Unit](ReceiverEpochMarker()) - endpoint.askSync[Unit](ReceiverRow(unsafeRow(222))) - endpoint.askSync[Unit](ReceiverRow(unsafeRow(333))) - endpoint.askSync[Unit](ReceiverEpochMarker()) + endpoint.askSync[Unit](ReceiverRow(0, unsafeRow(111))) + endpoint.askSync[Unit](ReceiverEpochMarker(0)) + endpoint.askSync[Unit](ReceiverRow(0, unsafeRow(222))) + endpoint.askSync[Unit](ReceiverRow(0, unsafeRow(333))) + endpoint.askSync[Unit](ReceiverEpochMarker(0)) val firstEpoch = rdd.compute(rdd.partitions(0), ctx) assert(firstEpoch.toSeq.map(_.getInt(0)) == Seq(111)) @@ -101,15 +109,30 @@ class ContinuousShuffleReadSuite extends StreamTest { assert(secondEpoch.toSeq.map(_.getInt(0)) == Seq(222, 333)) } + test("multiple writers") { + val rdd = new ContinuousShuffleReadRDD(sparkContext, numPartitions = 1, numShuffleWriters = 3) + val endpoint = rdd.partitions(0).asInstanceOf[ContinuousShuffleReadPartition].endpoint + endpoint.askSync[Unit](ReceiverRow(0, unsafeRow("writer0-row0"))) + endpoint.askSync[Unit](ReceiverRow(1, unsafeRow("writer1-row0"))) + endpoint.askSync[Unit](ReceiverRow(2, unsafeRow("writer2-row0"))) + endpoint.askSync[Unit](ReceiverEpochMarker(0)) + endpoint.askSync[Unit](ReceiverEpochMarker(1)) + endpoint.askSync[Unit](ReceiverEpochMarker(2)) + + val firstEpoch = rdd.compute(rdd.partitions(0), ctx) + assert(firstEpoch.toSeq.map(_.getUTF8String(0).toString).toSet == + Set("writer0-row0", "writer1-row0", "writer2-row0")) + } + test("empty epochs") { val rdd = new ContinuousShuffleReadRDD(sparkContext, numPartitions = 1) val endpoint = rdd.partitions(0).asInstanceOf[ContinuousShuffleReadPartition].endpoint - endpoint.askSync[Unit](ReceiverEpochMarker()) - endpoint.askSync[Unit](ReceiverEpochMarker()) - endpoint.askSync[Unit](ReceiverRow(unsafeRow(111))) - endpoint.askSync[Unit](ReceiverEpochMarker()) - endpoint.askSync[Unit](ReceiverEpochMarker()) - endpoint.askSync[Unit](ReceiverEpochMarker()) + endpoint.askSync[Unit](ReceiverEpochMarker(0)) + endpoint.askSync[Unit](ReceiverEpochMarker(0)) + endpoint.askSync[Unit](ReceiverRow(0, unsafeRow(111))) + endpoint.askSync[Unit](ReceiverEpochMarker(0)) + endpoint.askSync[Unit](ReceiverEpochMarker(0)) + endpoint.askSync[Unit](ReceiverEpochMarker(0)) assert(rdd.compute(rdd.partitions(0), ctx).isEmpty) assert(rdd.compute(rdd.partitions(0), ctx).isEmpty) @@ -127,8 +150,8 @@ class ContinuousShuffleReadSuite extends StreamTest { for (p <- rdd.partitions) { val part = p.asInstanceOf[ContinuousShuffleReadPartition] // Send index for identification. - part.endpoint.askSync[Unit](ReceiverRow(unsafeRow(part.index))) - part.endpoint.askSync[Unit](ReceiverEpochMarker()) + part.endpoint.askSync[Unit](ReceiverRow(0, unsafeRow(part.index))) + part.endpoint.askSync[Unit](ReceiverEpochMarker(0)) } for (p <- rdd.partitions) { @@ -156,4 +179,42 @@ class ContinuousShuffleReadSuite extends StreamTest { assert(readRow.getState == Thread.State.WAITING) } } + + test("epoch only ends when all writers send markers") { + val rdd = new ContinuousShuffleReadRDD(sparkContext, numPartitions = 1, numShuffleWriters = 3) + val endpoint = rdd.partitions(0).asInstanceOf[ContinuousShuffleReadPartition].endpoint + endpoint.askSync[Unit](ReceiverRow(0, unsafeRow("writer0-row0"))) + endpoint.askSync[Unit](ReceiverRow(1, unsafeRow("writer1-row0"))) + endpoint.askSync[Unit](ReceiverRow(2, unsafeRow("writer2-row0"))) + endpoint.askSync[Unit](ReceiverEpochMarker(0)) + endpoint.askSync[Unit](ReceiverEpochMarker(2)) + + val epoch = rdd.compute(rdd.partitions(0), ctx) + val rows = (0 until 3).map(_ => epoch.next()).toSet + assert(rows.map(_.getUTF8String(0).toString) == + Set("writer0-row0", "writer1-row0", "writer2-row0")) + + // After checking the right rows, block until we get an epoch marker indicating there's no next. + // (Also fail the assertion if for some reason we get a row.) + val readEpochMarker = new Thread { + override def run(): Unit = { + assert(!epoch.hasNext) + } + } + + readEpochMarker.start() + + eventually(timeout(streamingTimeout)) { + assert(readEpochMarker.getState == Thread.State.WAITING) + } + + // Send the last epoch marker - now the epoch should finish. + endpoint.askSync[Unit](ReceiverEpochMarker(1)) + eventually(timeout(streamingTimeout)) { + !readEpochMarker.isAlive + } + + // Join to pick up assertion failures. + readEpochMarker.join() + } } From ad0b5aab320413891f7c21ea6115b6da8d49ccf9 Mon Sep 17 00:00:00 2001 From: Jose Torres Date: Thu, 24 May 2018 17:06:15 -0700 Subject: [PATCH 14/33] writer with 1 reader partition --- .../shuffle/ContinuousShuffleWriteRDD.scala | 45 ++++ .../shuffle/ContinuousShuffleWriter.scala | 46 ++++ .../shuffle/ContinuousShuffleSuite.scala | 211 ++++++++++++++++++ 3 files changed, 302 insertions(+) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleWriteRDD.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleWriter.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/shuffle/ContinuousShuffleSuite.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleWriteRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleWriteRDD.scala new file mode 100644 index 0000000000000..95678e0619f11 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleWriteRDD.scala @@ -0,0 +1,45 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.continuous.shuffle + +import org.apache.spark.{Partition, Partitioner, TaskContext} +import org.apache.spark.rdd.RDD +import org.apache.spark.rpc.RpcEndpointRef +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.UnsafeRow + +class ContinuousShuffleWriteRDD( + var prev: RDD[UnsafeRow], + outputPartitioner: Partitioner, + endpoints: Seq[RpcEndpointRef]) + extends RDD[Unit](prev) { + + override def getPartitions: Array[Partition] = prev.partitions + + override def compute(split: Partition, context: TaskContext): Iterator[Unit] = { + val writer = new ContinuousShuffleWriter(split.index, outputPartitioner, endpoints) + writer.write(prev.compute(split, context)) + + Iterator() + } + + override def clearDependencies() { + super.clearDependencies() + prev = null + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleWriter.scala new file mode 100644 index 0000000000000..cfcf080516c51 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleWriter.scala @@ -0,0 +1,46 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.continuous.shuffle + +import org.apache.spark.Partitioner +import org.apache.spark.rpc.RpcEndpointRef +import org.apache.spark.sql.catalyst.expressions.UnsafeRow + +class ContinuousShuffleWriter( + writerId: Int, + outputPartitioner: Partitioner, + endpoints: Seq[RpcEndpointRef]) { + + if (outputPartitioner.numPartitions != 1) { + throw new IllegalArgumentException("multiple readers not yet supported") + } + + if (outputPartitioner.numPartitions != endpoints.size) { + throw new IllegalArgumentException(s"partitioner size ${outputPartitioner.numPartitions} did " + + s"not match endpoint count ${endpoints.size}") + } + + def write(epoch: Iterator[UnsafeRow]): Unit = { + while (epoch.hasNext) { + val row = epoch.next() + endpoints(outputPartitioner.getPartition(row)).ask[Unit](ReceiverRow(writerId, row)) + } + + endpoints.foreach(_.ask[Unit](ReceiverEpochMarker(writerId))) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/shuffle/ContinuousShuffleSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/shuffle/ContinuousShuffleSuite.scala new file mode 100644 index 0000000000000..90fcb7a80fe1f --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/shuffle/ContinuousShuffleSuite.scala @@ -0,0 +1,211 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.streaming.continuous.shuffle + +import scala.collection.mutable + +import org.apache.spark.{HashPartitioner, Partition, TaskContext, TaskContextImpl} +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeProjection, UnsafeRow} +import org.apache.spark.sql.execution.streaming.continuous.shuffle.{ContinuousShuffleReadPartition, ContinuousShuffleReadRDD, ContinuousShuffleWriteRDD} +import org.apache.spark.sql.streaming.StreamTest +import org.apache.spark.sql.types.{DataType, IntegerType} + +class ContinuousShuffleSuite extends StreamTest { + // In this unit test, we emulate that we're in the task thread where + // ContinuousShuffleReadRDD.compute() will be evaluated. This requires a task context + // thread local to be set. + var ctx: TaskContextImpl = _ + + override def beforeEach(): Unit = { + super.beforeEach() + ctx = TaskContext.empty() + TaskContext.setTaskContext(ctx) + } + + override def afterEach(): Unit = { + ctx.markTaskCompleted(None) + TaskContext.unset() + ctx = null + super.afterEach() + } + + private case class SimplePartition(index: Int) extends Partition + + /** + * An RDD that simulates multiple continuous processing epochs, with each epoch corresponding + * to one entry in the outer epochData array. The data in the inner array is round-robined across + * the specified number of partitions. + */ + private class MultipleEpochRDD(numPartitions: Int, epochData: Array[Int]*) + extends RDD[UnsafeRow](sparkContext, Nil) { + override def getPartitions: Array[Partition] = { + (0 until numPartitions).map(SimplePartition).toArray + } + + private val currentEpochForPartition = mutable.Map[Int, Int]().withDefaultValue(0) + + override def compute(split: Partition, ctx: TaskContext): Iterator[UnsafeRow] = { + val epoch = epochData(currentEpochForPartition(split.index)).zipWithIndex.collect { + case (value, idx) if idx % numPartitions == split.index => unsafeRow(value) + } + + currentEpochForPartition(split.index) += 1 + epoch.toIterator + } + } + + private def unsafeRow(value: Int) = { + UnsafeProjection.create(Array(IntegerType : DataType))( + new GenericInternalRow(Array(value: Any))) + } + + private def readRDDEndpoint(rdd: ContinuousShuffleReadRDD) = { + rdd.partitions(0).asInstanceOf[ContinuousShuffleReadPartition].endpoint + } + + private def writeEpoch(rdd: ContinuousShuffleWriteRDD, partition: Int = 0) = { + rdd.compute(rdd.partitions(partition), ctx) + } + + private def readEpoch(rdd: ContinuousShuffleReadRDD) = { + rdd.compute(rdd.partitions(0), ctx).toSeq.map(_.getInt(0)) + } + + test("one epoch") { + val data = sparkContext.parallelize(Seq(1, 2, 3).map(unsafeRow), 1) + + val reader = new ContinuousShuffleReadRDD(sparkContext, numPartitions = 1) + val writer = new ContinuousShuffleWriteRDD( + data, new HashPartitioner(1), Seq(readRDDEndpoint(reader))) + + writeEpoch(writer) + + assert(readEpoch(reader) == Seq(1, 2, 3)) + } + + test("multiple epochs") { + val data = new MultipleEpochRDD(1, Array(1, 2, 3), Array(4, 5, 6)) + + val reader = new ContinuousShuffleReadRDD(sparkContext, numPartitions = 1) + val writer = new ContinuousShuffleWriteRDD( + data, new HashPartitioner(1), Seq(readRDDEndpoint(reader))) + + writeEpoch(writer) + writeEpoch(writer) + + assert(readEpoch(reader) == Seq(1, 2, 3)) + assert(readEpoch(reader) == Seq(4, 5, 6)) + } + + test("empty epochs") { + val data = new MultipleEpochRDD(1, Array(), Array(1, 2), Array(), Array(), Array(3, 4), Array()) + + val reader = new ContinuousShuffleReadRDD(sparkContext, numPartitions = 1) + val writer = new ContinuousShuffleWriteRDD( + data, new HashPartitioner(1), Seq(readRDDEndpoint(reader))) + + for (_ <- 0 to 5) { + writeEpoch(writer) + } + + assert(readEpoch(reader) == Seq()) + assert(readEpoch(reader) == Seq(1, 2)) + assert(readEpoch(reader) == Seq()) + assert(readEpoch(reader) == Seq()) + assert(readEpoch(reader) == Seq(3, 4)) + assert(readEpoch(reader) == Seq()) + } + + test("blocks waiting for writer") { + val data = new MultipleEpochRDD(1, Array(1)) + + val reader = new ContinuousShuffleReadRDD(sparkContext, numPartitions = 1) + val writer = new ContinuousShuffleWriteRDD( + data, new HashPartitioner(1), Seq(readRDDEndpoint(reader))) + + val readerEpoch = reader.compute(reader.partitions(0), ctx) + + val readRowThread = new Thread { + override def run(): Unit = { + assert(readerEpoch.toSeq.map(_.getInt(0)) == Seq(1)) + } + } + readRowThread.start() + + eventually(timeout(streamingTimeout)) { + assert(readRowThread.getState == Thread.State.WAITING) + } + + // Once we write the epoch the thread should stop waiting and succeed. + writeEpoch(writer) + readRowThread.join() + } + + test("multiple writer partitions") { + val numWriterPartitions = 3 + val data = new MultipleEpochRDD( + numWriterPartitions, Array(1, 2, 3, 4, 5, 6, 7), Array(4, 5, 6, 7, 8, 9, 10)) + + val reader = new ContinuousShuffleReadRDD( + sparkContext, numPartitions = 1, numShuffleWriters = numWriterPartitions) + val writer = new ContinuousShuffleWriteRDD( + data, new HashPartitioner(1), Seq(readRDDEndpoint(reader))) + + writeEpoch(writer, 0) + writeEpoch(writer, 1) + writeEpoch(writer, 2) + writeEpoch(writer, 0) + writeEpoch(writer, 1) + writeEpoch(writer, 2) + + // Since there are multiple asynchronous writers, the original row sequencing is not guaranteed. + // The epochs should be deterministically preserved, however. + assert(readEpoch(reader).toSet == Seq(1, 2, 3, 4, 5, 6, 7).toSet) + assert(readEpoch(reader).toSet == Seq(4, 5, 6, 7, 8, 9, 10).toSet) + } + + test("reader epoch only ends when all writer partitions write it") { + val numWriterPartitions = 3 + val data = new MultipleEpochRDD(numWriterPartitions, Array()) + + val reader = new ContinuousShuffleReadRDD( + sparkContext, numPartitions = 1, numShuffleWriters = numWriterPartitions) + val writer = new ContinuousShuffleWriteRDD( + data, new HashPartitioner(1), Seq(readRDDEndpoint(reader))) + + writeEpoch(writer, 1) + writeEpoch(writer, 2) + + val readerEpoch = reader.compute(reader.partitions(0), ctx) + + val readEpochMarkerThread = new Thread { + override def run(): Unit = { + assert(!readerEpoch.hasNext) + } + } + + readEpochMarkerThread.start() + eventually(timeout(streamingTimeout)) { + assert(readEpochMarkerThread.getState == Thread.State.WAITING) + } + + writeEpoch(writer, 0) + readEpochMarkerThread.join() + } +} From c9adee5423c2e8a030911008d2e6942045d484bb Mon Sep 17 00:00:00 2001 From: Jose Torres Date: Thu, 24 May 2018 17:15:39 -0700 Subject: [PATCH 15/33] docs and iface --- .../shuffle/ContinuousShuffleWriteRDD.scala | 11 +++- .../shuffle/ContinuousShuffleWriter.scala | 29 ++-------- .../continuous/shuffle/UnsafeRowWriter.scala | 54 +++++++++++++++++++ 3 files changed, 68 insertions(+), 26 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/UnsafeRowWriter.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleWriteRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleWriteRDD.scala index 95678e0619f11..0678d44bb6929 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleWriteRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleWriteRDD.scala @@ -20,9 +20,15 @@ package org.apache.spark.sql.execution.streaming.continuous.shuffle import org.apache.spark.{Partition, Partitioner, TaskContext} import org.apache.spark.rdd.RDD import org.apache.spark.rpc.RpcEndpointRef -import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.UnsafeRow +/** + * + * @param prev The RDD to write to the continuous shuffle. + * @param outputPartitioner The partitioner on the reader side of the shuffle. + * @param endpoints The [[UnsafeRowReceiver]] endpoints to write to. Indexed by partition ID within + * outputPartitioner. + */ class ContinuousShuffleWriteRDD( var prev: RDD[UnsafeRow], outputPartitioner: Partitioner, @@ -32,7 +38,8 @@ class ContinuousShuffleWriteRDD( override def getPartitions: Array[Partition] = prev.partitions override def compute(split: Partition, context: TaskContext): Iterator[Unit] = { - val writer = new ContinuousShuffleWriter(split.index, outputPartitioner, endpoints) + val writer: ContinuousShuffleWriter = + new UnsafeRowWriter(split.index, outputPartitioner, endpoints.toArray) writer.write(prev.compute(split, context)) Iterator() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleWriter.scala index cfcf080516c51..47b1f78b24505 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleWriter.scala @@ -17,30 +17,11 @@ package org.apache.spark.sql.execution.streaming.continuous.shuffle -import org.apache.spark.Partitioner -import org.apache.spark.rpc.RpcEndpointRef import org.apache.spark.sql.catalyst.expressions.UnsafeRow -class ContinuousShuffleWriter( - writerId: Int, - outputPartitioner: Partitioner, - endpoints: Seq[RpcEndpointRef]) { - - if (outputPartitioner.numPartitions != 1) { - throw new IllegalArgumentException("multiple readers not yet supported") - } - - if (outputPartitioner.numPartitions != endpoints.size) { - throw new IllegalArgumentException(s"partitioner size ${outputPartitioner.numPartitions} did " + - s"not match endpoint count ${endpoints.size}") - } - - def write(epoch: Iterator[UnsafeRow]): Unit = { - while (epoch.hasNext) { - val row = epoch.next() - endpoints(outputPartitioner.getPartition(row)).ask[Unit](ReceiverRow(writerId, row)) - } - - endpoints.foreach(_.ask[Unit](ReceiverEpochMarker(writerId))) - } +/** + * Trait for writing to a continuous processing shuffle. + */ +trait ContinuousShuffleWriter { + def write(epoch: Iterator[UnsafeRow]): Unit } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/UnsafeRowWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/UnsafeRowWriter.scala new file mode 100644 index 0000000000000..0d17e968f9c08 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/UnsafeRowWriter.scala @@ -0,0 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.continuous.shuffle + +import org.apache.spark.Partitioner +import org.apache.spark.rpc.RpcEndpointRef +import org.apache.spark.sql.catalyst.expressions.UnsafeRow + +/** + * A [[ContinuousShuffleWriter]] sending data to [[UnsafeRowReceiver]] instances. + * + * @param writerId The partition ID of this writer. + * @param outputPartitioner The partitioner on the reader side of the shuffle. + * @param endpoints The [[UnsafeRowReceiver]] endpoints to write to. Indexed by partition ID within + * outputPartitioner. + */ +class UnsafeRowWriter( + writerId: Int, + outputPartitioner: Partitioner, + endpoints: Array[RpcEndpointRef]) extends ContinuousShuffleWriter { + + if (outputPartitioner.numPartitions != 1) { + throw new IllegalArgumentException("multiple readers not yet supported") + } + + if (outputPartitioner.numPartitions != endpoints.length) { + throw new IllegalArgumentException(s"partitioner size ${outputPartitioner.numPartitions} did " + + s"not match endpoint count ${endpoints.length}") + } + + def write(epoch: Iterator[UnsafeRow]): Unit = { + while (epoch.hasNext) { + val row = epoch.next() + endpoints(outputPartitioner.getPartition(row)).ask[Unit](ReceiverRow(writerId, row)) + } + + endpoints.foreach(_.ask[Unit](ReceiverEpochMarker(writerId))) + } +} From 331f437423262a1aa76754a8079d7c017e4ea28a Mon Sep 17 00:00:00 2001 From: Jose Torres Date: Thu, 24 May 2018 17:37:14 -0700 Subject: [PATCH 16/33] increment epoch --- .../continuous/shuffle/ContinuousShuffleWriteRDD.scala | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleWriteRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleWriteRDD.scala index 0678d44bb6929..8c40012840dd1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleWriteRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleWriteRDD.scala @@ -18,9 +18,11 @@ package org.apache.spark.sql.execution.streaming.continuous.shuffle import org.apache.spark.{Partition, Partitioner, TaskContext} + import org.apache.spark.rdd.RDD import org.apache.spark.rpc.RpcEndpointRef import org.apache.spark.sql.catalyst.expressions.UnsafeRow +import org.apache.spark.sql.execution.streaming.continuous.EpochTracker /** * @@ -42,6 +44,8 @@ class ContinuousShuffleWriteRDD( new UnsafeRowWriter(split.index, outputPartitioner, endpoints.toArray) writer.write(prev.compute(split, context)) + EpochTracker.incrementCurrentEpoch() + Iterator() } From f3ce67529372f72370a1e6028dc71a751acf26f2 Mon Sep 17 00:00:00 2001 From: Jose Torres Date: Thu, 24 May 2018 17:40:39 -0700 Subject: [PATCH 17/33] undo oop --- .../continuous/shuffle/ContinuousShuffleWriteRDD.scala | 4 ---- 1 file changed, 4 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleWriteRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleWriteRDD.scala index 8c40012840dd1..0678d44bb6929 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleWriteRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleWriteRDD.scala @@ -18,11 +18,9 @@ package org.apache.spark.sql.execution.streaming.continuous.shuffle import org.apache.spark.{Partition, Partitioner, TaskContext} - import org.apache.spark.rdd.RDD import org.apache.spark.rpc.RpcEndpointRef import org.apache.spark.sql.catalyst.expressions.UnsafeRow -import org.apache.spark.sql.execution.streaming.continuous.EpochTracker /** * @@ -44,8 +42,6 @@ class ContinuousShuffleWriteRDD( new UnsafeRowWriter(split.index, outputPartitioner, endpoints.toArray) writer.write(prev.compute(split, context)) - EpochTracker.incrementCurrentEpoch() - Iterator() } From e0108d7bc164b9e5eeb757c13c80bc1d11671188 Mon Sep 17 00:00:00 2001 From: Jose Torres Date: Thu, 24 May 2018 17:54:01 -0700 Subject: [PATCH 18/33] make rdd loop --- .../shuffle/ContinuousShuffleWriteRDD.scala | 10 ++- .../shuffle/ContinuousShuffleSuite.scala | 74 +++++++++---------- 2 files changed, 43 insertions(+), 41 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleWriteRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleWriteRDD.scala index 0678d44bb6929..cad37b79baa80 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleWriteRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleWriteRDD.scala @@ -21,8 +21,10 @@ import org.apache.spark.{Partition, Partitioner, TaskContext} import org.apache.spark.rdd.RDD import org.apache.spark.rpc.RpcEndpointRef import org.apache.spark.sql.catalyst.expressions.UnsafeRow +import org.apache.spark.sql.execution.streaming.continuous.{ContinuousExecution, EpochTracker} /** + * An RDD which continuously writes epochs from its child into a continuous shuffle. * * @param prev The RDD to write to the continuous shuffle. * @param outputPartitioner The partitioner on the reader side of the shuffle. @@ -38,9 +40,15 @@ class ContinuousShuffleWriteRDD( override def getPartitions: Array[Partition] = prev.partitions override def compute(split: Partition, context: TaskContext): Iterator[Unit] = { + EpochTracker.initializeCurrentEpoch( + context.getLocalProperty(ContinuousExecution.START_EPOCH_KEY).toLong) val writer: ContinuousShuffleWriter = new UnsafeRowWriter(split.index, outputPartitioner, endpoints.toArray) - writer.write(prev.compute(split, context)) + + while (!context.isInterrupted() && !context.isCompleted()) { + writer.write(prev.compute(split, context)) + EpochTracker.incrementCurrentEpoch() + } Iterator() } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/shuffle/ContinuousShuffleSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/shuffle/ContinuousShuffleSuite.scala index 90fcb7a80fe1f..879ced70646ef 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/shuffle/ContinuousShuffleSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/shuffle/ContinuousShuffleSuite.scala @@ -22,7 +22,7 @@ import scala.collection.mutable import org.apache.spark.{HashPartitioner, Partition, TaskContext, TaskContextImpl} import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeProjection, UnsafeRow} -import org.apache.spark.sql.execution.streaming.continuous.shuffle.{ContinuousShuffleReadPartition, ContinuousShuffleReadRDD, ContinuousShuffleWriteRDD} +import org.apache.spark.sql.execution.streaming.continuous.shuffle.{ContinuousShuffleReadPartition, ContinuousShuffleReadRDD, ContinuousShuffleWriteRDD, UnsafeRowWriter} import org.apache.spark.sql.streaming.StreamTest import org.apache.spark.sql.types.{DataType, IntegerType} @@ -70,7 +70,7 @@ class ContinuousShuffleSuite extends StreamTest { } } - private def unsafeRow(value: Int) = { + private implicit def unsafeRow(value: Int) = { UnsafeProjection.create(Array(IntegerType : DataType))( new GenericInternalRow(Array(value: Any))) } @@ -88,26 +88,20 @@ class ContinuousShuffleSuite extends StreamTest { } test("one epoch") { - val data = sparkContext.parallelize(Seq(1, 2, 3).map(unsafeRow), 1) - val reader = new ContinuousShuffleReadRDD(sparkContext, numPartitions = 1) - val writer = new ContinuousShuffleWriteRDD( - data, new HashPartitioner(1), Seq(readRDDEndpoint(reader))) + val writer = new UnsafeRowWriter(0, new HashPartitioner(1), Array(readRDDEndpoint(reader))) - writeEpoch(writer) + writer.write(Iterator(1, 2, 3)) assert(readEpoch(reader) == Seq(1, 2, 3)) } test("multiple epochs") { - val data = new MultipleEpochRDD(1, Array(1, 2, 3), Array(4, 5, 6)) - val reader = new ContinuousShuffleReadRDD(sparkContext, numPartitions = 1) - val writer = new ContinuousShuffleWriteRDD( - data, new HashPartitioner(1), Seq(readRDDEndpoint(reader))) + val writer = new UnsafeRowWriter(0, new HashPartitioner(1), Array(readRDDEndpoint(reader))) - writeEpoch(writer) - writeEpoch(writer) + writer.write(Iterator(1, 2, 3)) + writer.write(Iterator(4, 5, 6)) assert(readEpoch(reader) == Seq(1, 2, 3)) assert(readEpoch(reader) == Seq(4, 5, 6)) @@ -117,12 +111,14 @@ class ContinuousShuffleSuite extends StreamTest { val data = new MultipleEpochRDD(1, Array(), Array(1, 2), Array(), Array(), Array(3, 4), Array()) val reader = new ContinuousShuffleReadRDD(sparkContext, numPartitions = 1) - val writer = new ContinuousShuffleWriteRDD( - data, new HashPartitioner(1), Seq(readRDDEndpoint(reader))) + val writer = new UnsafeRowWriter(0, new HashPartitioner(1), Array(readRDDEndpoint(reader))) - for (_ <- 0 to 5) { - writeEpoch(writer) - } + writer.write(Iterator()) + writer.write(Iterator(1, 2)) + writer.write(Iterator()) + writer.write(Iterator()) + writer.write(Iterator(3, 4)) + writer.write(Iterator()) assert(readEpoch(reader) == Seq()) assert(readEpoch(reader) == Seq(1, 2)) @@ -133,11 +129,8 @@ class ContinuousShuffleSuite extends StreamTest { } test("blocks waiting for writer") { - val data = new MultipleEpochRDD(1, Array(1)) - val reader = new ContinuousShuffleReadRDD(sparkContext, numPartitions = 1) - val writer = new ContinuousShuffleWriteRDD( - data, new HashPartitioner(1), Seq(readRDDEndpoint(reader))) + val writer = new UnsafeRowWriter(0, new HashPartitioner(1), Array(readRDDEndpoint(reader))) val readerEpoch = reader.compute(reader.partitions(0), ctx) @@ -149,30 +142,30 @@ class ContinuousShuffleSuite extends StreamTest { readRowThread.start() eventually(timeout(streamingTimeout)) { - assert(readRowThread.getState == Thread.State.WAITING) + assert(readRowThread.getState == Thread.State.TIMED_WAITING) } // Once we write the epoch the thread should stop waiting and succeed. - writeEpoch(writer) + writer.write(Iterator(1)) readRowThread.join() } test("multiple writer partitions") { val numWriterPartitions = 3 - val data = new MultipleEpochRDD( - numWriterPartitions, Array(1, 2, 3, 4, 5, 6, 7), Array(4, 5, 6, 7, 8, 9, 10)) val reader = new ContinuousShuffleReadRDD( sparkContext, numPartitions = 1, numShuffleWriters = numWriterPartitions) - val writer = new ContinuousShuffleWriteRDD( - data, new HashPartitioner(1), Seq(readRDDEndpoint(reader))) + val writers = (0 until 3).map { idx => + new UnsafeRowWriter(idx, new HashPartitioner(1), Array(readRDDEndpoint(reader))) + } + + writers(0).write(Iterator(1, 4, 7)) + writers(1).write(Iterator(2, 5)) + writers(2).write(Iterator(3, 6)) - writeEpoch(writer, 0) - writeEpoch(writer, 1) - writeEpoch(writer, 2) - writeEpoch(writer, 0) - writeEpoch(writer, 1) - writeEpoch(writer, 2) + writers(0).write(Iterator(4, 7, 10)) + writers(1).write(Iterator(5, 8)) + writers(2).write(Iterator(6, 9)) // Since there are multiple asynchronous writers, the original row sequencing is not guaranteed. // The epochs should be deterministically preserved, however. @@ -186,11 +179,12 @@ class ContinuousShuffleSuite extends StreamTest { val reader = new ContinuousShuffleReadRDD( sparkContext, numPartitions = 1, numShuffleWriters = numWriterPartitions) - val writer = new ContinuousShuffleWriteRDD( - data, new HashPartitioner(1), Seq(readRDDEndpoint(reader))) + val writers = (0 until 3).map { idx => + new UnsafeRowWriter(idx, new HashPartitioner(1), Array(readRDDEndpoint(reader))) + } - writeEpoch(writer, 1) - writeEpoch(writer, 2) + writers(1).write(Iterator()) + writers(2).write(Iterator()) val readerEpoch = reader.compute(reader.partitions(0), ctx) @@ -202,10 +196,10 @@ class ContinuousShuffleSuite extends StreamTest { readEpochMarkerThread.start() eventually(timeout(streamingTimeout)) { - assert(readEpochMarkerThread.getState == Thread.State.WAITING) + assert(readEpochMarkerThread.getState == Thread.State.TIMED_WAITING) } - writeEpoch(writer, 0) + writers(0).write(Iterator()) readEpochMarkerThread.join() } } From f40065166f848dcc117865f15346f90a904c9153 Mon Sep 17 00:00:00 2001 From: Jose Torres Date: Fri, 25 May 2018 16:00:16 -0700 Subject: [PATCH 19/33] remote write RDD --- .../shuffle/ContinuousShuffleWriteRDD.scala | 60 ------------------- .../shuffle/ContinuousShuffleSuite.scala | 6 +- 2 files changed, 1 insertion(+), 65 deletions(-) delete mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleWriteRDD.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleWriteRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleWriteRDD.scala deleted file mode 100644 index cad37b79baa80..0000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleWriteRDD.scala +++ /dev/null @@ -1,60 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution.streaming.continuous.shuffle - -import org.apache.spark.{Partition, Partitioner, TaskContext} -import org.apache.spark.rdd.RDD -import org.apache.spark.rpc.RpcEndpointRef -import org.apache.spark.sql.catalyst.expressions.UnsafeRow -import org.apache.spark.sql.execution.streaming.continuous.{ContinuousExecution, EpochTracker} - -/** - * An RDD which continuously writes epochs from its child into a continuous shuffle. - * - * @param prev The RDD to write to the continuous shuffle. - * @param outputPartitioner The partitioner on the reader side of the shuffle. - * @param endpoints The [[UnsafeRowReceiver]] endpoints to write to. Indexed by partition ID within - * outputPartitioner. - */ -class ContinuousShuffleWriteRDD( - var prev: RDD[UnsafeRow], - outputPartitioner: Partitioner, - endpoints: Seq[RpcEndpointRef]) - extends RDD[Unit](prev) { - - override def getPartitions: Array[Partition] = prev.partitions - - override def compute(split: Partition, context: TaskContext): Iterator[Unit] = { - EpochTracker.initializeCurrentEpoch( - context.getLocalProperty(ContinuousExecution.START_EPOCH_KEY).toLong) - val writer: ContinuousShuffleWriter = - new UnsafeRowWriter(split.index, outputPartitioner, endpoints.toArray) - - while (!context.isInterrupted() && !context.isCompleted()) { - writer.write(prev.compute(split, context)) - EpochTracker.incrementCurrentEpoch() - } - - Iterator() - } - - override def clearDependencies() { - super.clearDependencies() - prev = null - } -} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/shuffle/ContinuousShuffleSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/shuffle/ContinuousShuffleSuite.scala index 879ced70646ef..5ac258dfcd340 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/shuffle/ContinuousShuffleSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/shuffle/ContinuousShuffleSuite.scala @@ -22,7 +22,7 @@ import scala.collection.mutable import org.apache.spark.{HashPartitioner, Partition, TaskContext, TaskContextImpl} import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeProjection, UnsafeRow} -import org.apache.spark.sql.execution.streaming.continuous.shuffle.{ContinuousShuffleReadPartition, ContinuousShuffleReadRDD, ContinuousShuffleWriteRDD, UnsafeRowWriter} +import org.apache.spark.sql.execution.streaming.continuous.shuffle.{ContinuousShuffleReadPartition, ContinuousShuffleReadRDD, UnsafeRowWriter} import org.apache.spark.sql.streaming.StreamTest import org.apache.spark.sql.types.{DataType, IntegerType} @@ -79,10 +79,6 @@ class ContinuousShuffleSuite extends StreamTest { rdd.partitions(0).asInstanceOf[ContinuousShuffleReadPartition].endpoint } - private def writeEpoch(rdd: ContinuousShuffleWriteRDD, partition: Int = 0) = { - rdd.compute(rdd.partitions(partition), ctx) - } - private def readEpoch(rdd: ContinuousShuffleReadRDD) = { rdd.compute(rdd.partitions(0), ctx).toSeq.map(_.getInt(0)) } From 1aaad8d7660d1d6cd2abbca10d67ef724b4a0dcc Mon Sep 17 00:00:00 2001 From: Jose Torres Date: Fri, 25 May 2018 16:32:53 -0700 Subject: [PATCH 20/33] rename classes --- .../shuffle/ContinuousShuffleReadRDD.scala | 3 ++- ....scala => RPCContinuousShuffleReader.scala} | 2 +- ....scala => RPCContinuousShuffleWriter.scala} | 10 +++++----- .../shuffle/ContinuousShuffleReadSuite.scala | 4 ++-- .../shuffle/ContinuousShuffleSuite.scala | 18 +++++++++++------- 5 files changed, 21 insertions(+), 16 deletions(-) rename sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/{UnsafeRowReceiver.scala => RPCContinuousShuffleReader.scala} (99%) rename sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/{UnsafeRowWriter.scala => RPCContinuousShuffleWriter.scala} (84%) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleReadRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleReadRDD.scala index f7d28c948596b..e07f7bb70b396 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleReadRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleReadRDD.scala @@ -34,7 +34,8 @@ case class ContinuousShuffleReadPartition( // Initialized only on the executor, and only once even as we call compute() multiple times. lazy val (reader: ContinuousShuffleReader, endpoint) = { val env = SparkEnv.get.rpcEnv - val receiver = new UnsafeRowReceiver(queueSize, numShuffleWriters, epochIntervalMs, env) + val receiver = new RPCContinuousShuffleReader( + queueSize, numShuffleWriters, epochIntervalMs, env) val endpoint = env.setupEndpoint(s"UnsafeRowReceiver-${UUID.randomUUID()}", receiver) TaskContext.get().addTaskCompletionListener { ctx => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/UnsafeRowReceiver.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/RPCContinuousShuffleReader.scala similarity index 99% rename from sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/UnsafeRowReceiver.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/RPCContinuousShuffleReader.scala index 4bb2c2d2ce2b7..2bb5bbeef8eaa 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/UnsafeRowReceiver.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/RPCContinuousShuffleReader.scala @@ -46,7 +46,7 @@ private[shuffle] case class ReceiverEpochMarker(writerId: Int) extends UnsafeRow * TODO: Support multiple source tasks. We need to output a single epoch marker once all * source tasks have sent one. */ -private[shuffle] class UnsafeRowReceiver( +private[shuffle] class RPCContinuousShuffleReader( queueSize: Int, numShuffleWriters: Int, epochIntervalMs: Long, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/UnsafeRowWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/RPCContinuousShuffleWriter.scala similarity index 84% rename from sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/UnsafeRowWriter.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/RPCContinuousShuffleWriter.scala index 0d17e968f9c08..a1bef1c115feb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/UnsafeRowWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/RPCContinuousShuffleWriter.scala @@ -22,14 +22,14 @@ import org.apache.spark.rpc.RpcEndpointRef import org.apache.spark.sql.catalyst.expressions.UnsafeRow /** - * A [[ContinuousShuffleWriter]] sending data to [[UnsafeRowReceiver]] instances. + * A [[ContinuousShuffleWriter]] sending data to [[RPCContinuousShuffleReader]] instances. * - * @param writerId The partition ID of this writer. + * @param writerId The partition ID of this writer. * @param outputPartitioner The partitioner on the reader side of the shuffle. - * @param endpoints The [[UnsafeRowReceiver]] endpoints to write to. Indexed by partition ID within - * outputPartitioner. + * @param endpoints The [[RPCContinuousShuffleReader]] endpoints to write to. Indexed by + * partition ID within outputPartitioner. */ -class UnsafeRowWriter( +class RPCContinuousShuffleWriter( writerId: Int, outputPartitioner: Partitioner, endpoints: Array[RpcEndpointRef]) extends ContinuousShuffleWriter { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/shuffle/ContinuousShuffleReadSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/shuffle/ContinuousShuffleReadSuite.scala index 974fd5569c908..1044316b0746f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/shuffle/ContinuousShuffleReadSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/shuffle/ContinuousShuffleReadSuite.scala @@ -70,7 +70,7 @@ class ContinuousShuffleReadSuite extends StreamTest { ctx.markTaskCompleted(None) val receiver = rdd.partitions(0).asInstanceOf[ContinuousShuffleReadPartition].reader eventually(timeout(streamingTimeout)) { - assert(receiver.asInstanceOf[UnsafeRowReceiver].stopped.get()) + assert(receiver.asInstanceOf[RPCContinuousShuffleReader].stopped.get()) } } @@ -86,7 +86,7 @@ class ContinuousShuffleReadSuite extends StreamTest { ctx.markTaskCompleted(None) val receiver = rdd.partitions(0).asInstanceOf[ContinuousShuffleReadPartition].reader eventually(timeout(streamingTimeout)) { - assert(receiver.asInstanceOf[UnsafeRowReceiver].stopped.get()) + assert(receiver.asInstanceOf[RPCContinuousShuffleReader].stopped.get()) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/shuffle/ContinuousShuffleSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/shuffle/ContinuousShuffleSuite.scala index 5ac258dfcd340..7763a348ffa53 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/shuffle/ContinuousShuffleSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/shuffle/ContinuousShuffleSuite.scala @@ -22,7 +22,7 @@ import scala.collection.mutable import org.apache.spark.{HashPartitioner, Partition, TaskContext, TaskContextImpl} import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeProjection, UnsafeRow} -import org.apache.spark.sql.execution.streaming.continuous.shuffle.{ContinuousShuffleReadPartition, ContinuousShuffleReadRDD, UnsafeRowWriter} +import org.apache.spark.sql.execution.streaming.continuous.shuffle.{ContinuousShuffleReadPartition, ContinuousShuffleReadRDD, RPCContinuousShuffleWriter} import org.apache.spark.sql.streaming.StreamTest import org.apache.spark.sql.types.{DataType, IntegerType} @@ -85,7 +85,8 @@ class ContinuousShuffleSuite extends StreamTest { test("one epoch") { val reader = new ContinuousShuffleReadRDD(sparkContext, numPartitions = 1) - val writer = new UnsafeRowWriter(0, new HashPartitioner(1), Array(readRDDEndpoint(reader))) + val writer = new RPCContinuousShuffleWriter( + 0, new HashPartitioner(1), Array(readRDDEndpoint(reader))) writer.write(Iterator(1, 2, 3)) @@ -94,7 +95,8 @@ class ContinuousShuffleSuite extends StreamTest { test("multiple epochs") { val reader = new ContinuousShuffleReadRDD(sparkContext, numPartitions = 1) - val writer = new UnsafeRowWriter(0, new HashPartitioner(1), Array(readRDDEndpoint(reader))) + val writer = new RPCContinuousShuffleWriter( + 0, new HashPartitioner(1), Array(readRDDEndpoint(reader))) writer.write(Iterator(1, 2, 3)) writer.write(Iterator(4, 5, 6)) @@ -107,7 +109,8 @@ class ContinuousShuffleSuite extends StreamTest { val data = new MultipleEpochRDD(1, Array(), Array(1, 2), Array(), Array(), Array(3, 4), Array()) val reader = new ContinuousShuffleReadRDD(sparkContext, numPartitions = 1) - val writer = new UnsafeRowWriter(0, new HashPartitioner(1), Array(readRDDEndpoint(reader))) + val writer = new RPCContinuousShuffleWriter( + 0, new HashPartitioner(1), Array(readRDDEndpoint(reader))) writer.write(Iterator()) writer.write(Iterator(1, 2)) @@ -126,7 +129,8 @@ class ContinuousShuffleSuite extends StreamTest { test("blocks waiting for writer") { val reader = new ContinuousShuffleReadRDD(sparkContext, numPartitions = 1) - val writer = new UnsafeRowWriter(0, new HashPartitioner(1), Array(readRDDEndpoint(reader))) + val writer = new RPCContinuousShuffleWriter( + 0, new HashPartitioner(1), Array(readRDDEndpoint(reader))) val readerEpoch = reader.compute(reader.partitions(0), ctx) @@ -152,7 +156,7 @@ class ContinuousShuffleSuite extends StreamTest { val reader = new ContinuousShuffleReadRDD( sparkContext, numPartitions = 1, numShuffleWriters = numWriterPartitions) val writers = (0 until 3).map { idx => - new UnsafeRowWriter(idx, new HashPartitioner(1), Array(readRDDEndpoint(reader))) + new RPCContinuousShuffleWriter(idx, new HashPartitioner(1), Array(readRDDEndpoint(reader))) } writers(0).write(Iterator(1, 4, 7)) @@ -176,7 +180,7 @@ class ContinuousShuffleSuite extends StreamTest { val reader = new ContinuousShuffleReadRDD( sparkContext, numPartitions = 1, numShuffleWriters = numWriterPartitions) val writers = (0 until 3).map { idx => - new UnsafeRowWriter(idx, new HashPartitioner(1), Array(readRDDEndpoint(reader))) + new RPCContinuousShuffleWriter(idx, new HashPartitioner(1), Array(readRDDEndpoint(reader))) } writers(1).write(Iterator()) From 59890d47a1f34fe9de6c16d03e8d644a40f6180b Mon Sep 17 00:00:00 2001 From: Jose Torres Date: Fri, 25 May 2018 16:48:31 -0700 Subject: [PATCH 21/33] combine suites --- .../shuffle/ContinuousShuffleReadSuite.scala | 230 --------------- .../shuffle/ContinuousShuffleSuite.scala | 278 ++++++++++++++++-- 2 files changed, 246 insertions(+), 262 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/shuffle/ContinuousShuffleReadSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/shuffle/ContinuousShuffleReadSuite.scala index 1044316b0746f..3d4881a19eca1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/shuffle/ContinuousShuffleReadSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/shuffle/ContinuousShuffleReadSuite.scala @@ -58,235 +58,5 @@ class ContinuousShuffleReadSuite extends StreamTest { super.afterEach() } - test("receiver stopped with row last") { - val rdd = new ContinuousShuffleReadRDD(sparkContext, numPartitions = 1) - val endpoint = rdd.partitions(0).asInstanceOf[ContinuousShuffleReadPartition].endpoint - send( - endpoint, - ReceiverEpochMarker(0), - ReceiverRow(0, unsafeRow(111)) - ) - ctx.markTaskCompleted(None) - val receiver = rdd.partitions(0).asInstanceOf[ContinuousShuffleReadPartition].reader - eventually(timeout(streamingTimeout)) { - assert(receiver.asInstanceOf[RPCContinuousShuffleReader].stopped.get()) - } - } - - test("receiver stopped with marker last") { - val rdd = new ContinuousShuffleReadRDD(sparkContext, numPartitions = 1) - val endpoint = rdd.partitions(0).asInstanceOf[ContinuousShuffleReadPartition].endpoint - send( - endpoint, - ReceiverRow(0, unsafeRow(111)), - ReceiverEpochMarker(0) - ) - - ctx.markTaskCompleted(None) - val receiver = rdd.partitions(0).asInstanceOf[ContinuousShuffleReadPartition].reader - eventually(timeout(streamingTimeout)) { - assert(receiver.asInstanceOf[RPCContinuousShuffleReader].stopped.get()) - } - } - - test("one epoch") { - val rdd = new ContinuousShuffleReadRDD(sparkContext, numPartitions = 1) - val endpoint = rdd.partitions(0).asInstanceOf[ContinuousShuffleReadPartition].endpoint - send( - endpoint, - ReceiverRow(0, unsafeRow(111)), - ReceiverRow(0, unsafeRow(222)), - ReceiverRow(0, unsafeRow(333)), - ReceiverEpochMarker(0) - ) - - val iter = rdd.compute(rdd.partitions(0), ctx) - assert(iter.toSeq.map(_.getInt(0)) == Seq(111, 222, 333)) - } - - test("multiple epochs") { - val rdd = new ContinuousShuffleReadRDD(sparkContext, numPartitions = 1) - val endpoint = rdd.partitions(0).asInstanceOf[ContinuousShuffleReadPartition].endpoint - send( - endpoint, - ReceiverRow(0, unsafeRow(111)), - ReceiverEpochMarker(0), - ReceiverRow(0, unsafeRow(222)), - ReceiverRow(0, unsafeRow(333)), - ReceiverEpochMarker(0) - ) - - val firstEpoch = rdd.compute(rdd.partitions(0), ctx) - assert(firstEpoch.toSeq.map(_.getInt(0)) == Seq(111)) - - val secondEpoch = rdd.compute(rdd.partitions(0), ctx) - assert(secondEpoch.toSeq.map(_.getInt(0)) == Seq(222, 333)) - } - - test("empty epochs") { - val rdd = new ContinuousShuffleReadRDD(sparkContext, numPartitions = 1) - val endpoint = rdd.partitions(0).asInstanceOf[ContinuousShuffleReadPartition].endpoint - - send( - endpoint, - ReceiverEpochMarker(0), - ReceiverEpochMarker(0), - ReceiverRow(0, unsafeRow(111)), - ReceiverEpochMarker(0), - ReceiverEpochMarker(0), - ReceiverEpochMarker(0) - ) - - assert(rdd.compute(rdd.partitions(0), ctx).isEmpty) - assert(rdd.compute(rdd.partitions(0), ctx).isEmpty) - - val thirdEpoch = rdd.compute(rdd.partitions(0), ctx) - assert(thirdEpoch.toSeq.map(_.getInt(0)) == Seq(111)) - - assert(rdd.compute(rdd.partitions(0), ctx).isEmpty) - assert(rdd.compute(rdd.partitions(0), ctx).isEmpty) - } - - test("multiple partitions") { - val rdd = new ContinuousShuffleReadRDD(sparkContext, numPartitions = 5) - // Send all data before processing to ensure there's no crossover. - for (p <- rdd.partitions) { - val part = p.asInstanceOf[ContinuousShuffleReadPartition] - // Send index for identification. - send( - part.endpoint, - ReceiverRow(0, unsafeRow(part.index)), - ReceiverEpochMarker(0) - ) - } - - for (p <- rdd.partitions) { - val part = p.asInstanceOf[ContinuousShuffleReadPartition] - val iter = rdd.compute(part, ctx) - assert(iter.next().getInt(0) == part.index) - assert(!iter.hasNext) - } - } - - test("blocks waiting for new rows") { - val rdd = new ContinuousShuffleReadRDD( - sparkContext, numPartitions = 1, epochIntervalMs = Long.MaxValue) - val epoch = rdd.compute(rdd.partitions(0), ctx) - - val readRowThread = new Thread { - override def run(): Unit = { - try { - epoch.next().getInt(0) - } catch { - case _: InterruptedException => // do nothing - expected at test ending - } - } - } - - try { - readRowThread.start() - eventually(timeout(streamingTimeout)) { - assert(readRowThread.getState == Thread.State.TIMED_WAITING) - } - } finally { - readRowThread.interrupt() - readRowThread.join() - } - } - - test("multiple writers") { - val rdd = new ContinuousShuffleReadRDD(sparkContext, numPartitions = 1, numShuffleWriters = 3) - val endpoint = rdd.partitions(0).asInstanceOf[ContinuousShuffleReadPartition].endpoint - send( - endpoint, - ReceiverRow(0, unsafeRow("writer0-row0")), - ReceiverRow(1, unsafeRow("writer1-row0")), - ReceiverRow(2, unsafeRow("writer2-row0")), - ReceiverEpochMarker(0), - ReceiverEpochMarker(1), - ReceiverEpochMarker(2) - ) - - val firstEpoch = rdd.compute(rdd.partitions(0), ctx) - assert(firstEpoch.toSeq.map(_.getUTF8String(0).toString).toSet == - Set("writer0-row0", "writer1-row0", "writer2-row0")) - } - - test("epoch only ends when all writers send markers") { - val rdd = new ContinuousShuffleReadRDD( - sparkContext, numPartitions = 1, numShuffleWriters = 3, epochIntervalMs = Long.MaxValue) - val endpoint = rdd.partitions(0).asInstanceOf[ContinuousShuffleReadPartition].endpoint - send( - endpoint, - ReceiverRow(0, unsafeRow("writer0-row0")), - ReceiverRow(1, unsafeRow("writer1-row0")), - ReceiverRow(2, unsafeRow("writer2-row0")), - ReceiverEpochMarker(0), - ReceiverEpochMarker(2) - ) - - val epoch = rdd.compute(rdd.partitions(0), ctx) - val rows = (0 until 3).map(_ => epoch.next()).toSet - assert(rows.map(_.getUTF8String(0).toString) == - Set("writer0-row0", "writer1-row0", "writer2-row0")) - - // After checking the right rows, block until we get an epoch marker indicating there's no next. - // (Also fail the assertion if for some reason we get a row.) - - val readEpochMarkerThread = new Thread { - override def run(): Unit = { - assert(!epoch.hasNext) - } - } - - readEpochMarkerThread.start() - eventually(timeout(streamingTimeout)) { - assert(readEpochMarkerThread.getState == Thread.State.TIMED_WAITING) - } - - // Send the last epoch marker - now the epoch should finish. - send(endpoint, ReceiverEpochMarker(1)) - eventually(timeout(streamingTimeout)) { - !readEpochMarkerThread.isAlive - } - - // Join to pick up assertion failures. - readEpochMarkerThread.join() - } - - test("writer epochs non aligned") { - val rdd = new ContinuousShuffleReadRDD(sparkContext, numPartitions = 1, numShuffleWriters = 3) - val endpoint = rdd.partitions(0).asInstanceOf[ContinuousShuffleReadPartition].endpoint - // We send multiple epochs for 0, then multiple for 1, then multiple for 2. The receiver should - // collate them as though the markers were aligned in the first place. - send( - endpoint, - ReceiverRow(0, unsafeRow("writer0-row0")), - ReceiverEpochMarker(0), - ReceiverRow(0, unsafeRow("writer0-row1")), - ReceiverEpochMarker(0), - ReceiverEpochMarker(0), - - ReceiverEpochMarker(1), - ReceiverRow(1, unsafeRow("writer1-row0")), - ReceiverEpochMarker(1), - ReceiverRow(1, unsafeRow("writer1-row1")), - ReceiverEpochMarker(1), - - ReceiverEpochMarker(2), - ReceiverEpochMarker(2), - ReceiverRow(2, unsafeRow("writer2-row0")), - ReceiverEpochMarker(2) - ) - - val firstEpoch = rdd.compute(rdd.partitions(0), ctx).map(_.getUTF8String(0).toString).toSet - assert(firstEpoch == Set("writer0-row0")) - - val secondEpoch = rdd.compute(rdd.partitions(0), ctx).map(_.getUTF8String(0).toString).toSet - assert(secondEpoch == Set("writer0-row1", "writer1-row0")) - - val thirdEpoch = rdd.compute(rdd.partitions(0), ctx).map(_.getUTF8String(0).toString).toSet - assert(thirdEpoch == Set("writer1-row1", "writer2-row0")) - } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/shuffle/ContinuousShuffleSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/shuffle/ContinuousShuffleSuite.scala index 7763a348ffa53..56cbc95bf8b95 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/shuffle/ContinuousShuffleSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/shuffle/ContinuousShuffleSuite.scala @@ -15,16 +15,17 @@ * limitations under the License. */ -package org.apache.spark.sql.streaming.continuous.shuffle +package org.apache.spark.sql.execution.streaming.continuous.shuffle import scala.collection.mutable import org.apache.spark.{HashPartitioner, Partition, TaskContext, TaskContextImpl} -import org.apache.spark.rdd.RDD +import org.apache.spark.rpc.RpcEndpointRef import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeProjection, UnsafeRow} -import org.apache.spark.sql.execution.streaming.continuous.shuffle.{ContinuousShuffleReadPartition, ContinuousShuffleReadRDD, RPCContinuousShuffleWriter} +import org.apache.spark.sql.execution.streaming.continuous.shuffle._ import org.apache.spark.sql.streaming.StreamTest -import org.apache.spark.sql.types.{DataType, IntegerType} +import org.apache.spark.sql.types.{DataType, IntegerType, StringType} +import org.apache.spark.unsafe.types.UTF8String class ContinuousShuffleSuite extends StreamTest { // In this unit test, we emulate that we're in the task thread where @@ -45,36 +46,20 @@ class ContinuousShuffleSuite extends StreamTest { super.afterEach() } - private case class SimplePartition(index: Int) extends Partition - - /** - * An RDD that simulates multiple continuous processing epochs, with each epoch corresponding - * to one entry in the outer epochData array. The data in the inner array is round-robined across - * the specified number of partitions. - */ - private class MultipleEpochRDD(numPartitions: Int, epochData: Array[Int]*) - extends RDD[UnsafeRow](sparkContext, Nil) { - override def getPartitions: Array[Partition] = { - (0 until numPartitions).map(SimplePartition).toArray - } - - private val currentEpochForPartition = mutable.Map[Int, Int]().withDefaultValue(0) - - override def compute(split: Partition, ctx: TaskContext): Iterator[UnsafeRow] = { - val epoch = epochData(currentEpochForPartition(split.index)).zipWithIndex.collect { - case (value, idx) if idx % numPartitions == split.index => unsafeRow(value) - } - - currentEpochForPartition(split.index) += 1 - epoch.toIterator - } - } - private implicit def unsafeRow(value: Int) = { UnsafeProjection.create(Array(IntegerType : DataType))( new GenericInternalRow(Array(value: Any))) } + private def unsafeRow(value: String) = { + UnsafeProjection.create(Array(StringType : DataType))( + new GenericInternalRow(Array(UTF8String.fromString(value): Any))) + } + + private def send(endpoint: RpcEndpointRef, messages: UnsafeRowReceiverMessage*) = { + messages.foreach(endpoint.askSync[Unit](_)) + } + private def readRDDEndpoint(rdd: ContinuousShuffleReadRDD) = { rdd.partitions(0).asInstanceOf[ContinuousShuffleReadPartition].endpoint } @@ -106,8 +91,6 @@ class ContinuousShuffleSuite extends StreamTest { } test("empty epochs") { - val data = new MultipleEpochRDD(1, Array(), Array(1, 2), Array(), Array(), Array(3, 4), Array()) - val reader = new ContinuousShuffleReadRDD(sparkContext, numPartitions = 1) val writer = new RPCContinuousShuffleWriter( 0, new HashPartitioner(1), Array(readRDDEndpoint(reader))) @@ -175,7 +158,6 @@ class ContinuousShuffleSuite extends StreamTest { test("reader epoch only ends when all writer partitions write it") { val numWriterPartitions = 3 - val data = new MultipleEpochRDD(numWriterPartitions, Array()) val reader = new ContinuousShuffleReadRDD( sparkContext, numPartitions = 1, numShuffleWriters = numWriterPartitions) @@ -202,4 +184,236 @@ class ContinuousShuffleSuite extends StreamTest { writers(0).write(Iterator()) readEpochMarkerThread.join() } + + test("receiver stopped with row last") { + val rdd = new ContinuousShuffleReadRDD(sparkContext, numPartitions = 1) + val endpoint = rdd.partitions(0).asInstanceOf[ContinuousShuffleReadPartition].endpoint + send( + endpoint, + ReceiverEpochMarker(0), + ReceiverRow(0, unsafeRow(111)) + ) + + ctx.markTaskCompleted(None) + val receiver = rdd.partitions(0).asInstanceOf[ContinuousShuffleReadPartition].reader + eventually(timeout(streamingTimeout)) { + assert(receiver.asInstanceOf[RPCContinuousShuffleReader].stopped.get()) + } + } + + test("receiver stopped with marker last") { + val rdd = new ContinuousShuffleReadRDD(sparkContext, numPartitions = 1) + val endpoint = rdd.partitions(0).asInstanceOf[ContinuousShuffleReadPartition].endpoint + send( + endpoint, + ReceiverRow(0, unsafeRow(111)), + ReceiverEpochMarker(0) + ) + + ctx.markTaskCompleted(None) + val receiver = rdd.partitions(0).asInstanceOf[ContinuousShuffleReadPartition].reader + eventually(timeout(streamingTimeout)) { + assert(receiver.asInstanceOf[RPCContinuousShuffleReader].stopped.get()) + } + } + + test("reader - one epoch") { + val rdd = new ContinuousShuffleReadRDD(sparkContext, numPartitions = 1) + val endpoint = rdd.partitions(0).asInstanceOf[ContinuousShuffleReadPartition].endpoint + send( + endpoint, + ReceiverRow(0, unsafeRow(111)), + ReceiverRow(0, unsafeRow(222)), + ReceiverRow(0, unsafeRow(333)), + ReceiverEpochMarker(0) + ) + + val iter = rdd.compute(rdd.partitions(0), ctx) + assert(iter.toSeq.map(_.getInt(0)) == Seq(111, 222, 333)) + } + + test("reader - multiple epochs") { + val rdd = new ContinuousShuffleReadRDD(sparkContext, numPartitions = 1) + val endpoint = rdd.partitions(0).asInstanceOf[ContinuousShuffleReadPartition].endpoint + send( + endpoint, + ReceiverRow(0, unsafeRow(111)), + ReceiverEpochMarker(0), + ReceiverRow(0, unsafeRow(222)), + ReceiverRow(0, unsafeRow(333)), + ReceiverEpochMarker(0) + ) + + val firstEpoch = rdd.compute(rdd.partitions(0), ctx) + assert(firstEpoch.toSeq.map(_.getInt(0)) == Seq(111)) + + val secondEpoch = rdd.compute(rdd.partitions(0), ctx) + assert(secondEpoch.toSeq.map(_.getInt(0)) == Seq(222, 333)) + } + + test("reader - empty epochs") { + val rdd = new ContinuousShuffleReadRDD(sparkContext, numPartitions = 1) + val endpoint = rdd.partitions(0).asInstanceOf[ContinuousShuffleReadPartition].endpoint + + send( + endpoint, + ReceiverEpochMarker(0), + ReceiverEpochMarker(0), + ReceiverRow(0, unsafeRow(111)), + ReceiverEpochMarker(0), + ReceiverEpochMarker(0), + ReceiverEpochMarker(0) + ) + + assert(rdd.compute(rdd.partitions(0), ctx).isEmpty) + assert(rdd.compute(rdd.partitions(0), ctx).isEmpty) + + val thirdEpoch = rdd.compute(rdd.partitions(0), ctx) + assert(thirdEpoch.toSeq.map(_.getInt(0)) == Seq(111)) + + assert(rdd.compute(rdd.partitions(0), ctx).isEmpty) + assert(rdd.compute(rdd.partitions(0), ctx).isEmpty) + } + + test("reader - multiple partitions") { + val rdd = new ContinuousShuffleReadRDD(sparkContext, numPartitions = 5) + // Send all data before processing to ensure there's no crossover. + for (p <- rdd.partitions) { + val part = p.asInstanceOf[ContinuousShuffleReadPartition] + // Send index for identification. + send( + part.endpoint, + ReceiverRow(0, unsafeRow(part.index)), + ReceiverEpochMarker(0) + ) + } + + for (p <- rdd.partitions) { + val part = p.asInstanceOf[ContinuousShuffleReadPartition] + val iter = rdd.compute(part, ctx) + assert(iter.next().getInt(0) == part.index) + assert(!iter.hasNext) + } + } + + test("reader - blocks waiting for new rows") { + val rdd = new ContinuousShuffleReadRDD( + sparkContext, numPartitions = 1, epochIntervalMs = Long.MaxValue) + val epoch = rdd.compute(rdd.partitions(0), ctx) + + val readRowThread = new Thread { + override def run(): Unit = { + try { + epoch.next().getInt(0) + } catch { + case _: InterruptedException => // do nothing - expected at test ending + } + } + } + + try { + readRowThread.start() + eventually(timeout(streamingTimeout)) { + assert(readRowThread.getState == Thread.State.TIMED_WAITING) + } + } finally { + readRowThread.interrupt() + readRowThread.join() + } + } + + test("reader - multiple writers") { + val rdd = new ContinuousShuffleReadRDD(sparkContext, numPartitions = 1, numShuffleWriters = 3) + val endpoint = rdd.partitions(0).asInstanceOf[ContinuousShuffleReadPartition].endpoint + send( + endpoint, + ReceiverRow(0, unsafeRow("writer0-row0")), + ReceiverRow(1, unsafeRow("writer1-row0")), + ReceiverRow(2, unsafeRow("writer2-row0")), + ReceiverEpochMarker(0), + ReceiverEpochMarker(1), + ReceiverEpochMarker(2) + ) + + val firstEpoch = rdd.compute(rdd.partitions(0), ctx) + assert(firstEpoch.toSeq.map(_.getUTF8String(0).toString).toSet == + Set("writer0-row0", "writer1-row0", "writer2-row0")) + } + + test("reader - epoch only ends when all writers send markers") { + val rdd = new ContinuousShuffleReadRDD( + sparkContext, numPartitions = 1, numShuffleWriters = 3, epochIntervalMs = Long.MaxValue) + val endpoint = rdd.partitions(0).asInstanceOf[ContinuousShuffleReadPartition].endpoint + send( + endpoint, + ReceiverRow(0, unsafeRow("writer0-row0")), + ReceiverRow(1, unsafeRow("writer1-row0")), + ReceiverRow(2, unsafeRow("writer2-row0")), + ReceiverEpochMarker(0), + ReceiverEpochMarker(2) + ) + + val epoch = rdd.compute(rdd.partitions(0), ctx) + val rows = (0 until 3).map(_ => epoch.next()).toSet + assert(rows.map(_.getUTF8String(0).toString) == + Set("writer0-row0", "writer1-row0", "writer2-row0")) + + // After checking the right rows, block until we get an epoch marker indicating there's no next. + // (Also fail the assertion if for some reason we get a row.) + + val readEpochMarkerThread = new Thread { + override def run(): Unit = { + assert(!epoch.hasNext) + } + } + + readEpochMarkerThread.start() + eventually(timeout(streamingTimeout)) { + assert(readEpochMarkerThread.getState == Thread.State.TIMED_WAITING) + } + + // Send the last epoch marker - now the epoch should finish. + send(endpoint, ReceiverEpochMarker(1)) + eventually(timeout(streamingTimeout)) { + !readEpochMarkerThread.isAlive + } + + // Join to pick up assertion failures. + readEpochMarkerThread.join() + } + + test("reader - writer epochs non aligned") { + val rdd = new ContinuousShuffleReadRDD(sparkContext, numPartitions = 1, numShuffleWriters = 3) + val endpoint = rdd.partitions(0).asInstanceOf[ContinuousShuffleReadPartition].endpoint + // We send multiple epochs for 0, then multiple for 1, then multiple for 2. The receiver should + // collate them as though the markers were aligned in the first place. + send( + endpoint, + ReceiverRow(0, unsafeRow("writer0-row0")), + ReceiverEpochMarker(0), + ReceiverRow(0, unsafeRow("writer0-row1")), + ReceiverEpochMarker(0), + ReceiverEpochMarker(0), + + ReceiverEpochMarker(1), + ReceiverRow(1, unsafeRow("writer1-row0")), + ReceiverEpochMarker(1), + ReceiverRow(1, unsafeRow("writer1-row1")), + ReceiverEpochMarker(1), + + ReceiverEpochMarker(2), + ReceiverEpochMarker(2), + ReceiverRow(2, unsafeRow("writer2-row0")), + ReceiverEpochMarker(2) + ) + + val firstEpoch = rdd.compute(rdd.partitions(0), ctx).map(_.getUTF8String(0).toString).toSet + assert(firstEpoch == Set("writer0-row0")) + + val secondEpoch = rdd.compute(rdd.partitions(0), ctx).map(_.getUTF8String(0).toString).toSet + assert(secondEpoch == Set("writer0-row1", "writer1-row0")) + + val thirdEpoch = rdd.compute(rdd.partitions(0), ctx).map(_.getUTF8String(0).toString).toSet + assert(thirdEpoch == Set("writer1-row1", "writer2-row0")) + } } From af1508cac894be8804a5e1646e9fe0a58e595f2d Mon Sep 17 00:00:00 2001 From: Jose Torres Date: Fri, 25 May 2018 16:51:13 -0700 Subject: [PATCH 22/33] fully rm old suite --- .../shuffle/ContinuousShuffleReadSuite.scala | 62 ------------------- 1 file changed, 62 deletions(-) delete mode 100644 sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/shuffle/ContinuousShuffleReadSuite.scala diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/shuffle/ContinuousShuffleReadSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/shuffle/ContinuousShuffleReadSuite.scala deleted file mode 100644 index 3d4881a19eca1..0000000000000 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/shuffle/ContinuousShuffleReadSuite.scala +++ /dev/null @@ -1,62 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution.streaming.continuous.shuffle - -import org.apache.spark.{TaskContext, TaskContextImpl} -import org.apache.spark.rpc.RpcEndpointRef -import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeProjection} -import org.apache.spark.sql.streaming.StreamTest -import org.apache.spark.sql.types.{DataType, IntegerType, StringType} -import org.apache.spark.unsafe.types.UTF8String - -class ContinuousShuffleReadSuite extends StreamTest { - - private def unsafeRow(value: Int) = { - UnsafeProjection.create(Array(IntegerType : DataType))( - new GenericInternalRow(Array(value: Any))) - } - - private def unsafeRow(value: String) = { - UnsafeProjection.create(Array(StringType : DataType))( - new GenericInternalRow(Array(UTF8String.fromString(value): Any))) - } - - private def send(endpoint: RpcEndpointRef, messages: UnsafeRowReceiverMessage*) = { - messages.foreach(endpoint.askSync[Unit](_)) - } - - // In this unit test, we emulate that we're in the task thread where - // ContinuousShuffleReadRDD.compute() will be evaluated. This requires a task context - // thread local to be set. - var ctx: TaskContextImpl = _ - - override def beforeEach(): Unit = { - super.beforeEach() - ctx = TaskContext.empty() - TaskContext.setTaskContext(ctx) - } - - override def afterEach(): Unit = { - ctx.markTaskCompleted(None) - TaskContext.unset() - ctx = null - super.afterEach() - } - - -} From 65837ac611991f2db7710d0657e56b222a2f5c74 Mon Sep 17 00:00:00 2001 From: Jose Torres Date: Tue, 29 May 2018 15:43:38 -0700 Subject: [PATCH 23/33] reorder tests --- .../shuffle/ContinuousShuffleSuite.scala | 298 +++++++++--------- 1 file changed, 149 insertions(+), 149 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/shuffle/ContinuousShuffleSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/shuffle/ContinuousShuffleSuite.scala index 56cbc95bf8b95..4d5813a18a3f3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/shuffle/ContinuousShuffleSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/shuffle/ContinuousShuffleSuite.scala @@ -68,155 +68,6 @@ class ContinuousShuffleSuite extends StreamTest { rdd.compute(rdd.partitions(0), ctx).toSeq.map(_.getInt(0)) } - test("one epoch") { - val reader = new ContinuousShuffleReadRDD(sparkContext, numPartitions = 1) - val writer = new RPCContinuousShuffleWriter( - 0, new HashPartitioner(1), Array(readRDDEndpoint(reader))) - - writer.write(Iterator(1, 2, 3)) - - assert(readEpoch(reader) == Seq(1, 2, 3)) - } - - test("multiple epochs") { - val reader = new ContinuousShuffleReadRDD(sparkContext, numPartitions = 1) - val writer = new RPCContinuousShuffleWriter( - 0, new HashPartitioner(1), Array(readRDDEndpoint(reader))) - - writer.write(Iterator(1, 2, 3)) - writer.write(Iterator(4, 5, 6)) - - assert(readEpoch(reader) == Seq(1, 2, 3)) - assert(readEpoch(reader) == Seq(4, 5, 6)) - } - - test("empty epochs") { - val reader = new ContinuousShuffleReadRDD(sparkContext, numPartitions = 1) - val writer = new RPCContinuousShuffleWriter( - 0, new HashPartitioner(1), Array(readRDDEndpoint(reader))) - - writer.write(Iterator()) - writer.write(Iterator(1, 2)) - writer.write(Iterator()) - writer.write(Iterator()) - writer.write(Iterator(3, 4)) - writer.write(Iterator()) - - assert(readEpoch(reader) == Seq()) - assert(readEpoch(reader) == Seq(1, 2)) - assert(readEpoch(reader) == Seq()) - assert(readEpoch(reader) == Seq()) - assert(readEpoch(reader) == Seq(3, 4)) - assert(readEpoch(reader) == Seq()) - } - - test("blocks waiting for writer") { - val reader = new ContinuousShuffleReadRDD(sparkContext, numPartitions = 1) - val writer = new RPCContinuousShuffleWriter( - 0, new HashPartitioner(1), Array(readRDDEndpoint(reader))) - - val readerEpoch = reader.compute(reader.partitions(0), ctx) - - val readRowThread = new Thread { - override def run(): Unit = { - assert(readerEpoch.toSeq.map(_.getInt(0)) == Seq(1)) - } - } - readRowThread.start() - - eventually(timeout(streamingTimeout)) { - assert(readRowThread.getState == Thread.State.TIMED_WAITING) - } - - // Once we write the epoch the thread should stop waiting and succeed. - writer.write(Iterator(1)) - readRowThread.join() - } - - test("multiple writer partitions") { - val numWriterPartitions = 3 - - val reader = new ContinuousShuffleReadRDD( - sparkContext, numPartitions = 1, numShuffleWriters = numWriterPartitions) - val writers = (0 until 3).map { idx => - new RPCContinuousShuffleWriter(idx, new HashPartitioner(1), Array(readRDDEndpoint(reader))) - } - - writers(0).write(Iterator(1, 4, 7)) - writers(1).write(Iterator(2, 5)) - writers(2).write(Iterator(3, 6)) - - writers(0).write(Iterator(4, 7, 10)) - writers(1).write(Iterator(5, 8)) - writers(2).write(Iterator(6, 9)) - - // Since there are multiple asynchronous writers, the original row sequencing is not guaranteed. - // The epochs should be deterministically preserved, however. - assert(readEpoch(reader).toSet == Seq(1, 2, 3, 4, 5, 6, 7).toSet) - assert(readEpoch(reader).toSet == Seq(4, 5, 6, 7, 8, 9, 10).toSet) - } - - test("reader epoch only ends when all writer partitions write it") { - val numWriterPartitions = 3 - - val reader = new ContinuousShuffleReadRDD( - sparkContext, numPartitions = 1, numShuffleWriters = numWriterPartitions) - val writers = (0 until 3).map { idx => - new RPCContinuousShuffleWriter(idx, new HashPartitioner(1), Array(readRDDEndpoint(reader))) - } - - writers(1).write(Iterator()) - writers(2).write(Iterator()) - - val readerEpoch = reader.compute(reader.partitions(0), ctx) - - val readEpochMarkerThread = new Thread { - override def run(): Unit = { - assert(!readerEpoch.hasNext) - } - } - - readEpochMarkerThread.start() - eventually(timeout(streamingTimeout)) { - assert(readEpochMarkerThread.getState == Thread.State.TIMED_WAITING) - } - - writers(0).write(Iterator()) - readEpochMarkerThread.join() - } - - test("receiver stopped with row last") { - val rdd = new ContinuousShuffleReadRDD(sparkContext, numPartitions = 1) - val endpoint = rdd.partitions(0).asInstanceOf[ContinuousShuffleReadPartition].endpoint - send( - endpoint, - ReceiverEpochMarker(0), - ReceiverRow(0, unsafeRow(111)) - ) - - ctx.markTaskCompleted(None) - val receiver = rdd.partitions(0).asInstanceOf[ContinuousShuffleReadPartition].reader - eventually(timeout(streamingTimeout)) { - assert(receiver.asInstanceOf[RPCContinuousShuffleReader].stopped.get()) - } - } - - test("receiver stopped with marker last") { - val rdd = new ContinuousShuffleReadRDD(sparkContext, numPartitions = 1) - val endpoint = rdd.partitions(0).asInstanceOf[ContinuousShuffleReadPartition].endpoint - send( - endpoint, - ReceiverRow(0, unsafeRow(111)), - ReceiverEpochMarker(0) - ) - - ctx.markTaskCompleted(None) - val receiver = rdd.partitions(0).asInstanceOf[ContinuousShuffleReadPartition].reader - eventually(timeout(streamingTimeout)) { - assert(receiver.asInstanceOf[RPCContinuousShuffleReader].stopped.get()) - } - } - test("reader - one epoch") { val rdd = new ContinuousShuffleReadRDD(sparkContext, numPartitions = 1) val endpoint = rdd.partitions(0).asInstanceOf[ContinuousShuffleReadPartition].endpoint @@ -416,4 +267,153 @@ class ContinuousShuffleSuite extends StreamTest { val thirdEpoch = rdd.compute(rdd.partitions(0), ctx).map(_.getUTF8String(0).toString).toSet assert(thirdEpoch == Set("writer1-row1", "writer2-row0")) } + + test("one epoch") { + val reader = new ContinuousShuffleReadRDD(sparkContext, numPartitions = 1) + val writer = new RPCContinuousShuffleWriter( + 0, new HashPartitioner(1), Array(readRDDEndpoint(reader))) + + writer.write(Iterator(1, 2, 3)) + + assert(readEpoch(reader) == Seq(1, 2, 3)) + } + + test("multiple epochs") { + val reader = new ContinuousShuffleReadRDD(sparkContext, numPartitions = 1) + val writer = new RPCContinuousShuffleWriter( + 0, new HashPartitioner(1), Array(readRDDEndpoint(reader))) + + writer.write(Iterator(1, 2, 3)) + writer.write(Iterator(4, 5, 6)) + + assert(readEpoch(reader) == Seq(1, 2, 3)) + assert(readEpoch(reader) == Seq(4, 5, 6)) + } + + test("empty epochs") { + val reader = new ContinuousShuffleReadRDD(sparkContext, numPartitions = 1) + val writer = new RPCContinuousShuffleWriter( + 0, new HashPartitioner(1), Array(readRDDEndpoint(reader))) + + writer.write(Iterator()) + writer.write(Iterator(1, 2)) + writer.write(Iterator()) + writer.write(Iterator()) + writer.write(Iterator(3, 4)) + writer.write(Iterator()) + + assert(readEpoch(reader) == Seq()) + assert(readEpoch(reader) == Seq(1, 2)) + assert(readEpoch(reader) == Seq()) + assert(readEpoch(reader) == Seq()) + assert(readEpoch(reader) == Seq(3, 4)) + assert(readEpoch(reader) == Seq()) + } + + test("blocks waiting for writer") { + val reader = new ContinuousShuffleReadRDD(sparkContext, numPartitions = 1) + val writer = new RPCContinuousShuffleWriter( + 0, new HashPartitioner(1), Array(readRDDEndpoint(reader))) + + val readerEpoch = reader.compute(reader.partitions(0), ctx) + + val readRowThread = new Thread { + override def run(): Unit = { + assert(readerEpoch.toSeq.map(_.getInt(0)) == Seq(1)) + } + } + readRowThread.start() + + eventually(timeout(streamingTimeout)) { + assert(readRowThread.getState == Thread.State.TIMED_WAITING) + } + + // Once we write the epoch the thread should stop waiting and succeed. + writer.write(Iterator(1)) + readRowThread.join() + } + + test("multiple writer partitions") { + val numWriterPartitions = 3 + + val reader = new ContinuousShuffleReadRDD( + sparkContext, numPartitions = 1, numShuffleWriters = numWriterPartitions) + val writers = (0 until 3).map { idx => + new RPCContinuousShuffleWriter(idx, new HashPartitioner(1), Array(readRDDEndpoint(reader))) + } + + writers(0).write(Iterator(1, 4, 7)) + writers(1).write(Iterator(2, 5)) + writers(2).write(Iterator(3, 6)) + + writers(0).write(Iterator(4, 7, 10)) + writers(1).write(Iterator(5, 8)) + writers(2).write(Iterator(6, 9)) + + // Since there are multiple asynchronous writers, the original row sequencing is not guaranteed. + // The epochs should be deterministically preserved, however. + assert(readEpoch(reader).toSet == Seq(1, 2, 3, 4, 5, 6, 7).toSet) + assert(readEpoch(reader).toSet == Seq(4, 5, 6, 7, 8, 9, 10).toSet) + } + + test("reader epoch only ends when all writer partitions write it") { + val numWriterPartitions = 3 + + val reader = new ContinuousShuffleReadRDD( + sparkContext, numPartitions = 1, numShuffleWriters = numWriterPartitions) + val writers = (0 until 3).map { idx => + new RPCContinuousShuffleWriter(idx, new HashPartitioner(1), Array(readRDDEndpoint(reader))) + } + + writers(1).write(Iterator()) + writers(2).write(Iterator()) + + val readerEpoch = reader.compute(reader.partitions(0), ctx) + + val readEpochMarkerThread = new Thread { + override def run(): Unit = { + assert(!readerEpoch.hasNext) + } + } + + readEpochMarkerThread.start() + eventually(timeout(streamingTimeout)) { + assert(readEpochMarkerThread.getState == Thread.State.TIMED_WAITING) + } + + writers(0).write(Iterator()) + readEpochMarkerThread.join() + } + + test("receiver stopped with row last") { + val rdd = new ContinuousShuffleReadRDD(sparkContext, numPartitions = 1) + val endpoint = rdd.partitions(0).asInstanceOf[ContinuousShuffleReadPartition].endpoint + send( + endpoint, + ReceiverEpochMarker(0), + ReceiverRow(0, unsafeRow(111)) + ) + + ctx.markTaskCompleted(None) + val receiver = rdd.partitions(0).asInstanceOf[ContinuousShuffleReadPartition].reader + eventually(timeout(streamingTimeout)) { + assert(receiver.asInstanceOf[RPCContinuousShuffleReader].stopped.get()) + } + } + + test("receiver stopped with marker last") { + val rdd = new ContinuousShuffleReadRDD(sparkContext, numPartitions = 1) + val endpoint = rdd.partitions(0).asInstanceOf[ContinuousShuffleReadPartition].endpoint + send( + endpoint, + ReceiverRow(0, unsafeRow(111)), + ReceiverEpochMarker(0) + ) + + ctx.markTaskCompleted(None) + val receiver = rdd.partitions(0).asInstanceOf[ContinuousShuffleReadPartition].reader + eventually(timeout(streamingTimeout)) { + assert(receiver.asInstanceOf[RPCContinuousShuffleReader].stopped.get()) + } + } } From a68fae2e92dfc04f89c3de6cf11557f112d8825a Mon Sep 17 00:00:00 2001 From: Jose Torres Date: Thu, 31 May 2018 15:17:58 -0700 Subject: [PATCH 24/33] return future --- .../shuffle/RPCContinuousShuffleWriter.scala | 19 +++++++++++++++---- 1 file changed, 15 insertions(+), 4 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/RPCContinuousShuffleWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/RPCContinuousShuffleWriter.scala index a1bef1c115feb..424fed066d7c4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/RPCContinuousShuffleWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/RPCContinuousShuffleWriter.scala @@ -17,7 +17,11 @@ package org.apache.spark.sql.execution.streaming.continuous.shuffle -import org.apache.spark.Partitioner +import scala.collection.mutable.ArrayBuffer +import scala.concurrent.Future + +import org.apache.spark.{Partitioner, TaskContext} + import org.apache.spark.rpc.RpcEndpointRef import org.apache.spark.sql.catalyst.expressions.UnsafeRow @@ -43,12 +47,19 @@ class RPCContinuousShuffleWriter( s"not match endpoint count ${endpoints.length}") } - def write(epoch: Iterator[UnsafeRow]): Unit = { + def write(epoch: Iterator[UnsafeRow]): Future[Unit] = { + val futures = new ArrayBuffer[Future[Unit]]() while (epoch.hasNext) { val row = epoch.next() - endpoints(outputPartitioner.getPartition(row)).ask[Unit](ReceiverRow(writerId, row)) + futures += + endpoints(outputPartitioner.getPartition(row)).ask[Unit](ReceiverRow(writerId, row)) + } - endpoints.foreach(_.ask[Unit](ReceiverEpochMarker(writerId))) + futures.appendAll(endpoints.map { + e => e.ask[Unit](ReceiverEpochMarker(writerId)) + }) + + Future.reduce(futures)(_, _ => ()) } } From 98d55e4190d42ca651f7e620aa72c35c3f4b335c Mon Sep 17 00:00:00 2001 From: Jose Torres Date: Thu, 31 May 2018 15:27:40 -0700 Subject: [PATCH 25/33] finish getting rid of old name --- .../shuffle/ContinuousShuffleReadRDD.scala | 2 +- .../shuffle/RPCContinuousShuffleReader.scala | 18 +++++++++--------- 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleReadRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleReadRDD.scala index e07f7bb70b396..cf6572d3de1f7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleReadRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleReadRDD.scala @@ -36,7 +36,7 @@ case class ContinuousShuffleReadPartition( val env = SparkEnv.get.rpcEnv val receiver = new RPCContinuousShuffleReader( queueSize, numShuffleWriters, epochIntervalMs, env) - val endpoint = env.setupEndpoint(s"UnsafeRowReceiver-${UUID.randomUUID()}", receiver) + val endpoint = env.setupEndpoint(s"RPCContinuousShuffleReader-${UUID.randomUUID()}", receiver) TaskContext.get().addTaskCompletionListener { ctx => env.stop(endpoint) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/RPCContinuousShuffleReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/RPCContinuousShuffleReader.scala index 2bb5bbeef8eaa..9d47e062c7fc7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/RPCContinuousShuffleReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/RPCContinuousShuffleReader.scala @@ -26,18 +26,18 @@ import org.apache.spark.sql.catalyst.expressions.UnsafeRow import org.apache.spark.util.NextIterator /** - * Messages for the UnsafeRowReceiver endpoint. Either an incoming row or an epoch marker. + * Messages for the RPCContinuousShuffleReader endpoint. Either an incoming row or an epoch marker. * * Each message comes tagged with writerId, identifying which writer the message is coming * from. The receiver will only begin the next epoch once all writers have sent an epoch * marker ending the current epoch. */ -private[shuffle] sealed trait UnsafeRowReceiverMessage extends Serializable { +private[shuffle] sealed trait RPCContinuousShuffleMessage extends Serializable { def writerId: Int } private[shuffle] case class ReceiverRow(writerId: Int, row: UnsafeRow) - extends UnsafeRowReceiverMessage -private[shuffle] case class ReceiverEpochMarker(writerId: Int) extends UnsafeRowReceiverMessage + extends RPCContinuousShuffleMessage +private[shuffle] case class ReceiverEpochMarker(writerId: Int) extends RPCContinuousShuffleMessage /** * RPC endpoint for receiving rows into a continuous processing shuffle task. Continuous shuffle @@ -55,7 +55,7 @@ private[shuffle] class RPCContinuousShuffleReader( // Note that this queue will be drained from the main task thread and populated in the RPC // response thread. private val queues = Array.fill(numShuffleWriters) { - new ArrayBlockingQueue[UnsafeRowReceiverMessage](queueSize) + new ArrayBlockingQueue[RPCContinuousShuffleMessage](queueSize) } // Exposed for testing to determine if the endpoint gets stopped on task end. @@ -66,7 +66,7 @@ private[shuffle] class RPCContinuousShuffleReader( } override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { - case r: UnsafeRowReceiverMessage => + case r: RPCContinuousShuffleMessage => queues(r.writerId).put(r) context.reply(()) } @@ -77,10 +77,10 @@ private[shuffle] class RPCContinuousShuffleReader( private val writerEpochMarkersReceived = Array.fill(numShuffleWriters)(false) private val executor = Executors.newFixedThreadPool(numShuffleWriters) - private val completion = new ExecutorCompletionService[UnsafeRowReceiverMessage](executor) + private val completion = new ExecutorCompletionService[RPCContinuousShuffleMessage](executor) - private def completionTask(writerId: Int) = new Callable[UnsafeRowReceiverMessage] { - override def call(): UnsafeRowReceiverMessage = queues(writerId).take() + private def completionTask(writerId: Int) = new Callable[RPCContinuousShuffleMessage] { + override def call(): RPCContinuousShuffleMessage = queues(writerId).take() } // Initialize by submitting tasks to read the first row from each writer. From e6b9118fb02bff4ef83e2d954fda041b1ddfc870 Mon Sep 17 00:00:00 2001 From: Jose Torres Date: Thu, 31 May 2018 15:29:55 -0700 Subject: [PATCH 26/33] synchronous --- .../shuffle/RPCContinuousShuffleWriter.scala | 19 ++++--------------- 1 file changed, 4 insertions(+), 15 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/RPCContinuousShuffleWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/RPCContinuousShuffleWriter.scala index 424fed066d7c4..e345faafb3aaf 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/RPCContinuousShuffleWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/RPCContinuousShuffleWriter.scala @@ -17,11 +17,7 @@ package org.apache.spark.sql.execution.streaming.continuous.shuffle -import scala.collection.mutable.ArrayBuffer -import scala.concurrent.Future - -import org.apache.spark.{Partitioner, TaskContext} - +import org.apache.spark.Partitioner import org.apache.spark.rpc.RpcEndpointRef import org.apache.spark.sql.catalyst.expressions.UnsafeRow @@ -47,19 +43,12 @@ class RPCContinuousShuffleWriter( s"not match endpoint count ${endpoints.length}") } - def write(epoch: Iterator[UnsafeRow]): Future[Unit] = { - val futures = new ArrayBuffer[Future[Unit]]() + def write(epoch: Iterator[UnsafeRow]): Unit = { while (epoch.hasNext) { val row = epoch.next() - futures += - endpoints(outputPartitioner.getPartition(row)).ask[Unit](ReceiverRow(writerId, row)) - + endpoints(outputPartitioner.getPartition(row)).askSync[Unit](ReceiverRow(writerId, row)) } - futures.appendAll(endpoints.map { - e => e.ask[Unit](ReceiverEpochMarker(writerId)) - }) - - Future.reduce(futures)(_, _ => ()) + endpoints.foreach(_.askSync[Unit](ReceiverEpochMarker(writerId))) } } From 629455b53c610f546cda27ca09de94b2c35e5f9c Mon Sep 17 00:00:00 2001 From: Jose Torres Date: Thu, 31 May 2018 15:32:28 -0700 Subject: [PATCH 27/33] finish rename --- .../continuous/shuffle/ContinuousShuffleSuite.scala | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/shuffle/ContinuousShuffleSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/shuffle/ContinuousShuffleSuite.scala index 4d5813a18a3f3..3765df30b72a0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/shuffle/ContinuousShuffleSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/shuffle/ContinuousShuffleSuite.scala @@ -17,12 +17,9 @@ package org.apache.spark.sql.execution.streaming.continuous.shuffle -import scala.collection.mutable - import org.apache.spark.{HashPartitioner, Partition, TaskContext, TaskContextImpl} import org.apache.spark.rpc.RpcEndpointRef import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeProjection, UnsafeRow} -import org.apache.spark.sql.execution.streaming.continuous.shuffle._ import org.apache.spark.sql.streaming.StreamTest import org.apache.spark.sql.types.{DataType, IntegerType, StringType} import org.apache.spark.unsafe.types.UTF8String @@ -56,7 +53,7 @@ class ContinuousShuffleSuite extends StreamTest { new GenericInternalRow(Array(UTF8String.fromString(value): Any))) } - private def send(endpoint: RpcEndpointRef, messages: UnsafeRowReceiverMessage*) = { + private def send(endpoint: RpcEndpointRef, messages: RPCContinuousShuffleMessage*) = { messages.foreach(endpoint.askSync[Unit](_)) } From cb6d42b0de2f911e16e47a86e247b2610a94db6d Mon Sep 17 00:00:00 2001 From: Jose Torres Date: Tue, 12 Jun 2018 21:02:57 -0700 Subject: [PATCH 28/33] add timeouts --- .../continuous/shuffle/ContinuousShuffleSuite.scala | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/shuffle/ContinuousShuffleSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/shuffle/ContinuousShuffleSuite.scala index 3765df30b72a0..a8e3611b585cf 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/shuffle/ContinuousShuffleSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/shuffle/ContinuousShuffleSuite.scala @@ -227,7 +227,7 @@ class ContinuousShuffleSuite extends StreamTest { } // Join to pick up assertion failures. - readEpochMarkerThread.join() + readEpochMarkerThread.join(streamingTimeout.toMillis) } test("reader - writer epochs non aligned") { @@ -327,7 +327,7 @@ class ContinuousShuffleSuite extends StreamTest { // Once we write the epoch the thread should stop waiting and succeed. writer.write(Iterator(1)) - readRowThread.join() + readRowThread.join(streamingTimeout.toMillis) } test("multiple writer partitions") { @@ -379,7 +379,7 @@ class ContinuousShuffleSuite extends StreamTest { } writers(0).write(Iterator()) - readEpochMarkerThread.join() + readEpochMarkerThread.join(streamingTimeout.toMillis) } test("receiver stopped with row last") { From 59d6ff79de54dc8d99390a056a0133f35667e817 Mon Sep 17 00:00:00 2001 From: Jose Torres Date: Tue, 12 Jun 2018 21:04:34 -0700 Subject: [PATCH 29/33] unalign --- .../continuous/shuffle/RPCContinuousShuffleWriter.scala | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/RPCContinuousShuffleWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/RPCContinuousShuffleWriter.scala index e345faafb3aaf..7d26bce15e938 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/RPCContinuousShuffleWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/RPCContinuousShuffleWriter.scala @@ -24,10 +24,10 @@ import org.apache.spark.sql.catalyst.expressions.UnsafeRow /** * A [[ContinuousShuffleWriter]] sending data to [[RPCContinuousShuffleReader]] instances. * - * @param writerId The partition ID of this writer. + * @param writerId The partition ID of this writer. * @param outputPartitioner The partitioner on the reader side of the shuffle. - * @param endpoints The [[RPCContinuousShuffleReader]] endpoints to write to. Indexed by - * partition ID within outputPartitioner. + * @param endpoints The [[RPCContinuousShuffleReader]] endpoints to write to. Indexed by + * partition ID within outputPartitioner. */ class RPCContinuousShuffleWriter( writerId: Int, From f90388c36e100fd1c6a9cf2ac96c5247c0a8672f Mon Sep 17 00:00:00 2001 From: Jose Torres Date: Tue, 12 Jun 2018 22:44:56 -0700 Subject: [PATCH 30/33] add note --- .../continuous/shuffle/RPCContinuousShuffleReader.scala | 2 ++ 1 file changed, 2 insertions(+) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/RPCContinuousShuffleReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/RPCContinuousShuffleReader.scala index 9d47e062c7fc7..834e84675c7d5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/RPCContinuousShuffleReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/RPCContinuousShuffleReader.scala @@ -67,6 +67,8 @@ private[shuffle] class RPCContinuousShuffleReader( override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { case r: RPCContinuousShuffleMessage => + // Note that this will block a thread the shared RPC handler pool! + // The TCP based shuffle handler (SPARK-24541) will avoid this problem. queues(r.writerId).put(r) context.reply(()) } From 4bbdeae3b9d7d1955593dab360422dd59800d54f Mon Sep 17 00:00:00 2001 From: Jose Torres Date: Tue, 12 Jun 2018 22:55:01 -0700 Subject: [PATCH 31/33] parallel --- .../continuous/shuffle/RPCContinuousShuffleWriter.scala | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/RPCContinuousShuffleWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/RPCContinuousShuffleWriter.scala index 7d26bce15e938..be4b3d60595d0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/RPCContinuousShuffleWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/RPCContinuousShuffleWriter.scala @@ -17,9 +17,14 @@ package org.apache.spark.sql.execution.streaming.continuous.shuffle +import scala.concurrent.Future +import scala.concurrent.duration.Duration + import org.apache.spark.Partitioner + import org.apache.spark.rpc.RpcEndpointRef import org.apache.spark.sql.catalyst.expressions.UnsafeRow +import org.apache.spark.util.ThreadUtils /** * A [[ContinuousShuffleWriter]] sending data to [[RPCContinuousShuffleReader]] instances. @@ -49,6 +54,8 @@ class RPCContinuousShuffleWriter( endpoints(outputPartitioner.getPartition(row)).askSync[Unit](ReceiverRow(writerId, row)) } - endpoints.foreach(_.askSync[Unit](ReceiverEpochMarker(writerId))) + val futures = endpoints.map(_.ask[Unit](ReceiverEpochMarker(writerId))) + implicit val ec = ThreadUtils.sameThread + ThreadUtils.awaitResult(Future.sequence(futures), Duration.Inf) } } From e57531d42d0fc29622a3c30e374e21560336def2 Mon Sep 17 00:00:00 2001 From: Jose Torres Date: Tue, 12 Jun 2018 23:27:51 -0700 Subject: [PATCH 32/33] fix compile --- .../continuous/shuffle/RPCContinuousShuffleWriter.scala | 1 - 1 file changed, 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/RPCContinuousShuffleWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/RPCContinuousShuffleWriter.scala index be4b3d60595d0..fde1630e8877a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/RPCContinuousShuffleWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/RPCContinuousShuffleWriter.scala @@ -21,7 +21,6 @@ import scala.concurrent.Future import scala.concurrent.duration.Duration import org.apache.spark.Partitioner - import org.apache.spark.rpc.RpcEndpointRef import org.apache.spark.sql.catalyst.expressions.UnsafeRow import org.apache.spark.util.ThreadUtils From cff37c45f084d50a0844fbe8481565f6a9985302 Mon Sep 17 00:00:00 2001 From: Jose Torres Date: Wed, 13 Jun 2018 08:39:48 -0700 Subject: [PATCH 33/33] fix compile --- .../continuous/shuffle/RPCContinuousShuffleWriter.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/RPCContinuousShuffleWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/RPCContinuousShuffleWriter.scala index fde1630e8877a..1c6f3ddb395e6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/RPCContinuousShuffleWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/RPCContinuousShuffleWriter.scala @@ -53,7 +53,7 @@ class RPCContinuousShuffleWriter( endpoints(outputPartitioner.getPartition(row)).askSync[Unit](ReceiverRow(writerId, row)) } - val futures = endpoints.map(_.ask[Unit](ReceiverEpochMarker(writerId))) + val futures = endpoints.map(_.ask[Unit](ReceiverEpochMarker(writerId))).toSeq implicit val ec = ThreadUtils.sameThread ThreadUtils.awaitResult(Future.sequence(futures), Duration.Inf) }