From 1d6b71898e2a640e3c0809695d2b83f3f84eaa38 Mon Sep 17 00:00:00 2001 From: Jose Torres Date: Tue, 15 May 2018 11:07:54 -0700 Subject: [PATCH 01/28] continuous shuffle read RDD --- .../shuffle/ContinuousShuffleReadRDD.scala | 64 +++++++++ .../shuffle/UnsafeRowReceiver.scala | 56 ++++++++ .../shuffle/ContinuousShuffleReadSuite.scala | 122 ++++++++++++++++++ 3 files changed, 242 insertions(+) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleReadRDD.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/UnsafeRowReceiver.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/shuffle/ContinuousShuffleReadSuite.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleReadRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleReadRDD.scala new file mode 100644 index 0000000000000..110b51f8922d8 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleReadRDD.scala @@ -0,0 +1,64 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.continuous.shuffle + +import java.util.UUID + +import org.apache.spark.{Partition, SparkContext, SparkEnv, TaskContext} +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.expressions.UnsafeRow +import org.apache.spark.util.NextIterator + +case class ContinuousShuffleReadPartition(index: Int) extends Partition { + // Initialized only on the executor, and only once even as we call compute() multiple times. + lazy val (receiver, endpoint) = { + val env = SparkEnv.get.rpcEnv + val receiver = new UnsafeRowReceiver(env) + val endpoint = env.setupEndpoint(UUID.randomUUID().toString, receiver) + TaskContext.get().addTaskCompletionListener { ctx => + env.stop(endpoint) + } + (receiver, endpoint) + } +} + +/** + * RDD at the bottom of each continuous processing shuffle task, reading from the + */ +class ContinuousShuffleReadRDD(sc: SparkContext, numPartitions: Int) + extends RDD[UnsafeRow](sc, Nil) { + + override protected def getPartitions: Array[Partition] = { + (0 until numPartitions).map(ContinuousShuffleReadPartition).toArray + } + + override def compute(split: Partition, context: TaskContext): Iterator[UnsafeRow] = { + val receiver = split.asInstanceOf[ContinuousShuffleReadPartition].receiver + + new NextIterator[UnsafeRow] { + override def getNext(): UnsafeRow = receiver.poll() match { + case ReceiverRow(r) => r + case ReceiverEpochMarker() => + finished = true + null + } + + override def close(): Unit = {} + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/UnsafeRowReceiver.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/UnsafeRowReceiver.scala new file mode 100644 index 0000000000000..637305e54519f --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/UnsafeRowReceiver.scala @@ -0,0 +1,56 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.continuous.shuffle + +import java.util.concurrent.{ArrayBlockingQueue, BlockingQueue} +import java.util.concurrent.atomic.AtomicBoolean + +import org.apache.spark.internal.Logging +import org.apache.spark.rpc.{RpcCallContext, RpcEnv, ThreadSafeRpcEndpoint} +import org.apache.spark.sql.catalyst.expressions.UnsafeRow + +/** + * Messages for the UnsafeRowReceiver endpoint. Either an incoming row or an epoch marker. + */ +private[shuffle] sealed trait UnsafeRowReceiverMessage extends Serializable +private[shuffle] case class ReceiverRow(row: UnsafeRow) extends UnsafeRowReceiverMessage +private[shuffle] case class ReceiverEpochMarker() extends UnsafeRowReceiverMessage + +/** + * RPC endpoint for receiving rows into a continuous processing shuffle task. + */ +private[shuffle] class UnsafeRowReceiver(val rpcEnv: RpcEnv) + extends ThreadSafeRpcEndpoint with Logging { + private val queue = new ArrayBlockingQueue[UnsafeRowReceiverMessage](1024) + var stopped = new AtomicBoolean(false) + + override def onStop(): Unit = { + stopped.set(true) + } + + override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { + case r: UnsafeRowReceiverMessage => + queue.put(r) + context.reply(()) + } + + /** + * Polls until a new row is available. + */ + def poll(): UnsafeRowReceiverMessage = queue.poll() +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/shuffle/ContinuousShuffleReadSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/shuffle/ContinuousShuffleReadSuite.scala new file mode 100644 index 0000000000000..ee0ff8ada23f4 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/shuffle/ContinuousShuffleReadSuite.scala @@ -0,0 +1,122 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.continuous.shuffle + +import org.apache.spark.{TaskContext, TaskContextImpl} +import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeProjection} +import org.apache.spark.sql.streaming.StreamTest +import org.apache.spark.sql.types.{DataType, IntegerType} + +class ContinuousShuffleReadSuite extends StreamTest { + + private def unsafeRow(value: Int) = { + UnsafeProjection.create(Array(IntegerType : DataType))( + new GenericInternalRow(Array(value: Any))) + } + + var ctx: TaskContextImpl = _ + + override def beforeEach(): Unit = { + super.beforeEach() + ctx = TaskContext.empty() + TaskContext.setTaskContext(ctx) + } + + override def afterEach(): Unit = { + ctx.markTaskCompleted(None) + ctx = null + super.afterEach() + } + + test("receiver stopped with row last") { + val rdd = new ContinuousShuffleReadRDD(sparkContext, numPartitions = 1) + val endpoint = rdd.partitions(0).asInstanceOf[ContinuousShuffleReadPartition].endpoint + endpoint.askSync[Unit](ReceiverEpochMarker()) + endpoint.askSync[Unit](ReceiverRow(unsafeRow(111))) + + ctx.markTaskCompleted(None) + val receiver = rdd.partitions(0).asInstanceOf[ContinuousShuffleReadPartition].receiver + eventually(timeout(streamingTimeout)) { + assert(receiver.stopped.get()) + } + } + + test("receiver stopped with marker last") { + val rdd = new ContinuousShuffleReadRDD(sparkContext, numPartitions = 1) + val endpoint = rdd.partitions(0).asInstanceOf[ContinuousShuffleReadPartition].endpoint + endpoint.askSync[Unit](ReceiverRow(unsafeRow(111))) + endpoint.askSync[Unit](ReceiverEpochMarker()) + + ctx.markTaskCompleted(None) + val receiver = rdd.partitions(0).asInstanceOf[ContinuousShuffleReadPartition].receiver + eventually(timeout(streamingTimeout)) { + assert(receiver.stopped.get()) + } + } + + test("one epoch") { + val rdd = new ContinuousShuffleReadRDD(sparkContext, numPartitions = 1) + val endpoint = rdd.partitions(0).asInstanceOf[ContinuousShuffleReadPartition].endpoint + endpoint.askSync[Unit](ReceiverRow(unsafeRow(111))) + endpoint.askSync[Unit](ReceiverRow(unsafeRow(222))) + endpoint.askSync[Unit](ReceiverRow(unsafeRow(333))) + endpoint.askSync[Unit](ReceiverEpochMarker()) + + val iter = rdd.compute(rdd.partitions(0), ctx) + assert(iter.next().getInt(0) == 111) + assert(iter.next().getInt(0) == 222) + assert(iter.next().getInt(0) == 333) + assert(!iter.hasNext) + } + + test("multiple epochs") { + val rdd = new ContinuousShuffleReadRDD(sparkContext, numPartitions = 1) + val endpoint = rdd.partitions(0).asInstanceOf[ContinuousShuffleReadPartition].endpoint + endpoint.askSync[Unit](ReceiverRow(unsafeRow(111))) + endpoint.askSync[Unit](ReceiverEpochMarker()) + endpoint.askSync[Unit](ReceiverRow(unsafeRow(222))) + endpoint.askSync[Unit](ReceiverRow(unsafeRow(333))) + endpoint.askSync[Unit](ReceiverEpochMarker()) + + val firstEpoch = rdd.compute(rdd.partitions(0), ctx) + assert(firstEpoch.next().getInt(0) == 111) + assert(!firstEpoch.hasNext) + + val secondEpoch = rdd.compute(rdd.partitions(0), ctx) + assert(secondEpoch.next().getInt(0) == 222) + assert(secondEpoch.next().getInt(0) == 333) + assert(!secondEpoch.hasNext) + } + + test("empty epochs") { + val rdd = new ContinuousShuffleReadRDD(sparkContext, numPartitions = 1) + val endpoint = rdd.partitions(0).asInstanceOf[ContinuousShuffleReadPartition].endpoint + endpoint.askSync[Unit](ReceiverEpochMarker()) + endpoint.askSync[Unit](ReceiverEpochMarker()) + endpoint.askSync[Unit](ReceiverRow(unsafeRow(111))) + endpoint.askSync[Unit](ReceiverEpochMarker()) + endpoint.askSync[Unit](ReceiverEpochMarker()) + + assert(rdd.compute(rdd.partitions(0), ctx).isEmpty) + assert(rdd.compute(rdd.partitions(0), ctx).isEmpty) + val thirdEpoch = rdd.compute(rdd.partitions(0), ctx) + assert(thirdEpoch.next().getInt(0) == 111) + assert(rdd.compute(rdd.partitions(0), ctx).isEmpty) + assert(rdd.compute(rdd.partitions(0), ctx).isEmpty) + } +} From b5d100875932bdfcb645c8f6b2cdb7b815d84c80 Mon Sep 17 00:00:00 2001 From: Jose Torres Date: Wed, 16 May 2018 20:11:11 -0700 Subject: [PATCH 02/28] docs --- .../continuous/shuffle/ContinuousShuffleReadRDD.scala | 4 +++- .../streaming/continuous/shuffle/UnsafeRowReceiver.scala | 8 +++++++- 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleReadRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleReadRDD.scala index 110b51f8922d8..01a7172357883 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleReadRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleReadRDD.scala @@ -38,7 +38,9 @@ case class ContinuousShuffleReadPartition(index: Int) extends Partition { } /** - * RDD at the bottom of each continuous processing shuffle task, reading from the + * RDD at the map side of each continuous processing shuffle task. Upstream tasks send their + * shuffle output to the wrapped receivers in partitions of this RDD; each of the RDD's tasks + * poll from their receiver until an epoch marker is sent. */ class ContinuousShuffleReadRDD(sc: SparkContext, numPartitions: Int) extends RDD[UnsafeRow](sc, Nil) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/UnsafeRowReceiver.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/UnsafeRowReceiver.scala index 637305e54519f..addfe357f63f1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/UnsafeRowReceiver.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/UnsafeRowReceiver.scala @@ -32,10 +32,16 @@ private[shuffle] case class ReceiverRow(row: UnsafeRow) extends UnsafeRowReceive private[shuffle] case class ReceiverEpochMarker() extends UnsafeRowReceiverMessage /** - * RPC endpoint for receiving rows into a continuous processing shuffle task. + * RPC endpoint for receiving rows into a continuous processing shuffle task. Continuous shuffle + * writers will send rows here, with continuous shuffle readers polling for new rows as needed. + * + * TODO: Support multiple source tasks. We need to output a single epoch marker once all + * source tasks have sent one. */ private[shuffle] class UnsafeRowReceiver(val rpcEnv: RpcEnv) extends ThreadSafeRpcEndpoint with Logging { + // Note that this queue will be drained from the main task thread and populated in the RPC + // response thread. private val queue = new ArrayBlockingQueue[UnsafeRowReceiverMessage](1024) var stopped = new AtomicBoolean(false) From 46456dc75a6aec9659b18523c421999debd060eb Mon Sep 17 00:00:00 2001 From: Jose Torres Date: Wed, 16 May 2018 20:22:49 -0700 Subject: [PATCH 03/28] fix ctor --- .../continuous/shuffle/ContinuousShuffleReadRDD.scala | 3 ++- .../streaming/continuous/shuffle/UnsafeRowReceiver.scala | 6 ++++-- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleReadRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleReadRDD.scala index 01a7172357883..f69977749de08 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleReadRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleReadRDD.scala @@ -22,13 +22,14 @@ import java.util.UUID import org.apache.spark.{Partition, SparkContext, SparkEnv, TaskContext} import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.expressions.UnsafeRow +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.util.NextIterator case class ContinuousShuffleReadPartition(index: Int) extends Partition { // Initialized only on the executor, and only once even as we call compute() multiple times. lazy val (receiver, endpoint) = { val env = SparkEnv.get.rpcEnv - val receiver = new UnsafeRowReceiver(env) + val receiver = new UnsafeRowReceiver(SQLConf.get.continuousStreamingExecutorQueueSize, env) val endpoint = env.setupEndpoint(UUID.randomUUID().toString, receiver) TaskContext.get().addTaskCompletionListener { ctx => env.stop(endpoint) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/UnsafeRowReceiver.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/UnsafeRowReceiver.scala index addfe357f63f1..f50b89f276eaa 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/UnsafeRowReceiver.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/UnsafeRowReceiver.scala @@ -38,11 +38,13 @@ private[shuffle] case class ReceiverEpochMarker() extends UnsafeRowReceiverMessa * TODO: Support multiple source tasks. We need to output a single epoch marker once all * source tasks have sent one. */ -private[shuffle] class UnsafeRowReceiver(val rpcEnv: RpcEnv) +private[shuffle] class UnsafeRowReceiver( + queueSize: Int, + override val rpcEnv: RpcEnv) extends ThreadSafeRpcEndpoint with Logging { // Note that this queue will be drained from the main task thread and populated in the RPC // response thread. - private val queue = new ArrayBlockingQueue[UnsafeRowReceiverMessage](1024) + private val queue = new ArrayBlockingQueue[UnsafeRowReceiverMessage](queueSize) var stopped = new AtomicBoolean(false) override def onStop(): Unit = { From 2ea8a6f94216e8b184e5780ec3e6ffb2838de382 Mon Sep 17 00:00:00 2001 From: Jose Torres Date: Wed, 16 May 2018 20:43:10 -0700 Subject: [PATCH 04/28] multiple partition test --- .../shuffle/ContinuousShuffleReadSuite.scala | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/shuffle/ContinuousShuffleReadSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/shuffle/ContinuousShuffleReadSuite.scala index ee0ff8ada23f4..ad841626c8b10 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/shuffle/ContinuousShuffleReadSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/shuffle/ContinuousShuffleReadSuite.scala @@ -119,4 +119,18 @@ class ContinuousShuffleReadSuite extends StreamTest { assert(rdd.compute(rdd.partitions(0), ctx).isEmpty) assert(rdd.compute(rdd.partitions(0), ctx).isEmpty) } + + test("multiple partitions") { + val rdd = new ContinuousShuffleReadRDD(sparkContext, numPartitions = 5) + for (p <- rdd.partitions) { + val part = p.asInstanceOf[ContinuousShuffleReadPartition] + // Send index to ensure data doesn't somehow cross over between partitions. + part.endpoint.askSync[Unit](ReceiverRow(unsafeRow(part.index))) + part.endpoint.askSync[Unit](ReceiverEpochMarker()) + + val iter = rdd.compute(part, ctx) + assert(iter.next().getInt(0) == part.index) + assert(!iter.hasNext) + } + } } From 955ac79eb05dc389e632d1aaa6c59396835c6ed5 Mon Sep 17 00:00:00 2001 From: Jose Torres Date: Thu, 17 May 2018 06:33:51 -0700 Subject: [PATCH 05/28] unset task context after test --- .../continuous/shuffle/ContinuousShuffleReadSuite.scala | 1 + 1 file changed, 1 insertion(+) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/shuffle/ContinuousShuffleReadSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/shuffle/ContinuousShuffleReadSuite.scala index ad841626c8b10..faec753768895 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/shuffle/ContinuousShuffleReadSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/shuffle/ContinuousShuffleReadSuite.scala @@ -39,6 +39,7 @@ class ContinuousShuffleReadSuite extends StreamTest { override def afterEach(): Unit = { ctx.markTaskCompleted(None) + TaskContext.unset() ctx = null super.afterEach() } From 8cefb724512b51f2aa1fdd81fa8a2d4560e60ce3 Mon Sep 17 00:00:00 2001 From: Jose Torres Date: Thu, 17 May 2018 17:00:05 -0700 Subject: [PATCH 06/28] conf from RDD --- .../shuffle/ContinuousShuffleReadRDD.scala | 31 ++++++++++++------- 1 file changed, 20 insertions(+), 11 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleReadRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleReadRDD.scala index f69977749de08..a85ba28a8d528 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleReadRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleReadRDD.scala @@ -20,22 +20,19 @@ package org.apache.spark.sql.execution.streaming.continuous.shuffle import java.util.UUID import org.apache.spark.{Partition, SparkContext, SparkEnv, TaskContext} + import org.apache.spark.rdd.RDD +import org.apache.spark.rpc.{RpcEndpoint, RpcEndpointRef} import org.apache.spark.sql.catalyst.expressions.UnsafeRow import org.apache.spark.sql.internal.SQLConf import org.apache.spark.util.NextIterator case class ContinuousShuffleReadPartition(index: Int) extends Partition { - // Initialized only on the executor, and only once even as we call compute() multiple times. - lazy val (receiver, endpoint) = { - val env = SparkEnv.get.rpcEnv - val receiver = new UnsafeRowReceiver(SQLConf.get.continuousStreamingExecutorQueueSize, env) - val endpoint = env.setupEndpoint(UUID.randomUUID().toString, receiver) - TaskContext.get().addTaskCompletionListener { ctx => - env.stop(endpoint) - } - (receiver, endpoint) - } + // Semantically lazy vals - initialized only on the executor, and only once even as we call + // compute() multiple times. We need to initialize them inside compute() so we have access to the + // RDD's conf. + var receiver: UnsafeRowReceiver = _ + var endpoint: RpcEndpointRef = _ } /** @@ -46,14 +43,26 @@ case class ContinuousShuffleReadPartition(index: Int) extends Partition { class ContinuousShuffleReadRDD(sc: SparkContext, numPartitions: Int) extends RDD[UnsafeRow](sc, Nil) { + private val queueSize = sc.conf.get(SQLConf.CONTINUOUS_STREAMING_EXECUTOR_QUEUE_SIZE) + override protected def getPartitions: Array[Partition] = { (0 until numPartitions).map(ContinuousShuffleReadPartition).toArray } override def compute(split: Partition, context: TaskContext): Iterator[UnsafeRow] = { - val receiver = split.asInstanceOf[ContinuousShuffleReadPartition].receiver + val part = split.asInstanceOf[ContinuousShuffleReadPartition] + if (part.receiver == null) { + val env = SparkEnv.get.rpcEnv + part.receiver = new UnsafeRowReceiver(queueSize, env) + part.endpoint = env.setupEndpoint(UUID.randomUUID().toString, part.receiver) + TaskContext.get().addTaskCompletionListener { _ => + env.stop(part.endpoint) + } + } new NextIterator[UnsafeRow] { + private val receiver = part.receiver + override def getNext(): UnsafeRow = receiver.poll() match { case ReceiverRow(r) => r case ReceiverEpochMarker() => From f91bfe7e3fc174202d7d5c7cde5a8fb7ce86bfd3 Mon Sep 17 00:00:00 2001 From: Jose Torres Date: Thu, 17 May 2018 17:00:44 -0700 Subject: [PATCH 07/28] endpoint name --- .../streaming/continuous/shuffle/ContinuousShuffleReadRDD.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleReadRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleReadRDD.scala index a85ba28a8d528..d4128c5e274fe 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleReadRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleReadRDD.scala @@ -54,7 +54,7 @@ class ContinuousShuffleReadRDD(sc: SparkContext, numPartitions: Int) if (part.receiver == null) { val env = SparkEnv.get.rpcEnv part.receiver = new UnsafeRowReceiver(queueSize, env) - part.endpoint = env.setupEndpoint(UUID.randomUUID().toString, part.receiver) + part.endpoint = env.setupEndpoint(s"UnsafeRowReceiver-${UUID.randomUUID()}", part.receiver) TaskContext.get().addTaskCompletionListener { _ => env.stop(part.endpoint) } From 259029298fc42a65e8ebb4d2effe49b7fafa96f1 Mon Sep 17 00:00:00 2001 From: Jose Torres Date: Thu, 17 May 2018 17:02:08 -0700 Subject: [PATCH 08/28] testing bool --- .../streaming/continuous/shuffle/UnsafeRowReceiver.scala | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/UnsafeRowReceiver.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/UnsafeRowReceiver.scala index f50b89f276eaa..d6a6d978d2653 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/UnsafeRowReceiver.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/UnsafeRowReceiver.scala @@ -45,7 +45,9 @@ private[shuffle] class UnsafeRowReceiver( // Note that this queue will be drained from the main task thread and populated in the RPC // response thread. private val queue = new ArrayBlockingQueue[UnsafeRowReceiverMessage](queueSize) - var stopped = new AtomicBoolean(false) + + // Exposed for testing to determine if the endpoint gets stopped on task end. + private[shuffle] val stopped = new AtomicBoolean(false) override def onStop(): Unit = { stopped.set(true) From 859e6e4dd4dd90ffd70fc9cbd243c94090d72506 Mon Sep 17 00:00:00 2001 From: Jose Torres Date: Thu, 17 May 2018 17:22:10 -0700 Subject: [PATCH 09/28] tests --- .../shuffle/ContinuousShuffleReadRDD.scala | 33 ++++++++----------- .../shuffle/ContinuousShuffleReadSuite.scala | 23 ++++++------- 2 files changed, 25 insertions(+), 31 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleReadRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleReadRDD.scala index d4128c5e274fe..7f5aaecabf5ad 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleReadRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleReadRDD.scala @@ -20,19 +20,22 @@ package org.apache.spark.sql.execution.streaming.continuous.shuffle import java.util.UUID import org.apache.spark.{Partition, SparkContext, SparkEnv, TaskContext} - import org.apache.spark.rdd.RDD -import org.apache.spark.rpc.{RpcEndpoint, RpcEndpointRef} import org.apache.spark.sql.catalyst.expressions.UnsafeRow import org.apache.spark.sql.internal.SQLConf import org.apache.spark.util.NextIterator -case class ContinuousShuffleReadPartition(index: Int) extends Partition { - // Semantically lazy vals - initialized only on the executor, and only once even as we call - // compute() multiple times. We need to initialize them inside compute() so we have access to the - // RDD's conf. - var receiver: UnsafeRowReceiver = _ - var endpoint: RpcEndpointRef = _ +case class ContinuousShuffleReadPartition(index: Int, queueSize: Int) extends Partition { + // Initialized only on the executor, and only once even as we call compute() multiple times. + lazy val (receiver, endpoint) = { + val env = SparkEnv.get.rpcEnv + val receiver = new UnsafeRowReceiver(queueSize, env) + val endpoint = env.setupEndpoint(s"UnsafeRowReceiver-${UUID.randomUUID().toString}", receiver) + TaskContext.get().addTaskCompletionListener { ctx => + env.stop(endpoint) + } + (receiver, endpoint) + } } /** @@ -46,23 +49,13 @@ class ContinuousShuffleReadRDD(sc: SparkContext, numPartitions: Int) private val queueSize = sc.conf.get(SQLConf.CONTINUOUS_STREAMING_EXECUTOR_QUEUE_SIZE) override protected def getPartitions: Array[Partition] = { - (0 until numPartitions).map(ContinuousShuffleReadPartition).toArray + (0 until numPartitions).map(ContinuousShuffleReadPartition(_, queueSize)).toArray } override def compute(split: Partition, context: TaskContext): Iterator[UnsafeRow] = { - val part = split.asInstanceOf[ContinuousShuffleReadPartition] - if (part.receiver == null) { - val env = SparkEnv.get.rpcEnv - part.receiver = new UnsafeRowReceiver(queueSize, env) - part.endpoint = env.setupEndpoint(s"UnsafeRowReceiver-${UUID.randomUUID()}", part.receiver) - TaskContext.get().addTaskCompletionListener { _ => - env.stop(part.endpoint) - } - } + val receiver = split.asInstanceOf[ContinuousShuffleReadPartition].receiver new NextIterator[UnsafeRow] { - private val receiver = part.receiver - override def getNext(): UnsafeRow = receiver.poll() match { case ReceiverRow(r) => r case ReceiverEpochMarker() => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/shuffle/ContinuousShuffleReadSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/shuffle/ContinuousShuffleReadSuite.scala index faec753768895..19558e1c49319 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/shuffle/ContinuousShuffleReadSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/shuffle/ContinuousShuffleReadSuite.scala @@ -79,10 +79,7 @@ class ContinuousShuffleReadSuite extends StreamTest { endpoint.askSync[Unit](ReceiverEpochMarker()) val iter = rdd.compute(rdd.partitions(0), ctx) - assert(iter.next().getInt(0) == 111) - assert(iter.next().getInt(0) == 222) - assert(iter.next().getInt(0) == 333) - assert(!iter.hasNext) + assert(iter.toSeq.map(_.getInt(0)) == Seq(111, 222, 333)) } test("multiple epochs") { @@ -95,13 +92,10 @@ class ContinuousShuffleReadSuite extends StreamTest { endpoint.askSync[Unit](ReceiverEpochMarker()) val firstEpoch = rdd.compute(rdd.partitions(0), ctx) - assert(firstEpoch.next().getInt(0) == 111) - assert(!firstEpoch.hasNext) + assert(firstEpoch.toSeq.map(_.getInt(0)) == Seq(111)) val secondEpoch = rdd.compute(rdd.partitions(0), ctx) - assert(secondEpoch.next().getInt(0) == 222) - assert(secondEpoch.next().getInt(0) == 333) - assert(!secondEpoch.hasNext) + assert(secondEpoch.toSeq.map(_.getInt(0)) == Seq(222, 333)) } test("empty epochs") { @@ -112,23 +106,30 @@ class ContinuousShuffleReadSuite extends StreamTest { endpoint.askSync[Unit](ReceiverRow(unsafeRow(111))) endpoint.askSync[Unit](ReceiverEpochMarker()) endpoint.askSync[Unit](ReceiverEpochMarker()) + endpoint.askSync[Unit](ReceiverEpochMarker()) assert(rdd.compute(rdd.partitions(0), ctx).isEmpty) assert(rdd.compute(rdd.partitions(0), ctx).isEmpty) + val thirdEpoch = rdd.compute(rdd.partitions(0), ctx) - assert(thirdEpoch.next().getInt(0) == 111) + assert(thirdEpoch.toSeq.map(_.getInt(0)) == Seq(111)) + assert(rdd.compute(rdd.partitions(0), ctx).isEmpty) assert(rdd.compute(rdd.partitions(0), ctx).isEmpty) } test("multiple partitions") { val rdd = new ContinuousShuffleReadRDD(sparkContext, numPartitions = 5) + // Send all data before processing to ensure there's no crossover. for (p <- rdd.partitions) { val part = p.asInstanceOf[ContinuousShuffleReadPartition] - // Send index to ensure data doesn't somehow cross over between partitions. + // Send index for identification. part.endpoint.askSync[Unit](ReceiverRow(unsafeRow(part.index))) part.endpoint.askSync[Unit](ReceiverEpochMarker()) + } + for (p <- rdd.partitions) { + val part = p.asInstanceOf[ContinuousShuffleReadPartition] val iter = rdd.compute(part, ctx) assert(iter.next().getInt(0) == part.index) assert(!iter.hasNext) From b23b7bb17abe3cbc873a3144c56d08c88bc0c963 Mon Sep 17 00:00:00 2001 From: Jose Torres Date: Thu, 17 May 2018 17:40:55 -0700 Subject: [PATCH 10/28] take instead of poll --- .../shuffle/ContinuousShuffleReadRDD.scala | 2 +- .../shuffle/UnsafeRowReceiver.scala | 4 ++-- .../shuffle/ContinuousShuffleReadSuite.scala | 21 +++++++++++++++++++ 3 files changed, 24 insertions(+), 3 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleReadRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleReadRDD.scala index 7f5aaecabf5ad..16d7f29ea31cc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleReadRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleReadRDD.scala @@ -56,7 +56,7 @@ class ContinuousShuffleReadRDD(sc: SparkContext, numPartitions: Int) val receiver = split.asInstanceOf[ContinuousShuffleReadPartition].receiver new NextIterator[UnsafeRow] { - override def getNext(): UnsafeRow = receiver.poll() match { + override def getNext(): UnsafeRow = receiver.take() match { case ReceiverRow(r) => r case ReceiverEpochMarker() => finished = true diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/UnsafeRowReceiver.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/UnsafeRowReceiver.scala index d6a6d978d2653..37cee999b1bed 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/UnsafeRowReceiver.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/UnsafeRowReceiver.scala @@ -60,7 +60,7 @@ private[shuffle] class UnsafeRowReceiver( } /** - * Polls until a new row is available. + * Take the next row, blocking until it's ready. */ - def poll(): UnsafeRowReceiverMessage = queue.poll() + def take(): UnsafeRowReceiverMessage = queue.take() } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/shuffle/ContinuousShuffleReadSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/shuffle/ContinuousShuffleReadSuite.scala index 19558e1c49319..6f201186d572c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/shuffle/ContinuousShuffleReadSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/shuffle/ContinuousShuffleReadSuite.scala @@ -29,6 +29,9 @@ class ContinuousShuffleReadSuite extends StreamTest { new GenericInternalRow(Array(value: Any))) } + // In this unit test, we emulate that we're in the task thread where + // ContinuousShuffleReadRDD.compute() will be evaluated. This requires a task context + // thread local to be set. var ctx: TaskContextImpl = _ override def beforeEach(): Unit = { @@ -135,4 +138,22 @@ class ContinuousShuffleReadSuite extends StreamTest { assert(!iter.hasNext) } } + + test("blocks waiting for new rows") { + val rdd = new ContinuousShuffleReadRDD(sparkContext, numPartitions = 1) + + val readRow = new Thread { + override def run(): Unit = { + // set the non-inheritable thread local + TaskContext.setTaskContext(ctx) + val epoch = rdd.compute(rdd.partitions(0), ctx) + epoch.next().getInt(0) + } + } + + readRow.start() + eventually(timeout(streamingTimeout)) { + assert(readRow.getState == Thread.State.WAITING) + } + } } From 97f7e8ff865e6054d0d70914ce9bb51880b161f6 Mon Sep 17 00:00:00 2001 From: Jose Torres Date: Thu, 17 May 2018 17:58:44 -0700 Subject: [PATCH 11/28] add interface --- .../shuffle/ContinuousShuffleReadRDD.scala | 15 ++------- .../shuffle/ContinuousShuffleReader.scala | 31 +++++++++++++++++++ .../shuffle/UnsafeRowReceiver.scala | 19 +++++++++--- .../shuffle/ContinuousShuffleReadSuite.scala | 8 ++--- 4 files changed, 51 insertions(+), 22 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleReader.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleReadRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleReadRDD.scala index 16d7f29ea31cc..74419a5b35fdb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleReadRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleReadRDD.scala @@ -27,7 +27,7 @@ import org.apache.spark.util.NextIterator case class ContinuousShuffleReadPartition(index: Int, queueSize: Int) extends Partition { // Initialized only on the executor, and only once even as we call compute() multiple times. - lazy val (receiver, endpoint) = { + lazy val (reader: ContinuousShuffleReader, endpoint) = { val env = SparkEnv.get.rpcEnv val receiver = new UnsafeRowReceiver(queueSize, env) val endpoint = env.setupEndpoint(s"UnsafeRowReceiver-${UUID.randomUUID().toString}", receiver) @@ -53,17 +53,6 @@ class ContinuousShuffleReadRDD(sc: SparkContext, numPartitions: Int) } override def compute(split: Partition, context: TaskContext): Iterator[UnsafeRow] = { - val receiver = split.asInstanceOf[ContinuousShuffleReadPartition].receiver - - new NextIterator[UnsafeRow] { - override def getNext(): UnsafeRow = receiver.take() match { - case ReceiverRow(r) => r - case ReceiverEpochMarker() => - finished = true - null - } - - override def close(): Unit = {} - } + split.asInstanceOf[ContinuousShuffleReadPartition].reader.read() } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleReader.scala new file mode 100644 index 0000000000000..08091bbc12893 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleReader.scala @@ -0,0 +1,31 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.continuous.shuffle + +import org.apache.spark.sql.catalyst.expressions.UnsafeRow + +/** + * Trait for reading from a continuous processing shuffle. + */ +trait ContinuousShuffleReader { + /** + * Returns an iterator over the incoming rows in the current epoch. Note that this iterator can + * block waiting for new rows to arrive. + */ + def read(): Iterator[UnsafeRow] +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/UnsafeRowReceiver.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/UnsafeRowReceiver.scala index 37cee999b1bed..b8adbb743c6c2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/UnsafeRowReceiver.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/UnsafeRowReceiver.scala @@ -23,6 +23,7 @@ import java.util.concurrent.atomic.AtomicBoolean import org.apache.spark.internal.Logging import org.apache.spark.rpc.{RpcCallContext, RpcEnv, ThreadSafeRpcEndpoint} import org.apache.spark.sql.catalyst.expressions.UnsafeRow +import org.apache.spark.util.NextIterator /** * Messages for the UnsafeRowReceiver endpoint. Either an incoming row or an epoch marker. @@ -41,7 +42,7 @@ private[shuffle] case class ReceiverEpochMarker() extends UnsafeRowReceiverMessa private[shuffle] class UnsafeRowReceiver( queueSize: Int, override val rpcEnv: RpcEnv) - extends ThreadSafeRpcEndpoint with Logging { + extends ThreadSafeRpcEndpoint with ContinuousShuffleReader with Logging { // Note that this queue will be drained from the main task thread and populated in the RPC // response thread. private val queue = new ArrayBlockingQueue[UnsafeRowReceiverMessage](queueSize) @@ -59,8 +60,16 @@ private[shuffle] class UnsafeRowReceiver( context.reply(()) } - /** - * Take the next row, blocking until it's ready. - */ - def take(): UnsafeRowReceiverMessage = queue.take() + override def read(): Iterator[UnsafeRow] = { + new NextIterator[UnsafeRow] { + override def getNext(): UnsafeRow = queue.take() match { + case ReceiverRow(r) => r + case ReceiverEpochMarker() => + finished = true + null + } + + override def close(): Unit = {} + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/shuffle/ContinuousShuffleReadSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/shuffle/ContinuousShuffleReadSuite.scala index 6f201186d572c..718fb0740bbe7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/shuffle/ContinuousShuffleReadSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/shuffle/ContinuousShuffleReadSuite.scala @@ -54,9 +54,9 @@ class ContinuousShuffleReadSuite extends StreamTest { endpoint.askSync[Unit](ReceiverRow(unsafeRow(111))) ctx.markTaskCompleted(None) - val receiver = rdd.partitions(0).asInstanceOf[ContinuousShuffleReadPartition].receiver + val receiver = rdd.partitions(0).asInstanceOf[ContinuousShuffleReadPartition].reader eventually(timeout(streamingTimeout)) { - assert(receiver.stopped.get()) + assert(receiver.asInstanceOf[UnsafeRowReceiver].stopped.get()) } } @@ -67,9 +67,9 @@ class ContinuousShuffleReadSuite extends StreamTest { endpoint.askSync[Unit](ReceiverEpochMarker()) ctx.markTaskCompleted(None) - val receiver = rdd.partitions(0).asInstanceOf[ContinuousShuffleReadPartition].receiver + val receiver = rdd.partitions(0).asInstanceOf[ContinuousShuffleReadPartition].reader eventually(timeout(streamingTimeout)) { - assert(receiver.stopped.get()) + assert(receiver.asInstanceOf[UnsafeRowReceiver].stopped.get()) } } From de21b1c25a333d44c0521fe151b468e51f0bdc47 Mon Sep 17 00:00:00 2001 From: Jose Torres Date: Thu, 17 May 2018 18:02:37 -0700 Subject: [PATCH 12/28] clarify comment --- .../continuous/shuffle/ContinuousShuffleReader.scala | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleReader.scala index 08091bbc12893..42631c90ebc55 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleReader.scala @@ -24,8 +24,9 @@ import org.apache.spark.sql.catalyst.expressions.UnsafeRow */ trait ContinuousShuffleReader { /** - * Returns an iterator over the incoming rows in the current epoch. Note that this iterator can - * block waiting for new rows to arrive. + * Returns an iterator over the incoming rows in an epoch. Implementations should block waiting + * for new rows to arrive, and end the iterator once they've received epoch markers from all + * shuffle writers. */ def read(): Iterator[UnsafeRow] } From 7dcf51a13e92a0bb2998e2a12e67d351e1c1a4fc Mon Sep 17 00:00:00 2001 From: Jose Torres Date: Fri, 18 May 2018 15:39:28 -0700 Subject: [PATCH 13/28] multiple --- .../shuffle/ContinuousShuffleReadRDD.scala | 11 +- .../shuffle/UnsafeRowReceiver.scala | 60 +++++++--- .../shuffle/ContinuousShuffleReadSuite.scala | 105 ++++++++++++++---- 3 files changed, 137 insertions(+), 39 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleReadRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleReadRDD.scala index 74419a5b35fdb..05fdeeb4450a3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleReadRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleReadRDD.scala @@ -25,11 +25,12 @@ import org.apache.spark.sql.catalyst.expressions.UnsafeRow import org.apache.spark.sql.internal.SQLConf import org.apache.spark.util.NextIterator -case class ContinuousShuffleReadPartition(index: Int, queueSize: Int) extends Partition { +case class ContinuousShuffleReadPartition(index: Int, queueSize: Int, numShuffleWriters: Int) + extends Partition { // Initialized only on the executor, and only once even as we call compute() multiple times. lazy val (reader: ContinuousShuffleReader, endpoint) = { val env = SparkEnv.get.rpcEnv - val receiver = new UnsafeRowReceiver(queueSize, env) + val receiver = new UnsafeRowReceiver(queueSize, numShuffleWriters, env) val endpoint = env.setupEndpoint(s"UnsafeRowReceiver-${UUID.randomUUID().toString}", receiver) TaskContext.get().addTaskCompletionListener { ctx => env.stop(endpoint) @@ -43,13 +44,15 @@ case class ContinuousShuffleReadPartition(index: Int, queueSize: Int) extends Pa * shuffle output to the wrapped receivers in partitions of this RDD; each of the RDD's tasks * poll from their receiver until an epoch marker is sent. */ -class ContinuousShuffleReadRDD(sc: SparkContext, numPartitions: Int) +class ContinuousShuffleReadRDD(sc: SparkContext, numPartitions: Int, numShuffleWriters: Int = 1) extends RDD[UnsafeRow](sc, Nil) { private val queueSize = sc.conf.get(SQLConf.CONTINUOUS_STREAMING_EXECUTOR_QUEUE_SIZE) override protected def getPartitions: Array[Partition] = { - (0 until numPartitions).map(ContinuousShuffleReadPartition(_, queueSize)).toArray + (0 until numPartitions).map { partIndex => + ContinuousShuffleReadPartition(partIndex, queueSize, numShuffleWriters) + }.toArray } override def compute(split: Partition, context: TaskContext): Iterator[UnsafeRow] = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/UnsafeRowReceiver.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/UnsafeRowReceiver.scala index b8adbb743c6c2..c7f3f0db41440 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/UnsafeRowReceiver.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/UnsafeRowReceiver.scala @@ -17,8 +17,10 @@ package org.apache.spark.sql.execution.streaming.continuous.shuffle -import java.util.concurrent.{ArrayBlockingQueue, BlockingQueue} -import java.util.concurrent.atomic.AtomicBoolean +import java.util.concurrent._ +import java.util.concurrent.atomic.{AtomicBoolean, AtomicInteger} + +import scala.concurrent.Future import org.apache.spark.internal.Logging import org.apache.spark.rpc.{RpcCallContext, RpcEnv, ThreadSafeRpcEndpoint} @@ -28,9 +30,12 @@ import org.apache.spark.util.NextIterator /** * Messages for the UnsafeRowReceiver endpoint. Either an incoming row or an epoch marker. */ -private[shuffle] sealed trait UnsafeRowReceiverMessage extends Serializable -private[shuffle] case class ReceiverRow(row: UnsafeRow) extends UnsafeRowReceiverMessage -private[shuffle] case class ReceiverEpochMarker() extends UnsafeRowReceiverMessage +private[shuffle] sealed trait UnsafeRowReceiverMessage extends Serializable { + def writerId: Int +} +private[shuffle] case class ReceiverRow(writerId: Int, row: UnsafeRow) + extends UnsafeRowReceiverMessage +private[shuffle] case class ReceiverEpochMarker(writerId: Int) extends UnsafeRowReceiverMessage /** * RPC endpoint for receiving rows into a continuous processing shuffle task. Continuous shuffle @@ -41,11 +46,14 @@ private[shuffle] case class ReceiverEpochMarker() extends UnsafeRowReceiverMessa */ private[shuffle] class UnsafeRowReceiver( queueSize: Int, + numShuffleWriters: Int, override val rpcEnv: RpcEnv) extends ThreadSafeRpcEndpoint with ContinuousShuffleReader with Logging { // Note that this queue will be drained from the main task thread and populated in the RPC // response thread. - private val queue = new ArrayBlockingQueue[UnsafeRowReceiverMessage](queueSize) + private val queues = Array.fill(numShuffleWriters) { + new ArrayBlockingQueue[UnsafeRowReceiverMessage](queueSize) + } // Exposed for testing to determine if the endpoint gets stopped on task end. private[shuffle] val stopped = new AtomicBoolean(false) @@ -56,20 +64,46 @@ private[shuffle] class UnsafeRowReceiver( override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { case r: UnsafeRowReceiverMessage => - queue.put(r) + queues(r.writerId).put(r) context.reply(()) } override def read(): Iterator[UnsafeRow] = { new NextIterator[UnsafeRow] { - override def getNext(): UnsafeRow = queue.take() match { - case ReceiverRow(r) => r - case ReceiverEpochMarker() => - finished = true - null + private val numWriterEpochMarkers = new AtomicInteger(0) + + private val executor = Executors.newFixedThreadPool(numShuffleWriters) + private val completion = new ExecutorCompletionService[UnsafeRowReceiverMessage](executor) + + private def completionTask(writerId: Int) = new Callable[UnsafeRowReceiverMessage] { + override def call(): UnsafeRowReceiverMessage = queues(writerId).take() } - override def close(): Unit = {} + (0 until numShuffleWriters).foreach(writerId => completion.submit(completionTask(writerId))) + + override def getNext(): UnsafeRow = { + completion.take().get() match { + case ReceiverRow(writerId, r) => + // Start reading the next element in the queue we just took from. + completion.submit(completionTask(writerId)) + r + // TODO use writerId + case ReceiverEpochMarker(writerId) => + // Don't read any more from this queue. If all the writers have sent epoch markers, + // the epoch is over; otherwise we need rows from one of the remaining writers. + val writersCompleted = numWriterEpochMarkers.incrementAndGet() + if (writersCompleted == numShuffleWriters) { + finished = true + null + } else { + getNext() + } + } + } + + override def close(): Unit = { + executor.shutdownNow() + } } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/shuffle/ContinuousShuffleReadSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/shuffle/ContinuousShuffleReadSuite.scala index 718fb0740bbe7..156c97e27bdda 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/shuffle/ContinuousShuffleReadSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/shuffle/ContinuousShuffleReadSuite.scala @@ -17,10 +17,13 @@ package org.apache.spark.sql.execution.streaming.continuous.shuffle +import scala.concurrent.Future + import org.apache.spark.{TaskContext, TaskContextImpl} import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeProjection} import org.apache.spark.sql.streaming.StreamTest -import org.apache.spark.sql.types.{DataType, IntegerType} +import org.apache.spark.sql.types.{DataType, IntegerType, StringType} +import org.apache.spark.unsafe.types.UTF8String class ContinuousShuffleReadSuite extends StreamTest { @@ -29,6 +32,11 @@ class ContinuousShuffleReadSuite extends StreamTest { new GenericInternalRow(Array(value: Any))) } + private def unsafeRow(value: String) = { + UnsafeProjection.create(Array(StringType : DataType))( + new GenericInternalRow(Array(UTF8String.fromString(value): Any))) + } + // In this unit test, we emulate that we're in the task thread where // ContinuousShuffleReadRDD.compute() will be evaluated. This requires a task context // thread local to be set. @@ -50,8 +58,8 @@ class ContinuousShuffleReadSuite extends StreamTest { test("receiver stopped with row last") { val rdd = new ContinuousShuffleReadRDD(sparkContext, numPartitions = 1) val endpoint = rdd.partitions(0).asInstanceOf[ContinuousShuffleReadPartition].endpoint - endpoint.askSync[Unit](ReceiverEpochMarker()) - endpoint.askSync[Unit](ReceiverRow(unsafeRow(111))) + endpoint.askSync[Unit](ReceiverEpochMarker(0)) + endpoint.askSync[Unit](ReceiverRow(0, unsafeRow(111))) ctx.markTaskCompleted(None) val receiver = rdd.partitions(0).asInstanceOf[ContinuousShuffleReadPartition].reader @@ -63,8 +71,8 @@ class ContinuousShuffleReadSuite extends StreamTest { test("receiver stopped with marker last") { val rdd = new ContinuousShuffleReadRDD(sparkContext, numPartitions = 1) val endpoint = rdd.partitions(0).asInstanceOf[ContinuousShuffleReadPartition].endpoint - endpoint.askSync[Unit](ReceiverRow(unsafeRow(111))) - endpoint.askSync[Unit](ReceiverEpochMarker()) + endpoint.askSync[Unit](ReceiverRow(0, unsafeRow(111))) + endpoint.askSync[Unit](ReceiverEpochMarker(0)) ctx.markTaskCompleted(None) val receiver = rdd.partitions(0).asInstanceOf[ContinuousShuffleReadPartition].reader @@ -76,10 +84,10 @@ class ContinuousShuffleReadSuite extends StreamTest { test("one epoch") { val rdd = new ContinuousShuffleReadRDD(sparkContext, numPartitions = 1) val endpoint = rdd.partitions(0).asInstanceOf[ContinuousShuffleReadPartition].endpoint - endpoint.askSync[Unit](ReceiverRow(unsafeRow(111))) - endpoint.askSync[Unit](ReceiverRow(unsafeRow(222))) - endpoint.askSync[Unit](ReceiverRow(unsafeRow(333))) - endpoint.askSync[Unit](ReceiverEpochMarker()) + endpoint.askSync[Unit](ReceiverRow(0, unsafeRow(111))) + endpoint.askSync[Unit](ReceiverRow(0, unsafeRow(222))) + endpoint.askSync[Unit](ReceiverRow(0, unsafeRow(333))) + endpoint.askSync[Unit](ReceiverEpochMarker(0)) val iter = rdd.compute(rdd.partitions(0), ctx) assert(iter.toSeq.map(_.getInt(0)) == Seq(111, 222, 333)) @@ -88,11 +96,11 @@ class ContinuousShuffleReadSuite extends StreamTest { test("multiple epochs") { val rdd = new ContinuousShuffleReadRDD(sparkContext, numPartitions = 1) val endpoint = rdd.partitions(0).asInstanceOf[ContinuousShuffleReadPartition].endpoint - endpoint.askSync[Unit](ReceiverRow(unsafeRow(111))) - endpoint.askSync[Unit](ReceiverEpochMarker()) - endpoint.askSync[Unit](ReceiverRow(unsafeRow(222))) - endpoint.askSync[Unit](ReceiverRow(unsafeRow(333))) - endpoint.askSync[Unit](ReceiverEpochMarker()) + endpoint.askSync[Unit](ReceiverRow(0, unsafeRow(111))) + endpoint.askSync[Unit](ReceiverEpochMarker(0)) + endpoint.askSync[Unit](ReceiverRow(0, unsafeRow(222))) + endpoint.askSync[Unit](ReceiverRow(0, unsafeRow(333))) + endpoint.askSync[Unit](ReceiverEpochMarker(0)) val firstEpoch = rdd.compute(rdd.partitions(0), ctx) assert(firstEpoch.toSeq.map(_.getInt(0)) == Seq(111)) @@ -101,15 +109,30 @@ class ContinuousShuffleReadSuite extends StreamTest { assert(secondEpoch.toSeq.map(_.getInt(0)) == Seq(222, 333)) } + test("multiple writers") { + val rdd = new ContinuousShuffleReadRDD(sparkContext, numPartitions = 1, numShuffleWriters = 3) + val endpoint = rdd.partitions(0).asInstanceOf[ContinuousShuffleReadPartition].endpoint + endpoint.askSync[Unit](ReceiverRow(0, unsafeRow("writer0-row0"))) + endpoint.askSync[Unit](ReceiverRow(1, unsafeRow("writer1-row0"))) + endpoint.askSync[Unit](ReceiverRow(2, unsafeRow("writer2-row0"))) + endpoint.askSync[Unit](ReceiverEpochMarker(0)) + endpoint.askSync[Unit](ReceiverEpochMarker(1)) + endpoint.askSync[Unit](ReceiverEpochMarker(2)) + + val firstEpoch = rdd.compute(rdd.partitions(0), ctx) + assert(firstEpoch.toSeq.map(_.getUTF8String(0).toString).toSet == + Set("writer0-row0", "writer1-row0", "writer2-row0")) + } + test("empty epochs") { val rdd = new ContinuousShuffleReadRDD(sparkContext, numPartitions = 1) val endpoint = rdd.partitions(0).asInstanceOf[ContinuousShuffleReadPartition].endpoint - endpoint.askSync[Unit](ReceiverEpochMarker()) - endpoint.askSync[Unit](ReceiverEpochMarker()) - endpoint.askSync[Unit](ReceiverRow(unsafeRow(111))) - endpoint.askSync[Unit](ReceiverEpochMarker()) - endpoint.askSync[Unit](ReceiverEpochMarker()) - endpoint.askSync[Unit](ReceiverEpochMarker()) + endpoint.askSync[Unit](ReceiverEpochMarker(0)) + endpoint.askSync[Unit](ReceiverEpochMarker(0)) + endpoint.askSync[Unit](ReceiverRow(0, unsafeRow(111))) + endpoint.askSync[Unit](ReceiverEpochMarker(0)) + endpoint.askSync[Unit](ReceiverEpochMarker(0)) + endpoint.askSync[Unit](ReceiverEpochMarker(0)) assert(rdd.compute(rdd.partitions(0), ctx).isEmpty) assert(rdd.compute(rdd.partitions(0), ctx).isEmpty) @@ -127,8 +150,8 @@ class ContinuousShuffleReadSuite extends StreamTest { for (p <- rdd.partitions) { val part = p.asInstanceOf[ContinuousShuffleReadPartition] // Send index for identification. - part.endpoint.askSync[Unit](ReceiverRow(unsafeRow(part.index))) - part.endpoint.askSync[Unit](ReceiverEpochMarker()) + part.endpoint.askSync[Unit](ReceiverRow(0, unsafeRow(part.index))) + part.endpoint.askSync[Unit](ReceiverEpochMarker(0)) } for (p <- rdd.partitions) { @@ -156,4 +179,42 @@ class ContinuousShuffleReadSuite extends StreamTest { assert(readRow.getState == Thread.State.WAITING) } } + + test("epoch only ends when all writers send markers") { + val rdd = new ContinuousShuffleReadRDD(sparkContext, numPartitions = 1, numShuffleWriters = 3) + val endpoint = rdd.partitions(0).asInstanceOf[ContinuousShuffleReadPartition].endpoint + endpoint.askSync[Unit](ReceiverRow(0, unsafeRow("writer0-row0"))) + endpoint.askSync[Unit](ReceiverRow(1, unsafeRow("writer1-row0"))) + endpoint.askSync[Unit](ReceiverRow(2, unsafeRow("writer2-row0"))) + endpoint.askSync[Unit](ReceiverEpochMarker(0)) + endpoint.askSync[Unit](ReceiverEpochMarker(2)) + + val epoch = rdd.compute(rdd.partitions(0), ctx) + val rows = (0 until 3).map(_ => epoch.next()).toSet + assert(rows.map(_.getUTF8String(0).toString) == + Set("writer0-row0", "writer1-row0", "writer2-row0")) + + // After checking the right rows, block until we get an epoch marker indicating there's no next. + // (Also fail the assertion if for some reason we get a row.) + val readEpochMarker = new Thread { + override def run(): Unit = { + assert(!epoch.hasNext) + } + } + + readEpochMarker.start() + + eventually(timeout(streamingTimeout)) { + assert(readEpochMarker.getState == Thread.State.WAITING) + } + + // Send the last epoch marker - now the epoch should finish. + endpoint.askSync[Unit](ReceiverEpochMarker(1)) + eventually(timeout(streamingTimeout)) { + !readEpochMarker.isAlive + } + + // Join to pick up assertion failures. + readEpochMarker.join() + } } From 154843d799683c5cdfc035033475f223f85f0d66 Mon Sep 17 00:00:00 2001 From: Jose Torres Date: Fri, 18 May 2018 15:41:18 -0700 Subject: [PATCH 14/28] don't use spark conf for the sql conf --- .../shuffle/ContinuousShuffleReadRDD.scala | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleReadRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleReadRDD.scala index 74419a5b35fdb..b0a09180ccb60 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleReadRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleReadRDD.scala @@ -43,13 +43,16 @@ case class ContinuousShuffleReadPartition(index: Int, queueSize: Int) extends Pa * shuffle output to the wrapped receivers in partitions of this RDD; each of the RDD's tasks * poll from their receiver until an epoch marker is sent. */ -class ContinuousShuffleReadRDD(sc: SparkContext, numPartitions: Int) - extends RDD[UnsafeRow](sc, Nil) { - - private val queueSize = sc.conf.get(SQLConf.CONTINUOUS_STREAMING_EXECUTOR_QUEUE_SIZE) +class ContinuousShuffleReadRDD( + sc: SparkContext, + numPartitions: Int, + queueSize: Int = 1024) + extends RDD[UnsafeRow](sc, Nil) { override protected def getPartitions: Array[Partition] = { - (0 until numPartitions).map(ContinuousShuffleReadPartition(_, queueSize)).toArray + (0 until numPartitions).map { partIndex => + ContinuousShuffleReadPartition(partIndex, queueSize) + }.toArray } override def compute(split: Partition, context: TaskContext): Iterator[UnsafeRow] = { From f0262d0a9d3539bcf8fbdbb248968fd704d1e690 Mon Sep 17 00:00:00 2001 From: Jose Torres Date: Fri, 18 May 2018 15:54:27 -0700 Subject: [PATCH 15/28] end thread --- .../shuffle/ContinuousShuffleReadSuite.scala | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/shuffle/ContinuousShuffleReadSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/shuffle/ContinuousShuffleReadSuite.scala index 718fb0740bbe7..7706a37834d88 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/shuffle/ContinuousShuffleReadSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/shuffle/ContinuousShuffleReadSuite.scala @@ -151,9 +151,14 @@ class ContinuousShuffleReadSuite extends StreamTest { } } - readRow.start() - eventually(timeout(streamingTimeout)) { - assert(readRow.getState == Thread.State.WAITING) + try { + readRow.start() + eventually(timeout(streamingTimeout)) { + assert(readRow.getState == Thread.State.WAITING) + } + } finally { + readRow.interrupt() + readRow.join() } } } From 3e7a6f9d31967d9efc618c4d319a9dabd22ae4e5 Mon Sep 17 00:00:00 2001 From: Jose Torres Date: Fri, 18 May 2018 15:54:54 -0700 Subject: [PATCH 16/28] name thread --- .../shuffle/ContinuousShuffleReadSuite.scala | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/shuffle/ContinuousShuffleReadSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/shuffle/ContinuousShuffleReadSuite.scala index 7706a37834d88..d6e6eeb127a60 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/shuffle/ContinuousShuffleReadSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/shuffle/ContinuousShuffleReadSuite.scala @@ -142,7 +142,7 @@ class ContinuousShuffleReadSuite extends StreamTest { test("blocks waiting for new rows") { val rdd = new ContinuousShuffleReadRDD(sparkContext, numPartitions = 1) - val readRow = new Thread { + val readRowThread = new Thread { override def run(): Unit = { // set the non-inheritable thread local TaskContext.setTaskContext(ctx) @@ -152,13 +152,13 @@ class ContinuousShuffleReadSuite extends StreamTest { } try { - readRow.start() + readRowThread.start() eventually(timeout(streamingTimeout)) { - assert(readRow.getState == Thread.State.WAITING) + assert(readRowThread.getState == Thread.State.WAITING) } } finally { - readRow.interrupt() - readRow.join() + readRowThread.interrupt() + readRowThread.join() } } } From 0a38ced23b7e1a6dfe9588ef0ebf7c071a08055d Mon Sep 17 00:00:00 2001 From: Jose Torres Date: Fri, 18 May 2018 15:55:14 -0700 Subject: [PATCH 17/28] no toString --- .../streaming/continuous/shuffle/ContinuousShuffleReadRDD.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleReadRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleReadRDD.scala index b0a09180ccb60..270b1a5c28dee 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleReadRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleReadRDD.scala @@ -30,7 +30,7 @@ case class ContinuousShuffleReadPartition(index: Int, queueSize: Int) extends Pa lazy val (reader: ContinuousShuffleReader, endpoint) = { val env = SparkEnv.get.rpcEnv val receiver = new UnsafeRowReceiver(queueSize, env) - val endpoint = env.setupEndpoint(s"UnsafeRowReceiver-${UUID.randomUUID().toString}", receiver) + val endpoint = env.setupEndpoint(s"UnsafeRowReceiver-${UUID.randomUUID()}", receiver) TaskContext.get().addTaskCompletionListener { ctx => env.stop(endpoint) } From ef34e6e9817274df9378341bfb52105c591a5507 Mon Sep 17 00:00:00 2001 From: Jose Torres Date: Fri, 18 May 2018 16:00:01 -0700 Subject: [PATCH 18/28] send method --- .../shuffle/ContinuousShuffleReadSuite.scala | 57 ++++++++++++------- 1 file changed, 38 insertions(+), 19 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/shuffle/ContinuousShuffleReadSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/shuffle/ContinuousShuffleReadSuite.scala index d6e6eeb127a60..e71274365a759 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/shuffle/ContinuousShuffleReadSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/shuffle/ContinuousShuffleReadSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.execution.streaming.continuous.shuffle import org.apache.spark.{TaskContext, TaskContextImpl} +import org.apache.spark.rpc.RpcEndpointRef import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeProjection} import org.apache.spark.sql.streaming.StreamTest import org.apache.spark.sql.types.{DataType, IntegerType} @@ -29,6 +30,10 @@ class ContinuousShuffleReadSuite extends StreamTest { new GenericInternalRow(Array(value: Any))) } + private def send(endpoint: RpcEndpointRef, messages: UnsafeRowReceiverMessage*) = { + messages.foreach(endpoint.askSync[Unit](_)) + } + // In this unit test, we emulate that we're in the task thread where // ContinuousShuffleReadRDD.compute() will be evaluated. This requires a task context // thread local to be set. @@ -50,8 +55,11 @@ class ContinuousShuffleReadSuite extends StreamTest { test("receiver stopped with row last") { val rdd = new ContinuousShuffleReadRDD(sparkContext, numPartitions = 1) val endpoint = rdd.partitions(0).asInstanceOf[ContinuousShuffleReadPartition].endpoint - endpoint.askSync[Unit](ReceiverEpochMarker()) - endpoint.askSync[Unit](ReceiverRow(unsafeRow(111))) + send( + endpoint, + ReceiverEpochMarker(), + ReceiverRow(unsafeRow(111)) + ) ctx.markTaskCompleted(None) val receiver = rdd.partitions(0).asInstanceOf[ContinuousShuffleReadPartition].reader @@ -76,10 +84,13 @@ class ContinuousShuffleReadSuite extends StreamTest { test("one epoch") { val rdd = new ContinuousShuffleReadRDD(sparkContext, numPartitions = 1) val endpoint = rdd.partitions(0).asInstanceOf[ContinuousShuffleReadPartition].endpoint - endpoint.askSync[Unit](ReceiverRow(unsafeRow(111))) - endpoint.askSync[Unit](ReceiverRow(unsafeRow(222))) - endpoint.askSync[Unit](ReceiverRow(unsafeRow(333))) - endpoint.askSync[Unit](ReceiverEpochMarker()) + send( + endpoint, + ReceiverRow(unsafeRow(111)), + ReceiverRow(unsafeRow(222)), + ReceiverRow(unsafeRow(333)), + ReceiverEpochMarker() + ) val iter = rdd.compute(rdd.partitions(0), ctx) assert(iter.toSeq.map(_.getInt(0)) == Seq(111, 222, 333)) @@ -88,11 +99,13 @@ class ContinuousShuffleReadSuite extends StreamTest { test("multiple epochs") { val rdd = new ContinuousShuffleReadRDD(sparkContext, numPartitions = 1) val endpoint = rdd.partitions(0).asInstanceOf[ContinuousShuffleReadPartition].endpoint - endpoint.askSync[Unit](ReceiverRow(unsafeRow(111))) - endpoint.askSync[Unit](ReceiverEpochMarker()) - endpoint.askSync[Unit](ReceiverRow(unsafeRow(222))) - endpoint.askSync[Unit](ReceiverRow(unsafeRow(333))) - endpoint.askSync[Unit](ReceiverEpochMarker()) + send( + endpoint, + ReceiverRow(unsafeRow(111)), + ReceiverRow(unsafeRow(222)), + ReceiverRow(unsafeRow(333)), + ReceiverEpochMarker() + ) val firstEpoch = rdd.compute(rdd.partitions(0), ctx) assert(firstEpoch.toSeq.map(_.getInt(0)) == Seq(111)) @@ -104,12 +117,15 @@ class ContinuousShuffleReadSuite extends StreamTest { test("empty epochs") { val rdd = new ContinuousShuffleReadRDD(sparkContext, numPartitions = 1) val endpoint = rdd.partitions(0).asInstanceOf[ContinuousShuffleReadPartition].endpoint - endpoint.askSync[Unit](ReceiverEpochMarker()) - endpoint.askSync[Unit](ReceiverEpochMarker()) - endpoint.askSync[Unit](ReceiverRow(unsafeRow(111))) - endpoint.askSync[Unit](ReceiverEpochMarker()) - endpoint.askSync[Unit](ReceiverEpochMarker()) - endpoint.askSync[Unit](ReceiverEpochMarker()) + send( + endpoint, + ReceiverEpochMarker(), + ReceiverEpochMarker(), + ReceiverRow(unsafeRow(111)), + ReceiverEpochMarker(), + ReceiverEpochMarker(), + ReceiverEpochMarker() + ) assert(rdd.compute(rdd.partitions(0), ctx).isEmpty) assert(rdd.compute(rdd.partitions(0), ctx).isEmpty) @@ -127,8 +143,11 @@ class ContinuousShuffleReadSuite extends StreamTest { for (p <- rdd.partitions) { val part = p.asInstanceOf[ContinuousShuffleReadPartition] // Send index for identification. - part.endpoint.askSync[Unit](ReceiverRow(unsafeRow(part.index))) - part.endpoint.askSync[Unit](ReceiverEpochMarker()) + send( + part.endpoint, + ReceiverRow(unsafeRow(part.index)), + ReceiverEpochMarker() + ) } for (p <- rdd.partitions) { From 00f910ea39b76a24e1e21acdf3d6a20fd7784fa9 Mon Sep 17 00:00:00 2001 From: Jose Torres Date: Fri, 18 May 2018 16:02:10 -0700 Subject: [PATCH 19/28] fix --- .../continuous/shuffle/ContinuousShuffleReadSuite.scala | 1 + 1 file changed, 1 insertion(+) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/shuffle/ContinuousShuffleReadSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/shuffle/ContinuousShuffleReadSuite.scala index e71274365a759..b25e75b3b37a6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/shuffle/ContinuousShuffleReadSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/shuffle/ContinuousShuffleReadSuite.scala @@ -102,6 +102,7 @@ class ContinuousShuffleReadSuite extends StreamTest { send( endpoint, ReceiverRow(unsafeRow(111)), + ReceiverEpochMarker(), ReceiverRow(unsafeRow(222)), ReceiverRow(unsafeRow(333)), ReceiverEpochMarker() From 504bf7426acf16cce21c549e494b8149dbaa3774 Mon Sep 17 00:00:00 2001 From: Jose Torres Date: Mon, 21 May 2018 10:13:33 -0700 Subject: [PATCH 20/28] add test --- .../shuffle/ContinuousShuffleReadSuite.scala | 35 +++++++++++++++++++ 1 file changed, 35 insertions(+) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/shuffle/ContinuousShuffleReadSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/shuffle/ContinuousShuffleReadSuite.scala index e79587e2c5b04..f62d27e9c34e7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/shuffle/ContinuousShuffleReadSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/shuffle/ContinuousShuffleReadSuite.scala @@ -254,4 +254,39 @@ class ContinuousShuffleReadSuite extends StreamTest { // Join to pick up assertion failures. readEpochMarkerThread.join() } + + test("writer epochs non aligned") { + val rdd = new ContinuousShuffleReadRDD(sparkContext, numPartitions = 1, numShuffleWriters = 3) + val endpoint = rdd.partitions(0).asInstanceOf[ContinuousShuffleReadPartition].endpoint + // We send multiple epochs for 0, then multiple for 1, then multiple for 2. The receiver should + // collate them as though the markers were aligned in the first place. + send( + endpoint, + ReceiverRow(0, unsafeRow("writer0-row0")), + ReceiverEpochMarker(0), + ReceiverRow(0, unsafeRow("writer0-row1")), + ReceiverEpochMarker(0), + ReceiverEpochMarker(0), + + ReceiverEpochMarker(1), + ReceiverRow(1, unsafeRow("writer1-row0")), + ReceiverEpochMarker(1), + ReceiverRow(1, unsafeRow("writer1-row1")), + ReceiverEpochMarker(1), + + ReceiverEpochMarker(2), + ReceiverEpochMarker(2), + ReceiverRow(2, unsafeRow("writer2-row0")), + ReceiverEpochMarker(2) + ) + + val firstEpoch = rdd.compute(rdd.partitions(0), ctx).map(_.getUTF8String(0).toString).toSet + assert(firstEpoch == Set("writer0-row0")) + + val secondEpoch = rdd.compute(rdd.partitions(0), ctx).map(_.getUTF8String(0).toString).toSet + assert(secondEpoch == Set("writer0-row1", "writer1-row0")) + + val thirdEpoch = rdd.compute(rdd.partitions(0), ctx).map(_.getUTF8String(0).toString).toSet + assert(thirdEpoch == Set("writer1-row1", "writer2-row0")) + } } From 5e66b9be9977883327af9d91128fbb7f1965db36 Mon Sep 17 00:00:00 2001 From: Jose Torres Date: Tue, 22 May 2018 15:35:18 -0700 Subject: [PATCH 21/28] docs --- .../continuous/shuffle/UnsafeRowReceiver.scala | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/UnsafeRowReceiver.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/UnsafeRowReceiver.scala index 575d8e02b4e92..feb4bd5750f9c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/UnsafeRowReceiver.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/UnsafeRowReceiver.scala @@ -27,6 +27,10 @@ import org.apache.spark.util.NextIterator /** * Messages for the UnsafeRowReceiver endpoint. Either an incoming row or an epoch marker. + * + * Each message comes tagged with writerId, identifying which writer the message is coming + * from. The receiver will only begin the next epoch once all writers have sent an epoch + * marker ending the current epoch. */ private[shuffle] sealed trait UnsafeRowReceiverMessage extends Serializable { def writerId: Int @@ -77,8 +81,16 @@ private[shuffle] class UnsafeRowReceiver( override def call(): UnsafeRowReceiverMessage = queues(writerId).take() } + // Initialize by submitting tasks to read the first row from each writer. (0 until numShuffleWriters).foreach(writerId => completion.submit(completionTask(writerId))) + /** + * In each call to getNext(), we pull the next row available in the completion queue, and then + * submit another task to read the next row from the writer which returned it. + * + * When a writer sends an epoch marker, we note that it's finished and don't submit another + * task for it in this epoch. The iterator is over once all writers have sent an epoch marker. + */ override def getNext(): UnsafeRow = { completion.take().get() match { case ReceiverRow(writerId, r) => From a8bffde1656f08a3b1b0619051b8cbb5b6251f74 Mon Sep 17 00:00:00 2001 From: Jose Torres Date: Tue, 22 May 2018 16:08:09 -0700 Subject: [PATCH 22/28] use while loop --- .../shuffle/ContinuousShuffleReadRDD.scala | 19 ++++++-- .../shuffle/UnsafeRowReceiver.scala | 46 +++++++++++++------ .../shuffle/ContinuousShuffleReadSuite.scala | 10 ++-- 3 files changed, 53 insertions(+), 22 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleReadRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleReadRDD.scala index 357256c18e909..94f691bb29ef5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleReadRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleReadRDD.scala @@ -25,12 +25,16 @@ import org.apache.spark.sql.catalyst.expressions.UnsafeRow import org.apache.spark.sql.internal.SQLConf import org.apache.spark.util.NextIterator -case class ContinuousShuffleReadPartition(index: Int, queueSize: Int, numShuffleWriters: Int) +case class ContinuousShuffleReadPartition( + index: Int, + queueSize: Int, + numShuffleWriters: Int, + checkpointIntervalMs: Long) extends Partition { // Initialized only on the executor, and only once even as we call compute() multiple times. lazy val (reader: ContinuousShuffleReader, endpoint) = { val env = SparkEnv.get.rpcEnv - val receiver = new UnsafeRowReceiver(queueSize, numShuffleWriters, env) + val receiver = new UnsafeRowReceiver(queueSize, numShuffleWriters, checkpointIntervalMs, env) val endpoint = env.setupEndpoint(s"UnsafeRowReceiver-${UUID.randomUUID()}", receiver) TaskContext.get().addTaskCompletionListener { ctx => env.stop(endpoint) @@ -43,17 +47,24 @@ case class ContinuousShuffleReadPartition(index: Int, queueSize: Int, numShuffle * RDD at the map side of each continuous processing shuffle task. Upstream tasks send their * shuffle output to the wrapped receivers in partitions of this RDD; each of the RDD's tasks * poll from their receiver until an epoch marker is sent. + * + * @param sc the RDD context + * @param numPartitions the number of read partitions for this RDD + * @param queueSize the size of the row buffers to use + * @param numShuffleWriters the number of continuous shuffle writers feeding into this RDD + * @param checkpointIntervalMs the checkpoint interval of the streaming query */ class ContinuousShuffleReadRDD( sc: SparkContext, numPartitions: Int, queueSize: Int = 1024, - numShuffleWriters: Int = 1) + numShuffleWriters: Int = 1, + checkpointIntervalMs: Long = 1000) extends RDD[UnsafeRow](sc, Nil) { override protected def getPartitions: Array[Partition] = { (0 until numPartitions).map { partIndex => - ContinuousShuffleReadPartition(partIndex, queueSize, numShuffleWriters) + ContinuousShuffleReadPartition(partIndex, queueSize, numShuffleWriters, checkpointIntervalMs) }.toArray } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/UnsafeRowReceiver.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/UnsafeRowReceiver.scala index feb4bd5750f9c..6590b38a95e68 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/UnsafeRowReceiver.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/UnsafeRowReceiver.scala @@ -49,6 +49,7 @@ private[shuffle] case class ReceiverEpochMarker(writerId: Int) extends UnsafeRow private[shuffle] class UnsafeRowReceiver( queueSize: Int, numShuffleWriters: Int, + checkpointIntervalMs: Long, override val rpcEnv: RpcEnv) extends ThreadSafeRpcEndpoint with ContinuousShuffleReader with Logging { // Note that this queue will be drained from the main task thread and populated in the RPC @@ -92,23 +93,40 @@ private[shuffle] class UnsafeRowReceiver( * task for it in this epoch. The iterator is over once all writers have sent an epoch marker. */ override def getNext(): UnsafeRow = { - completion.take().get() match { - case ReceiverRow(writerId, r) => - // Start reading the next element in the queue we just took from. - completion.submit(completionTask(writerId)) - r - // TODO use writerId - case ReceiverEpochMarker(writerId) => - // Don't read any more from this queue. If all the writers have sent epoch markers, - // the epoch is over; otherwise we need rows from one of the remaining writers. - val writersCompleted = numWriterEpochMarkers.incrementAndGet() - if (writersCompleted == numShuffleWriters) { - finished = true + var nextRow: UnsafeRow = null + while (nextRow == null) { + nextRow = completion.poll(checkpointIntervalMs, TimeUnit.MILLISECONDS) match { + case null => + // Try again if the poll didn't wait long enough to get a real result. + // But we should be getting at least an epoch marker every checkpoint interval. + logWarning( + s"Completion service failed to make progress after $checkpointIntervalMs ms") null - } else { - getNext() + + // The completion service guarantees this future will be available immediately. + case future => future.get() match { + case ReceiverRow(writerId, r) => + // Start reading the next element in the queue we just took from. + completion.submit(completionTask(writerId)) + r + // TODO use writerId + case ReceiverEpochMarker(writerId) => + // Don't read any more from this queue. If all the writers have sent epoch markers, + // the epoch is over; otherwise we need rows from one of the remaining writers. + val writersCompleted = numWriterEpochMarkers.incrementAndGet() + if (writersCompleted == numShuffleWriters) { + finished = true + // Break out of the while loop and end the iterator. + return null + } else { + // Poll again for the next completion result. + null + } } + } } + + nextRow } override def close(): Unit = { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/shuffle/ContinuousShuffleReadSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/shuffle/ContinuousShuffleReadSuite.scala index d8dcce2f7b73f..aa8b3e65a27f9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/shuffle/ContinuousShuffleReadSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/shuffle/ContinuousShuffleReadSuite.scala @@ -188,7 +188,8 @@ class ContinuousShuffleReadSuite extends StreamTest { } test("blocks waiting for new rows") { - val rdd = new ContinuousShuffleReadRDD(sparkContext, numPartitions = 1) + val rdd = new ContinuousShuffleReadRDD( + sparkContext, numPartitions = 1, checkpointIntervalMs = Long.MaxValue) val epoch = rdd.compute(rdd.partitions(0), ctx) val readRowThread = new Thread { @@ -204,7 +205,7 @@ class ContinuousShuffleReadSuite extends StreamTest { try { readRowThread.start() eventually(timeout(streamingTimeout)) { - assert(readRowThread.getState == Thread.State.WAITING) + assert(readRowThread.getState == Thread.State.TIMED_WAITING) } } finally { readRowThread.interrupt() @@ -213,7 +214,8 @@ class ContinuousShuffleReadSuite extends StreamTest { } test("epoch only ends when all writers send markers") { - val rdd = new ContinuousShuffleReadRDD(sparkContext, numPartitions = 1, numShuffleWriters = 3) + val rdd = new ContinuousShuffleReadRDD( + sparkContext, numPartitions = 1, numShuffleWriters = 3, checkpointIntervalMs = Long.MaxValue) val endpoint = rdd.partitions(0).asInstanceOf[ContinuousShuffleReadPartition].endpoint send( endpoint, @@ -240,7 +242,7 @@ class ContinuousShuffleReadSuite extends StreamTest { readEpochMarkerThread.start() eventually(timeout(streamingTimeout)) { - assert(readEpochMarkerThread.getState == Thread.State.WAITING) + assert(readEpochMarkerThread.getState == Thread.State.TIMED_WAITING) } // Send the last epoch marker - now the epoch should finish. From fc1c829caf46dda26eba5ad7a86ff8560de6e05b Mon Sep 17 00:00:00 2001 From: Jose Torres Date: Tue, 22 May 2018 16:09:12 -0700 Subject: [PATCH 23/28] rearrange --- .../shuffle/ContinuousShuffleReadSuite.scala | 37 +++++++++---------- 1 file changed, 18 insertions(+), 19 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/shuffle/ContinuousShuffleReadSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/shuffle/ContinuousShuffleReadSuite.scala index aa8b3e65a27f9..9ae677870ee5e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/shuffle/ContinuousShuffleReadSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/shuffle/ContinuousShuffleReadSuite.scala @@ -124,24 +124,6 @@ class ContinuousShuffleReadSuite extends StreamTest { assert(secondEpoch.toSeq.map(_.getInt(0)) == Seq(222, 333)) } - test("multiple writers") { - val rdd = new ContinuousShuffleReadRDD(sparkContext, numPartitions = 1, numShuffleWriters = 3) - val endpoint = rdd.partitions(0).asInstanceOf[ContinuousShuffleReadPartition].endpoint - send( - endpoint, - ReceiverRow(0, unsafeRow("writer0-row0")), - ReceiverRow(1, unsafeRow("writer1-row0")), - ReceiverRow(2, unsafeRow("writer2-row0")), - ReceiverEpochMarker(0), - ReceiverEpochMarker(1), - ReceiverEpochMarker(2) - ) - - val firstEpoch = rdd.compute(rdd.partitions(0), ctx) - assert(firstEpoch.toSeq.map(_.getUTF8String(0).toString).toSet == - Set("writer0-row0", "writer1-row0", "writer2-row0")) - } - test("empty epochs") { val rdd = new ContinuousShuffleReadRDD(sparkContext, numPartitions = 1) val endpoint = rdd.partitions(0).asInstanceOf[ContinuousShuffleReadPartition].endpoint @@ -213,6 +195,24 @@ class ContinuousShuffleReadSuite extends StreamTest { } } + test("multiple writers") { + val rdd = new ContinuousShuffleReadRDD(sparkContext, numPartitions = 1, numShuffleWriters = 3) + val endpoint = rdd.partitions(0).asInstanceOf[ContinuousShuffleReadPartition].endpoint + send( + endpoint, + ReceiverRow(0, unsafeRow("writer0-row0")), + ReceiverRow(1, unsafeRow("writer1-row0")), + ReceiverRow(2, unsafeRow("writer2-row0")), + ReceiverEpochMarker(0), + ReceiverEpochMarker(1), + ReceiverEpochMarker(2) + ) + + val firstEpoch = rdd.compute(rdd.partitions(0), ctx) + assert(firstEpoch.toSeq.map(_.getUTF8String(0).toString).toSet == + Set("writer0-row0", "writer1-row0", "writer2-row0")) + } + test("epoch only ends when all writers send markers") { val rdd = new ContinuousShuffleReadRDD( sparkContext, numPartitions = 1, numShuffleWriters = 3, checkpointIntervalMs = Long.MaxValue) @@ -240,7 +240,6 @@ class ContinuousShuffleReadSuite extends StreamTest { } readEpochMarkerThread.start() - eventually(timeout(streamingTimeout)) { assert(readEpochMarkerThread.getState == Thread.State.TIMED_WAITING) } From ca3f5cd31617e6952abd29b935d480059677b67e Mon Sep 17 00:00:00 2001 From: Jose Torres Date: Tue, 22 May 2018 16:54:52 -0700 Subject: [PATCH 24/28] use map --- .../continuous/shuffle/UnsafeRowReceiver.scala | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/UnsafeRowReceiver.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/UnsafeRowReceiver.scala index 6590b38a95e68..2f1d30737306c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/UnsafeRowReceiver.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/UnsafeRowReceiver.scala @@ -18,7 +18,9 @@ package org.apache.spark.sql.execution.streaming.continuous.shuffle import java.util.concurrent._ -import java.util.concurrent.atomic.{AtomicBoolean, AtomicInteger} +import java.util.concurrent.atomic.AtomicBoolean + +import scala.collection.mutable import org.apache.spark.internal.Logging import org.apache.spark.rpc.{RpcCallContext, RpcEnv, ThreadSafeRpcEndpoint} @@ -73,7 +75,9 @@ private[shuffle] class UnsafeRowReceiver( override def read(): Iterator[UnsafeRow] = { new NextIterator[UnsafeRow] { - private val numWriterEpochMarkers = new AtomicInteger(0) + // An array of flags for whether each writer ID has gotten an epoch marker. + private val writerEpochMarkersReceived = + mutable.Map.empty[Int, Boolean].withDefaultValue(false) private val executor = Executors.newFixedThreadPool(numShuffleWriters) private val completion = new ExecutorCompletionService[UnsafeRowReceiverMessage](executor) @@ -112,9 +116,9 @@ private[shuffle] class UnsafeRowReceiver( // TODO use writerId case ReceiverEpochMarker(writerId) => // Don't read any more from this queue. If all the writers have sent epoch markers, - // the epoch is over; otherwise we need rows from one of the remaining writers. - val writersCompleted = numWriterEpochMarkers.incrementAndGet() - if (writersCompleted == numShuffleWriters) { + // the epoch is over; otherwise we need to poll from the remaining writers. + writerEpochMarkersReceived.put(writerId, true) + if ((0 until numShuffleWriters).forall(id => writerEpochMarkersReceived(id))) { finished = true // Break out of the while loop and end the iterator. return null From e02d714f6c6774eb275bf86cc81e09dd80f40b11 Mon Sep 17 00:00:00 2001 From: Jose Torres Date: Wed, 23 May 2018 08:58:13 -0700 Subject: [PATCH 25/28] use array instead --- .../streaming/continuous/shuffle/UnsafeRowReceiver.scala | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/UnsafeRowReceiver.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/UnsafeRowReceiver.scala index 2f1d30737306c..986a9de70f8e7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/UnsafeRowReceiver.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/UnsafeRowReceiver.scala @@ -76,8 +76,7 @@ private[shuffle] class UnsafeRowReceiver( override def read(): Iterator[UnsafeRow] = { new NextIterator[UnsafeRow] { // An array of flags for whether each writer ID has gotten an epoch marker. - private val writerEpochMarkersReceived = - mutable.Map.empty[Int, Boolean].withDefaultValue(false) + private val writerEpochMarkersReceived = Array.fill(numShuffleWriters)(false) private val executor = Executors.newFixedThreadPool(numShuffleWriters) private val completion = new ExecutorCompletionService[UnsafeRowReceiverMessage](executor) @@ -113,12 +112,11 @@ private[shuffle] class UnsafeRowReceiver( // Start reading the next element in the queue we just took from. completion.submit(completionTask(writerId)) r - // TODO use writerId case ReceiverEpochMarker(writerId) => // Don't read any more from this queue. If all the writers have sent epoch markers, // the epoch is over; otherwise we need to poll from the remaining writers. - writerEpochMarkersReceived.put(writerId, true) - if ((0 until numShuffleWriters).forall(id => writerEpochMarkersReceived(id))) { + writerEpochMarkersReceived(writerId) = true + if (writerEpochMarkersReceived.forall(flag => flag)) { finished = true // Break out of the while loop and end the iterator. return null From b0c38781280288b3775fb087be7763934c1deb3c Mon Sep 17 00:00:00 2001 From: Jose Torres Date: Thu, 24 May 2018 11:18:49 -0700 Subject: [PATCH 26/28] address comments --- .../continuous/shuffle/ContinuousShuffleReadRDD.scala | 10 +++++----- .../continuous/shuffle/UnsafeRowReceiver.scala | 10 +++++++--- .../shuffle/ContinuousShuffleReadSuite.scala | 4 ++-- 3 files changed, 14 insertions(+), 10 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleReadRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleReadRDD.scala index 94f691bb29ef5..801b28b751bee 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleReadRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleReadRDD.scala @@ -29,12 +29,12 @@ case class ContinuousShuffleReadPartition( index: Int, queueSize: Int, numShuffleWriters: Int, - checkpointIntervalMs: Long) + epochIntervalMs: Long) extends Partition { // Initialized only on the executor, and only once even as we call compute() multiple times. lazy val (reader: ContinuousShuffleReader, endpoint) = { val env = SparkEnv.get.rpcEnv - val receiver = new UnsafeRowReceiver(queueSize, numShuffleWriters, checkpointIntervalMs, env) + val receiver = new UnsafeRowReceiver(queueSize, numShuffleWriters, epochIntervalMs, env) val endpoint = env.setupEndpoint(s"UnsafeRowReceiver-${UUID.randomUUID()}", receiver) TaskContext.get().addTaskCompletionListener { ctx => env.stop(endpoint) @@ -52,19 +52,19 @@ case class ContinuousShuffleReadPartition( * @param numPartitions the number of read partitions for this RDD * @param queueSize the size of the row buffers to use * @param numShuffleWriters the number of continuous shuffle writers feeding into this RDD - * @param checkpointIntervalMs the checkpoint interval of the streaming query + * @param epochIntervalMs the checkpoint interval of the streaming query */ class ContinuousShuffleReadRDD( sc: SparkContext, numPartitions: Int, queueSize: Int = 1024, numShuffleWriters: Int = 1, - checkpointIntervalMs: Long = 1000) + epochIntervalMs: Long = 1000) extends RDD[UnsafeRow](sc, Nil) { override protected def getPartitions: Array[Partition] = { (0 until numPartitions).map { partIndex => - ContinuousShuffleReadPartition(partIndex, queueSize, numShuffleWriters, checkpointIntervalMs) + ContinuousShuffleReadPartition(partIndex, queueSize, numShuffleWriters, epochIntervalMs) }.toArray } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/UnsafeRowReceiver.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/UnsafeRowReceiver.scala index 986a9de70f8e7..b072a561f5a94 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/UnsafeRowReceiver.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/UnsafeRowReceiver.scala @@ -51,7 +51,7 @@ private[shuffle] case class ReceiverEpochMarker(writerId: Int) extends UnsafeRow private[shuffle] class UnsafeRowReceiver( queueSize: Int, numShuffleWriters: Int, - checkpointIntervalMs: Long, + epochIntervalMs: Long, override val rpcEnv: RpcEnv) extends ThreadSafeRpcEndpoint with ContinuousShuffleReader with Logging { // Note that this queue will be drained from the main task thread and populated in the RPC @@ -98,12 +98,16 @@ private[shuffle] class UnsafeRowReceiver( override def getNext(): UnsafeRow = { var nextRow: UnsafeRow = null while (nextRow == null) { - nextRow = completion.poll(checkpointIntervalMs, TimeUnit.MILLISECONDS) match { + nextRow = completion.poll(epochIntervalMs, TimeUnit.MILLISECONDS) match { case null => // Try again if the poll didn't wait long enough to get a real result. // But we should be getting at least an epoch marker every checkpoint interval. + val writerIdsUncommitted = writerEpochMarkersReceived.zipWithIndex.collect { + case (flag, idx) if !flag => idx + } logWarning( - s"Completion service failed to make progress after $checkpointIntervalMs ms") + s"Completion service failed to make progress after $epochIntervalMs ms. Waiting " + + s"for writers $writerIdsUncommitted to send epoch markers.") null // The completion service guarantees this future will be available immediately. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/shuffle/ContinuousShuffleReadSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/shuffle/ContinuousShuffleReadSuite.scala index 9ae677870ee5e..2e4d607a403ca 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/shuffle/ContinuousShuffleReadSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/shuffle/ContinuousShuffleReadSuite.scala @@ -171,7 +171,7 @@ class ContinuousShuffleReadSuite extends StreamTest { test("blocks waiting for new rows") { val rdd = new ContinuousShuffleReadRDD( - sparkContext, numPartitions = 1, checkpointIntervalMs = Long.MaxValue) + sparkContext, numPartitions = 1, epochIntervalMs = Long.MaxValue) val epoch = rdd.compute(rdd.partitions(0), ctx) val readRowThread = new Thread { @@ -215,7 +215,7 @@ class ContinuousShuffleReadSuite extends StreamTest { test("epoch only ends when all writers send markers") { val rdd = new ContinuousShuffleReadRDD( - sparkContext, numPartitions = 1, numShuffleWriters = 3, checkpointIntervalMs = Long.MaxValue) + sparkContext, numPartitions = 1, numShuffleWriters = 3, epochIntervalMs = Long.MaxValue) val endpoint = rdd.partitions(0).asInstanceOf[ContinuousShuffleReadPartition].endpoint send( endpoint, From b9ad1957330e4f549b4e3ae4ec6819bc3628346f Mon Sep 17 00:00:00 2001 From: Jose Torres Date: Thu, 24 May 2018 11:20:20 -0700 Subject: [PATCH 27/28] don't assign to block --- .../continuous/shuffle/UnsafeRowReceiver.scala | 15 +++++---------- 1 file changed, 5 insertions(+), 10 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/UnsafeRowReceiver.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/UnsafeRowReceiver.scala index b072a561f5a94..946649e861ef8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/UnsafeRowReceiver.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/UnsafeRowReceiver.scala @@ -97,8 +97,8 @@ private[shuffle] class UnsafeRowReceiver( */ override def getNext(): UnsafeRow = { var nextRow: UnsafeRow = null - while (nextRow == null) { - nextRow = completion.poll(epochIntervalMs, TimeUnit.MILLISECONDS) match { + while (!finished && nextRow == null) { + completion.poll(epochIntervalMs, TimeUnit.MILLISECONDS) match { case null => // Try again if the poll didn't wait long enough to get a real result. // But we should be getting at least an epoch marker every checkpoint interval. @@ -108,25 +108,20 @@ private[shuffle] class UnsafeRowReceiver( logWarning( s"Completion service failed to make progress after $epochIntervalMs ms. Waiting " + s"for writers $writerIdsUncommitted to send epoch markers.") - null // The completion service guarantees this future will be available immediately. case future => future.get() match { case ReceiverRow(writerId, r) => // Start reading the next element in the queue we just took from. completion.submit(completionTask(writerId)) - r + nextRow = r case ReceiverEpochMarker(writerId) => // Don't read any more from this queue. If all the writers have sent epoch markers, - // the epoch is over; otherwise we need to poll from the remaining writers. + // the epoch is over; otherwise we need to loop again to poll from the remaining + // writers. writerEpochMarkersReceived(writerId) = true if (writerEpochMarkersReceived.forall(flag => flag)) { finished = true - // Break out of the while loop and end the iterator. - return null - } else { - // Poll again for the next completion result. - null } } } From f2a1f48665ac8214c6943bdb4d363f5c1d9e61c0 Mon Sep 17 00:00:00 2001 From: Jose Torres Date: Thu, 24 May 2018 11:20:36 -0700 Subject: [PATCH 28/28] change forall --- .../streaming/continuous/shuffle/UnsafeRowReceiver.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/UnsafeRowReceiver.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/UnsafeRowReceiver.scala index 946649e861ef8..d81f552d56626 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/UnsafeRowReceiver.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/UnsafeRowReceiver.scala @@ -120,7 +120,7 @@ private[shuffle] class UnsafeRowReceiver( // the epoch is over; otherwise we need to loop again to poll from the remaining // writers. writerEpochMarkersReceived(writerId) = true - if (writerEpochMarkersReceived.forall(flag => flag)) { + if (writerEpochMarkersReceived.forall(_ == true)) { finished = true } }