From 1d6b71898e2a640e3c0809695d2b83f3f84eaa38 Mon Sep 17 00:00:00 2001 From: Jose Torres Date: Tue, 15 May 2018 11:07:54 -0700 Subject: [PATCH 01/18] continuous shuffle read RDD --- .../shuffle/ContinuousShuffleReadRDD.scala | 64 +++++++++ .../shuffle/UnsafeRowReceiver.scala | 56 ++++++++ .../shuffle/ContinuousShuffleReadSuite.scala | 122 ++++++++++++++++++ 3 files changed, 242 insertions(+) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleReadRDD.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/UnsafeRowReceiver.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/shuffle/ContinuousShuffleReadSuite.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleReadRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleReadRDD.scala new file mode 100644 index 0000000000000..110b51f8922d8 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleReadRDD.scala @@ -0,0 +1,64 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.continuous.shuffle + +import java.util.UUID + +import org.apache.spark.{Partition, SparkContext, SparkEnv, TaskContext} +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.expressions.UnsafeRow +import org.apache.spark.util.NextIterator + +case class ContinuousShuffleReadPartition(index: Int) extends Partition { + // Initialized only on the executor, and only once even as we call compute() multiple times. + lazy val (receiver, endpoint) = { + val env = SparkEnv.get.rpcEnv + val receiver = new UnsafeRowReceiver(env) + val endpoint = env.setupEndpoint(UUID.randomUUID().toString, receiver) + TaskContext.get().addTaskCompletionListener { ctx => + env.stop(endpoint) + } + (receiver, endpoint) + } +} + +/** + * RDD at the bottom of each continuous processing shuffle task, reading from the + */ +class ContinuousShuffleReadRDD(sc: SparkContext, numPartitions: Int) + extends RDD[UnsafeRow](sc, Nil) { + + override protected def getPartitions: Array[Partition] = { + (0 until numPartitions).map(ContinuousShuffleReadPartition).toArray + } + + override def compute(split: Partition, context: TaskContext): Iterator[UnsafeRow] = { + val receiver = split.asInstanceOf[ContinuousShuffleReadPartition].receiver + + new NextIterator[UnsafeRow] { + override def getNext(): UnsafeRow = receiver.poll() match { + case ReceiverRow(r) => r + case ReceiverEpochMarker() => + finished = true + null + } + + override def close(): Unit = {} + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/UnsafeRowReceiver.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/UnsafeRowReceiver.scala new file mode 100644 index 0000000000000..637305e54519f --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/UnsafeRowReceiver.scala @@ -0,0 +1,56 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.continuous.shuffle + +import java.util.concurrent.{ArrayBlockingQueue, BlockingQueue} +import java.util.concurrent.atomic.AtomicBoolean + +import org.apache.spark.internal.Logging +import org.apache.spark.rpc.{RpcCallContext, RpcEnv, ThreadSafeRpcEndpoint} +import org.apache.spark.sql.catalyst.expressions.UnsafeRow + +/** + * Messages for the UnsafeRowReceiver endpoint. Either an incoming row or an epoch marker. + */ +private[shuffle] sealed trait UnsafeRowReceiverMessage extends Serializable +private[shuffle] case class ReceiverRow(row: UnsafeRow) extends UnsafeRowReceiverMessage +private[shuffle] case class ReceiverEpochMarker() extends UnsafeRowReceiverMessage + +/** + * RPC endpoint for receiving rows into a continuous processing shuffle task. + */ +private[shuffle] class UnsafeRowReceiver(val rpcEnv: RpcEnv) + extends ThreadSafeRpcEndpoint with Logging { + private val queue = new ArrayBlockingQueue[UnsafeRowReceiverMessage](1024) + var stopped = new AtomicBoolean(false) + + override def onStop(): Unit = { + stopped.set(true) + } + + override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { + case r: UnsafeRowReceiverMessage => + queue.put(r) + context.reply(()) + } + + /** + * Polls until a new row is available. + */ + def poll(): UnsafeRowReceiverMessage = queue.poll() +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/shuffle/ContinuousShuffleReadSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/shuffle/ContinuousShuffleReadSuite.scala new file mode 100644 index 0000000000000..ee0ff8ada23f4 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/shuffle/ContinuousShuffleReadSuite.scala @@ -0,0 +1,122 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.continuous.shuffle + +import org.apache.spark.{TaskContext, TaskContextImpl} +import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeProjection} +import org.apache.spark.sql.streaming.StreamTest +import org.apache.spark.sql.types.{DataType, IntegerType} + +class ContinuousShuffleReadSuite extends StreamTest { + + private def unsafeRow(value: Int) = { + UnsafeProjection.create(Array(IntegerType : DataType))( + new GenericInternalRow(Array(value: Any))) + } + + var ctx: TaskContextImpl = _ + + override def beforeEach(): Unit = { + super.beforeEach() + ctx = TaskContext.empty() + TaskContext.setTaskContext(ctx) + } + + override def afterEach(): Unit = { + ctx.markTaskCompleted(None) + ctx = null + super.afterEach() + } + + test("receiver stopped with row last") { + val rdd = new ContinuousShuffleReadRDD(sparkContext, numPartitions = 1) + val endpoint = rdd.partitions(0).asInstanceOf[ContinuousShuffleReadPartition].endpoint + endpoint.askSync[Unit](ReceiverEpochMarker()) + endpoint.askSync[Unit](ReceiverRow(unsafeRow(111))) + + ctx.markTaskCompleted(None) + val receiver = rdd.partitions(0).asInstanceOf[ContinuousShuffleReadPartition].receiver + eventually(timeout(streamingTimeout)) { + assert(receiver.stopped.get()) + } + } + + test("receiver stopped with marker last") { + val rdd = new ContinuousShuffleReadRDD(sparkContext, numPartitions = 1) + val endpoint = rdd.partitions(0).asInstanceOf[ContinuousShuffleReadPartition].endpoint + endpoint.askSync[Unit](ReceiverRow(unsafeRow(111))) + endpoint.askSync[Unit](ReceiverEpochMarker()) + + ctx.markTaskCompleted(None) + val receiver = rdd.partitions(0).asInstanceOf[ContinuousShuffleReadPartition].receiver + eventually(timeout(streamingTimeout)) { + assert(receiver.stopped.get()) + } + } + + test("one epoch") { + val rdd = new ContinuousShuffleReadRDD(sparkContext, numPartitions = 1) + val endpoint = rdd.partitions(0).asInstanceOf[ContinuousShuffleReadPartition].endpoint + endpoint.askSync[Unit](ReceiverRow(unsafeRow(111))) + endpoint.askSync[Unit](ReceiverRow(unsafeRow(222))) + endpoint.askSync[Unit](ReceiverRow(unsafeRow(333))) + endpoint.askSync[Unit](ReceiverEpochMarker()) + + val iter = rdd.compute(rdd.partitions(0), ctx) + assert(iter.next().getInt(0) == 111) + assert(iter.next().getInt(0) == 222) + assert(iter.next().getInt(0) == 333) + assert(!iter.hasNext) + } + + test("multiple epochs") { + val rdd = new ContinuousShuffleReadRDD(sparkContext, numPartitions = 1) + val endpoint = rdd.partitions(0).asInstanceOf[ContinuousShuffleReadPartition].endpoint + endpoint.askSync[Unit](ReceiverRow(unsafeRow(111))) + endpoint.askSync[Unit](ReceiverEpochMarker()) + endpoint.askSync[Unit](ReceiverRow(unsafeRow(222))) + endpoint.askSync[Unit](ReceiverRow(unsafeRow(333))) + endpoint.askSync[Unit](ReceiverEpochMarker()) + + val firstEpoch = rdd.compute(rdd.partitions(0), ctx) + assert(firstEpoch.next().getInt(0) == 111) + assert(!firstEpoch.hasNext) + + val secondEpoch = rdd.compute(rdd.partitions(0), ctx) + assert(secondEpoch.next().getInt(0) == 222) + assert(secondEpoch.next().getInt(0) == 333) + assert(!secondEpoch.hasNext) + } + + test("empty epochs") { + val rdd = new ContinuousShuffleReadRDD(sparkContext, numPartitions = 1) + val endpoint = rdd.partitions(0).asInstanceOf[ContinuousShuffleReadPartition].endpoint + endpoint.askSync[Unit](ReceiverEpochMarker()) + endpoint.askSync[Unit](ReceiverEpochMarker()) + endpoint.askSync[Unit](ReceiverRow(unsafeRow(111))) + endpoint.askSync[Unit](ReceiverEpochMarker()) + endpoint.askSync[Unit](ReceiverEpochMarker()) + + assert(rdd.compute(rdd.partitions(0), ctx).isEmpty) + assert(rdd.compute(rdd.partitions(0), ctx).isEmpty) + val thirdEpoch = rdd.compute(rdd.partitions(0), ctx) + assert(thirdEpoch.next().getInt(0) == 111) + assert(rdd.compute(rdd.partitions(0), ctx).isEmpty) + assert(rdd.compute(rdd.partitions(0), ctx).isEmpty) + } +} From b5d100875932bdfcb645c8f6b2cdb7b815d84c80 Mon Sep 17 00:00:00 2001 From: Jose Torres Date: Wed, 16 May 2018 20:11:11 -0700 Subject: [PATCH 02/18] docs --- .../continuous/shuffle/ContinuousShuffleReadRDD.scala | 4 +++- .../streaming/continuous/shuffle/UnsafeRowReceiver.scala | 8 +++++++- 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleReadRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleReadRDD.scala index 110b51f8922d8..01a7172357883 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleReadRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleReadRDD.scala @@ -38,7 +38,9 @@ case class ContinuousShuffleReadPartition(index: Int) extends Partition { } /** - * RDD at the bottom of each continuous processing shuffle task, reading from the + * RDD at the map side of each continuous processing shuffle task. Upstream tasks send their + * shuffle output to the wrapped receivers in partitions of this RDD; each of the RDD's tasks + * poll from their receiver until an epoch marker is sent. */ class ContinuousShuffleReadRDD(sc: SparkContext, numPartitions: Int) extends RDD[UnsafeRow](sc, Nil) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/UnsafeRowReceiver.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/UnsafeRowReceiver.scala index 637305e54519f..addfe357f63f1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/UnsafeRowReceiver.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/UnsafeRowReceiver.scala @@ -32,10 +32,16 @@ private[shuffle] case class ReceiverRow(row: UnsafeRow) extends UnsafeRowReceive private[shuffle] case class ReceiverEpochMarker() extends UnsafeRowReceiverMessage /** - * RPC endpoint for receiving rows into a continuous processing shuffle task. + * RPC endpoint for receiving rows into a continuous processing shuffle task. Continuous shuffle + * writers will send rows here, with continuous shuffle readers polling for new rows as needed. + * + * TODO: Support multiple source tasks. We need to output a single epoch marker once all + * source tasks have sent one. */ private[shuffle] class UnsafeRowReceiver(val rpcEnv: RpcEnv) extends ThreadSafeRpcEndpoint with Logging { + // Note that this queue will be drained from the main task thread and populated in the RPC + // response thread. private val queue = new ArrayBlockingQueue[UnsafeRowReceiverMessage](1024) var stopped = new AtomicBoolean(false) From 46456dc75a6aec9659b18523c421999debd060eb Mon Sep 17 00:00:00 2001 From: Jose Torres Date: Wed, 16 May 2018 20:22:49 -0700 Subject: [PATCH 03/18] fix ctor --- .../continuous/shuffle/ContinuousShuffleReadRDD.scala | 3 ++- .../streaming/continuous/shuffle/UnsafeRowReceiver.scala | 6 ++++-- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleReadRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleReadRDD.scala index 01a7172357883..f69977749de08 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleReadRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleReadRDD.scala @@ -22,13 +22,14 @@ import java.util.UUID import org.apache.spark.{Partition, SparkContext, SparkEnv, TaskContext} import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.expressions.UnsafeRow +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.util.NextIterator case class ContinuousShuffleReadPartition(index: Int) extends Partition { // Initialized only on the executor, and only once even as we call compute() multiple times. lazy val (receiver, endpoint) = { val env = SparkEnv.get.rpcEnv - val receiver = new UnsafeRowReceiver(env) + val receiver = new UnsafeRowReceiver(SQLConf.get.continuousStreamingExecutorQueueSize, env) val endpoint = env.setupEndpoint(UUID.randomUUID().toString, receiver) TaskContext.get().addTaskCompletionListener { ctx => env.stop(endpoint) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/UnsafeRowReceiver.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/UnsafeRowReceiver.scala index addfe357f63f1..f50b89f276eaa 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/UnsafeRowReceiver.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/UnsafeRowReceiver.scala @@ -38,11 +38,13 @@ private[shuffle] case class ReceiverEpochMarker() extends UnsafeRowReceiverMessa * TODO: Support multiple source tasks. We need to output a single epoch marker once all * source tasks have sent one. */ -private[shuffle] class UnsafeRowReceiver(val rpcEnv: RpcEnv) +private[shuffle] class UnsafeRowReceiver( + queueSize: Int, + override val rpcEnv: RpcEnv) extends ThreadSafeRpcEndpoint with Logging { // Note that this queue will be drained from the main task thread and populated in the RPC // response thread. - private val queue = new ArrayBlockingQueue[UnsafeRowReceiverMessage](1024) + private val queue = new ArrayBlockingQueue[UnsafeRowReceiverMessage](queueSize) var stopped = new AtomicBoolean(false) override def onStop(): Unit = { From 2ea8a6f94216e8b184e5780ec3e6ffb2838de382 Mon Sep 17 00:00:00 2001 From: Jose Torres Date: Wed, 16 May 2018 20:43:10 -0700 Subject: [PATCH 04/18] multiple partition test --- .../shuffle/ContinuousShuffleReadSuite.scala | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/shuffle/ContinuousShuffleReadSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/shuffle/ContinuousShuffleReadSuite.scala index ee0ff8ada23f4..ad841626c8b10 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/shuffle/ContinuousShuffleReadSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/shuffle/ContinuousShuffleReadSuite.scala @@ -119,4 +119,18 @@ class ContinuousShuffleReadSuite extends StreamTest { assert(rdd.compute(rdd.partitions(0), ctx).isEmpty) assert(rdd.compute(rdd.partitions(0), ctx).isEmpty) } + + test("multiple partitions") { + val rdd = new ContinuousShuffleReadRDD(sparkContext, numPartitions = 5) + for (p <- rdd.partitions) { + val part = p.asInstanceOf[ContinuousShuffleReadPartition] + // Send index to ensure data doesn't somehow cross over between partitions. + part.endpoint.askSync[Unit](ReceiverRow(unsafeRow(part.index))) + part.endpoint.askSync[Unit](ReceiverEpochMarker()) + + val iter = rdd.compute(part, ctx) + assert(iter.next().getInt(0) == part.index) + assert(!iter.hasNext) + } + } } From 955ac79eb05dc389e632d1aaa6c59396835c6ed5 Mon Sep 17 00:00:00 2001 From: Jose Torres Date: Thu, 17 May 2018 06:33:51 -0700 Subject: [PATCH 05/18] unset task context after test --- .../continuous/shuffle/ContinuousShuffleReadSuite.scala | 1 + 1 file changed, 1 insertion(+) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/shuffle/ContinuousShuffleReadSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/shuffle/ContinuousShuffleReadSuite.scala index ad841626c8b10..faec753768895 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/shuffle/ContinuousShuffleReadSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/shuffle/ContinuousShuffleReadSuite.scala @@ -39,6 +39,7 @@ class ContinuousShuffleReadSuite extends StreamTest { override def afterEach(): Unit = { ctx.markTaskCompleted(None) + TaskContext.unset() ctx = null super.afterEach() } From 8cefb724512b51f2aa1fdd81fa8a2d4560e60ce3 Mon Sep 17 00:00:00 2001 From: Jose Torres Date: Thu, 17 May 2018 17:00:05 -0700 Subject: [PATCH 06/18] conf from RDD --- .../shuffle/ContinuousShuffleReadRDD.scala | 31 ++++++++++++------- 1 file changed, 20 insertions(+), 11 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleReadRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleReadRDD.scala index f69977749de08..a85ba28a8d528 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleReadRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleReadRDD.scala @@ -20,22 +20,19 @@ package org.apache.spark.sql.execution.streaming.continuous.shuffle import java.util.UUID import org.apache.spark.{Partition, SparkContext, SparkEnv, TaskContext} + import org.apache.spark.rdd.RDD +import org.apache.spark.rpc.{RpcEndpoint, RpcEndpointRef} import org.apache.spark.sql.catalyst.expressions.UnsafeRow import org.apache.spark.sql.internal.SQLConf import org.apache.spark.util.NextIterator case class ContinuousShuffleReadPartition(index: Int) extends Partition { - // Initialized only on the executor, and only once even as we call compute() multiple times. - lazy val (receiver, endpoint) = { - val env = SparkEnv.get.rpcEnv - val receiver = new UnsafeRowReceiver(SQLConf.get.continuousStreamingExecutorQueueSize, env) - val endpoint = env.setupEndpoint(UUID.randomUUID().toString, receiver) - TaskContext.get().addTaskCompletionListener { ctx => - env.stop(endpoint) - } - (receiver, endpoint) - } + // Semantically lazy vals - initialized only on the executor, and only once even as we call + // compute() multiple times. We need to initialize them inside compute() so we have access to the + // RDD's conf. + var receiver: UnsafeRowReceiver = _ + var endpoint: RpcEndpointRef = _ } /** @@ -46,14 +43,26 @@ case class ContinuousShuffleReadPartition(index: Int) extends Partition { class ContinuousShuffleReadRDD(sc: SparkContext, numPartitions: Int) extends RDD[UnsafeRow](sc, Nil) { + private val queueSize = sc.conf.get(SQLConf.CONTINUOUS_STREAMING_EXECUTOR_QUEUE_SIZE) + override protected def getPartitions: Array[Partition] = { (0 until numPartitions).map(ContinuousShuffleReadPartition).toArray } override def compute(split: Partition, context: TaskContext): Iterator[UnsafeRow] = { - val receiver = split.asInstanceOf[ContinuousShuffleReadPartition].receiver + val part = split.asInstanceOf[ContinuousShuffleReadPartition] + if (part.receiver == null) { + val env = SparkEnv.get.rpcEnv + part.receiver = new UnsafeRowReceiver(queueSize, env) + part.endpoint = env.setupEndpoint(UUID.randomUUID().toString, part.receiver) + TaskContext.get().addTaskCompletionListener { _ => + env.stop(part.endpoint) + } + } new NextIterator[UnsafeRow] { + private val receiver = part.receiver + override def getNext(): UnsafeRow = receiver.poll() match { case ReceiverRow(r) => r case ReceiverEpochMarker() => From f91bfe7e3fc174202d7d5c7cde5a8fb7ce86bfd3 Mon Sep 17 00:00:00 2001 From: Jose Torres Date: Thu, 17 May 2018 17:00:44 -0700 Subject: [PATCH 07/18] endpoint name --- .../streaming/continuous/shuffle/ContinuousShuffleReadRDD.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleReadRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleReadRDD.scala index a85ba28a8d528..d4128c5e274fe 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleReadRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleReadRDD.scala @@ -54,7 +54,7 @@ class ContinuousShuffleReadRDD(sc: SparkContext, numPartitions: Int) if (part.receiver == null) { val env = SparkEnv.get.rpcEnv part.receiver = new UnsafeRowReceiver(queueSize, env) - part.endpoint = env.setupEndpoint(UUID.randomUUID().toString, part.receiver) + part.endpoint = env.setupEndpoint(s"UnsafeRowReceiver-${UUID.randomUUID()}", part.receiver) TaskContext.get().addTaskCompletionListener { _ => env.stop(part.endpoint) } From 259029298fc42a65e8ebb4d2effe49b7fafa96f1 Mon Sep 17 00:00:00 2001 From: Jose Torres Date: Thu, 17 May 2018 17:02:08 -0700 Subject: [PATCH 08/18] testing bool --- .../streaming/continuous/shuffle/UnsafeRowReceiver.scala | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/UnsafeRowReceiver.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/UnsafeRowReceiver.scala index f50b89f276eaa..d6a6d978d2653 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/UnsafeRowReceiver.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/UnsafeRowReceiver.scala @@ -45,7 +45,9 @@ private[shuffle] class UnsafeRowReceiver( // Note that this queue will be drained from the main task thread and populated in the RPC // response thread. private val queue = new ArrayBlockingQueue[UnsafeRowReceiverMessage](queueSize) - var stopped = new AtomicBoolean(false) + + // Exposed for testing to determine if the endpoint gets stopped on task end. + private[shuffle] val stopped = new AtomicBoolean(false) override def onStop(): Unit = { stopped.set(true) From 859e6e4dd4dd90ffd70fc9cbd243c94090d72506 Mon Sep 17 00:00:00 2001 From: Jose Torres Date: Thu, 17 May 2018 17:22:10 -0700 Subject: [PATCH 09/18] tests --- .../shuffle/ContinuousShuffleReadRDD.scala | 33 ++++++++----------- .../shuffle/ContinuousShuffleReadSuite.scala | 23 ++++++------- 2 files changed, 25 insertions(+), 31 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleReadRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleReadRDD.scala index d4128c5e274fe..7f5aaecabf5ad 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleReadRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleReadRDD.scala @@ -20,19 +20,22 @@ package org.apache.spark.sql.execution.streaming.continuous.shuffle import java.util.UUID import org.apache.spark.{Partition, SparkContext, SparkEnv, TaskContext} - import org.apache.spark.rdd.RDD -import org.apache.spark.rpc.{RpcEndpoint, RpcEndpointRef} import org.apache.spark.sql.catalyst.expressions.UnsafeRow import org.apache.spark.sql.internal.SQLConf import org.apache.spark.util.NextIterator -case class ContinuousShuffleReadPartition(index: Int) extends Partition { - // Semantically lazy vals - initialized only on the executor, and only once even as we call - // compute() multiple times. We need to initialize them inside compute() so we have access to the - // RDD's conf. - var receiver: UnsafeRowReceiver = _ - var endpoint: RpcEndpointRef = _ +case class ContinuousShuffleReadPartition(index: Int, queueSize: Int) extends Partition { + // Initialized only on the executor, and only once even as we call compute() multiple times. + lazy val (receiver, endpoint) = { + val env = SparkEnv.get.rpcEnv + val receiver = new UnsafeRowReceiver(queueSize, env) + val endpoint = env.setupEndpoint(s"UnsafeRowReceiver-${UUID.randomUUID().toString}", receiver) + TaskContext.get().addTaskCompletionListener { ctx => + env.stop(endpoint) + } + (receiver, endpoint) + } } /** @@ -46,23 +49,13 @@ class ContinuousShuffleReadRDD(sc: SparkContext, numPartitions: Int) private val queueSize = sc.conf.get(SQLConf.CONTINUOUS_STREAMING_EXECUTOR_QUEUE_SIZE) override protected def getPartitions: Array[Partition] = { - (0 until numPartitions).map(ContinuousShuffleReadPartition).toArray + (0 until numPartitions).map(ContinuousShuffleReadPartition(_, queueSize)).toArray } override def compute(split: Partition, context: TaskContext): Iterator[UnsafeRow] = { - val part = split.asInstanceOf[ContinuousShuffleReadPartition] - if (part.receiver == null) { - val env = SparkEnv.get.rpcEnv - part.receiver = new UnsafeRowReceiver(queueSize, env) - part.endpoint = env.setupEndpoint(s"UnsafeRowReceiver-${UUID.randomUUID()}", part.receiver) - TaskContext.get().addTaskCompletionListener { _ => - env.stop(part.endpoint) - } - } + val receiver = split.asInstanceOf[ContinuousShuffleReadPartition].receiver new NextIterator[UnsafeRow] { - private val receiver = part.receiver - override def getNext(): UnsafeRow = receiver.poll() match { case ReceiverRow(r) => r case ReceiverEpochMarker() => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/shuffle/ContinuousShuffleReadSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/shuffle/ContinuousShuffleReadSuite.scala index faec753768895..19558e1c49319 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/shuffle/ContinuousShuffleReadSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/shuffle/ContinuousShuffleReadSuite.scala @@ -79,10 +79,7 @@ class ContinuousShuffleReadSuite extends StreamTest { endpoint.askSync[Unit](ReceiverEpochMarker()) val iter = rdd.compute(rdd.partitions(0), ctx) - assert(iter.next().getInt(0) == 111) - assert(iter.next().getInt(0) == 222) - assert(iter.next().getInt(0) == 333) - assert(!iter.hasNext) + assert(iter.toSeq.map(_.getInt(0)) == Seq(111, 222, 333)) } test("multiple epochs") { @@ -95,13 +92,10 @@ class ContinuousShuffleReadSuite extends StreamTest { endpoint.askSync[Unit](ReceiverEpochMarker()) val firstEpoch = rdd.compute(rdd.partitions(0), ctx) - assert(firstEpoch.next().getInt(0) == 111) - assert(!firstEpoch.hasNext) + assert(firstEpoch.toSeq.map(_.getInt(0)) == Seq(111)) val secondEpoch = rdd.compute(rdd.partitions(0), ctx) - assert(secondEpoch.next().getInt(0) == 222) - assert(secondEpoch.next().getInt(0) == 333) - assert(!secondEpoch.hasNext) + assert(secondEpoch.toSeq.map(_.getInt(0)) == Seq(222, 333)) } test("empty epochs") { @@ -112,23 +106,30 @@ class ContinuousShuffleReadSuite extends StreamTest { endpoint.askSync[Unit](ReceiverRow(unsafeRow(111))) endpoint.askSync[Unit](ReceiverEpochMarker()) endpoint.askSync[Unit](ReceiverEpochMarker()) + endpoint.askSync[Unit](ReceiverEpochMarker()) assert(rdd.compute(rdd.partitions(0), ctx).isEmpty) assert(rdd.compute(rdd.partitions(0), ctx).isEmpty) + val thirdEpoch = rdd.compute(rdd.partitions(0), ctx) - assert(thirdEpoch.next().getInt(0) == 111) + assert(thirdEpoch.toSeq.map(_.getInt(0)) == Seq(111)) + assert(rdd.compute(rdd.partitions(0), ctx).isEmpty) assert(rdd.compute(rdd.partitions(0), ctx).isEmpty) } test("multiple partitions") { val rdd = new ContinuousShuffleReadRDD(sparkContext, numPartitions = 5) + // Send all data before processing to ensure there's no crossover. for (p <- rdd.partitions) { val part = p.asInstanceOf[ContinuousShuffleReadPartition] - // Send index to ensure data doesn't somehow cross over between partitions. + // Send index for identification. part.endpoint.askSync[Unit](ReceiverRow(unsafeRow(part.index))) part.endpoint.askSync[Unit](ReceiverEpochMarker()) + } + for (p <- rdd.partitions) { + val part = p.asInstanceOf[ContinuousShuffleReadPartition] val iter = rdd.compute(part, ctx) assert(iter.next().getInt(0) == part.index) assert(!iter.hasNext) From b23b7bb17abe3cbc873a3144c56d08c88bc0c963 Mon Sep 17 00:00:00 2001 From: Jose Torres Date: Thu, 17 May 2018 17:40:55 -0700 Subject: [PATCH 10/18] take instead of poll --- .../shuffle/ContinuousShuffleReadRDD.scala | 2 +- .../shuffle/UnsafeRowReceiver.scala | 4 ++-- .../shuffle/ContinuousShuffleReadSuite.scala | 21 +++++++++++++++++++ 3 files changed, 24 insertions(+), 3 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleReadRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleReadRDD.scala index 7f5aaecabf5ad..16d7f29ea31cc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleReadRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleReadRDD.scala @@ -56,7 +56,7 @@ class ContinuousShuffleReadRDD(sc: SparkContext, numPartitions: Int) val receiver = split.asInstanceOf[ContinuousShuffleReadPartition].receiver new NextIterator[UnsafeRow] { - override def getNext(): UnsafeRow = receiver.poll() match { + override def getNext(): UnsafeRow = receiver.take() match { case ReceiverRow(r) => r case ReceiverEpochMarker() => finished = true diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/UnsafeRowReceiver.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/UnsafeRowReceiver.scala index d6a6d978d2653..37cee999b1bed 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/UnsafeRowReceiver.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/UnsafeRowReceiver.scala @@ -60,7 +60,7 @@ private[shuffle] class UnsafeRowReceiver( } /** - * Polls until a new row is available. + * Take the next row, blocking until it's ready. */ - def poll(): UnsafeRowReceiverMessage = queue.poll() + def take(): UnsafeRowReceiverMessage = queue.take() } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/shuffle/ContinuousShuffleReadSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/shuffle/ContinuousShuffleReadSuite.scala index 19558e1c49319..6f201186d572c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/shuffle/ContinuousShuffleReadSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/shuffle/ContinuousShuffleReadSuite.scala @@ -29,6 +29,9 @@ class ContinuousShuffleReadSuite extends StreamTest { new GenericInternalRow(Array(value: Any))) } + // In this unit test, we emulate that we're in the task thread where + // ContinuousShuffleReadRDD.compute() will be evaluated. This requires a task context + // thread local to be set. var ctx: TaskContextImpl = _ override def beforeEach(): Unit = { @@ -135,4 +138,22 @@ class ContinuousShuffleReadSuite extends StreamTest { assert(!iter.hasNext) } } + + test("blocks waiting for new rows") { + val rdd = new ContinuousShuffleReadRDD(sparkContext, numPartitions = 1) + + val readRow = new Thread { + override def run(): Unit = { + // set the non-inheritable thread local + TaskContext.setTaskContext(ctx) + val epoch = rdd.compute(rdd.partitions(0), ctx) + epoch.next().getInt(0) + } + } + + readRow.start() + eventually(timeout(streamingTimeout)) { + assert(readRow.getState == Thread.State.WAITING) + } + } } From 97f7e8ff865e6054d0d70914ce9bb51880b161f6 Mon Sep 17 00:00:00 2001 From: Jose Torres Date: Thu, 17 May 2018 17:58:44 -0700 Subject: [PATCH 11/18] add interface --- .../shuffle/ContinuousShuffleReadRDD.scala | 15 ++------- .../shuffle/ContinuousShuffleReader.scala | 31 +++++++++++++++++++ .../shuffle/UnsafeRowReceiver.scala | 19 +++++++++--- .../shuffle/ContinuousShuffleReadSuite.scala | 8 ++--- 4 files changed, 51 insertions(+), 22 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleReader.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleReadRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleReadRDD.scala index 16d7f29ea31cc..74419a5b35fdb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleReadRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleReadRDD.scala @@ -27,7 +27,7 @@ import org.apache.spark.util.NextIterator case class ContinuousShuffleReadPartition(index: Int, queueSize: Int) extends Partition { // Initialized only on the executor, and only once even as we call compute() multiple times. - lazy val (receiver, endpoint) = { + lazy val (reader: ContinuousShuffleReader, endpoint) = { val env = SparkEnv.get.rpcEnv val receiver = new UnsafeRowReceiver(queueSize, env) val endpoint = env.setupEndpoint(s"UnsafeRowReceiver-${UUID.randomUUID().toString}", receiver) @@ -53,17 +53,6 @@ class ContinuousShuffleReadRDD(sc: SparkContext, numPartitions: Int) } override def compute(split: Partition, context: TaskContext): Iterator[UnsafeRow] = { - val receiver = split.asInstanceOf[ContinuousShuffleReadPartition].receiver - - new NextIterator[UnsafeRow] { - override def getNext(): UnsafeRow = receiver.take() match { - case ReceiverRow(r) => r - case ReceiverEpochMarker() => - finished = true - null - } - - override def close(): Unit = {} - } + split.asInstanceOf[ContinuousShuffleReadPartition].reader.read() } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleReader.scala new file mode 100644 index 0000000000000..08091bbc12893 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleReader.scala @@ -0,0 +1,31 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.continuous.shuffle + +import org.apache.spark.sql.catalyst.expressions.UnsafeRow + +/** + * Trait for reading from a continuous processing shuffle. + */ +trait ContinuousShuffleReader { + /** + * Returns an iterator over the incoming rows in the current epoch. Note that this iterator can + * block waiting for new rows to arrive. + */ + def read(): Iterator[UnsafeRow] +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/UnsafeRowReceiver.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/UnsafeRowReceiver.scala index 37cee999b1bed..b8adbb743c6c2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/UnsafeRowReceiver.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/UnsafeRowReceiver.scala @@ -23,6 +23,7 @@ import java.util.concurrent.atomic.AtomicBoolean import org.apache.spark.internal.Logging import org.apache.spark.rpc.{RpcCallContext, RpcEnv, ThreadSafeRpcEndpoint} import org.apache.spark.sql.catalyst.expressions.UnsafeRow +import org.apache.spark.util.NextIterator /** * Messages for the UnsafeRowReceiver endpoint. Either an incoming row or an epoch marker. @@ -41,7 +42,7 @@ private[shuffle] case class ReceiverEpochMarker() extends UnsafeRowReceiverMessa private[shuffle] class UnsafeRowReceiver( queueSize: Int, override val rpcEnv: RpcEnv) - extends ThreadSafeRpcEndpoint with Logging { + extends ThreadSafeRpcEndpoint with ContinuousShuffleReader with Logging { // Note that this queue will be drained from the main task thread and populated in the RPC // response thread. private val queue = new ArrayBlockingQueue[UnsafeRowReceiverMessage](queueSize) @@ -59,8 +60,16 @@ private[shuffle] class UnsafeRowReceiver( context.reply(()) } - /** - * Take the next row, blocking until it's ready. - */ - def take(): UnsafeRowReceiverMessage = queue.take() + override def read(): Iterator[UnsafeRow] = { + new NextIterator[UnsafeRow] { + override def getNext(): UnsafeRow = queue.take() match { + case ReceiverRow(r) => r + case ReceiverEpochMarker() => + finished = true + null + } + + override def close(): Unit = {} + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/shuffle/ContinuousShuffleReadSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/shuffle/ContinuousShuffleReadSuite.scala index 6f201186d572c..718fb0740bbe7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/shuffle/ContinuousShuffleReadSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/shuffle/ContinuousShuffleReadSuite.scala @@ -54,9 +54,9 @@ class ContinuousShuffleReadSuite extends StreamTest { endpoint.askSync[Unit](ReceiverRow(unsafeRow(111))) ctx.markTaskCompleted(None) - val receiver = rdd.partitions(0).asInstanceOf[ContinuousShuffleReadPartition].receiver + val receiver = rdd.partitions(0).asInstanceOf[ContinuousShuffleReadPartition].reader eventually(timeout(streamingTimeout)) { - assert(receiver.stopped.get()) + assert(receiver.asInstanceOf[UnsafeRowReceiver].stopped.get()) } } @@ -67,9 +67,9 @@ class ContinuousShuffleReadSuite extends StreamTest { endpoint.askSync[Unit](ReceiverEpochMarker()) ctx.markTaskCompleted(None) - val receiver = rdd.partitions(0).asInstanceOf[ContinuousShuffleReadPartition].receiver + val receiver = rdd.partitions(0).asInstanceOf[ContinuousShuffleReadPartition].reader eventually(timeout(streamingTimeout)) { - assert(receiver.stopped.get()) + assert(receiver.asInstanceOf[UnsafeRowReceiver].stopped.get()) } } From de21b1c25a333d44c0521fe151b468e51f0bdc47 Mon Sep 17 00:00:00 2001 From: Jose Torres Date: Thu, 17 May 2018 18:02:37 -0700 Subject: [PATCH 12/18] clarify comment --- .../continuous/shuffle/ContinuousShuffleReader.scala | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleReader.scala index 08091bbc12893..42631c90ebc55 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleReader.scala @@ -24,8 +24,9 @@ import org.apache.spark.sql.catalyst.expressions.UnsafeRow */ trait ContinuousShuffleReader { /** - * Returns an iterator over the incoming rows in the current epoch. Note that this iterator can - * block waiting for new rows to arrive. + * Returns an iterator over the incoming rows in an epoch. Implementations should block waiting + * for new rows to arrive, and end the iterator once they've received epoch markers from all + * shuffle writers. */ def read(): Iterator[UnsafeRow] } From 154843d799683c5cdfc035033475f223f85f0d66 Mon Sep 17 00:00:00 2001 From: Jose Torres Date: Fri, 18 May 2018 15:41:18 -0700 Subject: [PATCH 13/18] don't use spark conf for the sql conf --- .../shuffle/ContinuousShuffleReadRDD.scala | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleReadRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleReadRDD.scala index 74419a5b35fdb..b0a09180ccb60 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleReadRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleReadRDD.scala @@ -43,13 +43,16 @@ case class ContinuousShuffleReadPartition(index: Int, queueSize: Int) extends Pa * shuffle output to the wrapped receivers in partitions of this RDD; each of the RDD's tasks * poll from their receiver until an epoch marker is sent. */ -class ContinuousShuffleReadRDD(sc: SparkContext, numPartitions: Int) - extends RDD[UnsafeRow](sc, Nil) { - - private val queueSize = sc.conf.get(SQLConf.CONTINUOUS_STREAMING_EXECUTOR_QUEUE_SIZE) +class ContinuousShuffleReadRDD( + sc: SparkContext, + numPartitions: Int, + queueSize: Int = 1024) + extends RDD[UnsafeRow](sc, Nil) { override protected def getPartitions: Array[Partition] = { - (0 until numPartitions).map(ContinuousShuffleReadPartition(_, queueSize)).toArray + (0 until numPartitions).map { partIndex => + ContinuousShuffleReadPartition(partIndex, queueSize) + }.toArray } override def compute(split: Partition, context: TaskContext): Iterator[UnsafeRow] = { From f0262d0a9d3539bcf8fbdbb248968fd704d1e690 Mon Sep 17 00:00:00 2001 From: Jose Torres Date: Fri, 18 May 2018 15:54:27 -0700 Subject: [PATCH 14/18] end thread --- .../shuffle/ContinuousShuffleReadSuite.scala | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/shuffle/ContinuousShuffleReadSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/shuffle/ContinuousShuffleReadSuite.scala index 718fb0740bbe7..7706a37834d88 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/shuffle/ContinuousShuffleReadSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/shuffle/ContinuousShuffleReadSuite.scala @@ -151,9 +151,14 @@ class ContinuousShuffleReadSuite extends StreamTest { } } - readRow.start() - eventually(timeout(streamingTimeout)) { - assert(readRow.getState == Thread.State.WAITING) + try { + readRow.start() + eventually(timeout(streamingTimeout)) { + assert(readRow.getState == Thread.State.WAITING) + } + } finally { + readRow.interrupt() + readRow.join() } } } From 3e7a6f9d31967d9efc618c4d319a9dabd22ae4e5 Mon Sep 17 00:00:00 2001 From: Jose Torres Date: Fri, 18 May 2018 15:54:54 -0700 Subject: [PATCH 15/18] name thread --- .../shuffle/ContinuousShuffleReadSuite.scala | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/shuffle/ContinuousShuffleReadSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/shuffle/ContinuousShuffleReadSuite.scala index 7706a37834d88..d6e6eeb127a60 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/shuffle/ContinuousShuffleReadSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/shuffle/ContinuousShuffleReadSuite.scala @@ -142,7 +142,7 @@ class ContinuousShuffleReadSuite extends StreamTest { test("blocks waiting for new rows") { val rdd = new ContinuousShuffleReadRDD(sparkContext, numPartitions = 1) - val readRow = new Thread { + val readRowThread = new Thread { override def run(): Unit = { // set the non-inheritable thread local TaskContext.setTaskContext(ctx) @@ -152,13 +152,13 @@ class ContinuousShuffleReadSuite extends StreamTest { } try { - readRow.start() + readRowThread.start() eventually(timeout(streamingTimeout)) { - assert(readRow.getState == Thread.State.WAITING) + assert(readRowThread.getState == Thread.State.WAITING) } } finally { - readRow.interrupt() - readRow.join() + readRowThread.interrupt() + readRowThread.join() } } } From 0a38ced23b7e1a6dfe9588ef0ebf7c071a08055d Mon Sep 17 00:00:00 2001 From: Jose Torres Date: Fri, 18 May 2018 15:55:14 -0700 Subject: [PATCH 16/18] no toString --- .../streaming/continuous/shuffle/ContinuousShuffleReadRDD.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleReadRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleReadRDD.scala index b0a09180ccb60..270b1a5c28dee 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleReadRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleReadRDD.scala @@ -30,7 +30,7 @@ case class ContinuousShuffleReadPartition(index: Int, queueSize: Int) extends Pa lazy val (reader: ContinuousShuffleReader, endpoint) = { val env = SparkEnv.get.rpcEnv val receiver = new UnsafeRowReceiver(queueSize, env) - val endpoint = env.setupEndpoint(s"UnsafeRowReceiver-${UUID.randomUUID().toString}", receiver) + val endpoint = env.setupEndpoint(s"UnsafeRowReceiver-${UUID.randomUUID()}", receiver) TaskContext.get().addTaskCompletionListener { ctx => env.stop(endpoint) } From ef34e6e9817274df9378341bfb52105c591a5507 Mon Sep 17 00:00:00 2001 From: Jose Torres Date: Fri, 18 May 2018 16:00:01 -0700 Subject: [PATCH 17/18] send method --- .../shuffle/ContinuousShuffleReadSuite.scala | 57 ++++++++++++------- 1 file changed, 38 insertions(+), 19 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/shuffle/ContinuousShuffleReadSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/shuffle/ContinuousShuffleReadSuite.scala index d6e6eeb127a60..e71274365a759 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/shuffle/ContinuousShuffleReadSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/shuffle/ContinuousShuffleReadSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.execution.streaming.continuous.shuffle import org.apache.spark.{TaskContext, TaskContextImpl} +import org.apache.spark.rpc.RpcEndpointRef import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeProjection} import org.apache.spark.sql.streaming.StreamTest import org.apache.spark.sql.types.{DataType, IntegerType} @@ -29,6 +30,10 @@ class ContinuousShuffleReadSuite extends StreamTest { new GenericInternalRow(Array(value: Any))) } + private def send(endpoint: RpcEndpointRef, messages: UnsafeRowReceiverMessage*) = { + messages.foreach(endpoint.askSync[Unit](_)) + } + // In this unit test, we emulate that we're in the task thread where // ContinuousShuffleReadRDD.compute() will be evaluated. This requires a task context // thread local to be set. @@ -50,8 +55,11 @@ class ContinuousShuffleReadSuite extends StreamTest { test("receiver stopped with row last") { val rdd = new ContinuousShuffleReadRDD(sparkContext, numPartitions = 1) val endpoint = rdd.partitions(0).asInstanceOf[ContinuousShuffleReadPartition].endpoint - endpoint.askSync[Unit](ReceiverEpochMarker()) - endpoint.askSync[Unit](ReceiverRow(unsafeRow(111))) + send( + endpoint, + ReceiverEpochMarker(), + ReceiverRow(unsafeRow(111)) + ) ctx.markTaskCompleted(None) val receiver = rdd.partitions(0).asInstanceOf[ContinuousShuffleReadPartition].reader @@ -76,10 +84,13 @@ class ContinuousShuffleReadSuite extends StreamTest { test("one epoch") { val rdd = new ContinuousShuffleReadRDD(sparkContext, numPartitions = 1) val endpoint = rdd.partitions(0).asInstanceOf[ContinuousShuffleReadPartition].endpoint - endpoint.askSync[Unit](ReceiverRow(unsafeRow(111))) - endpoint.askSync[Unit](ReceiverRow(unsafeRow(222))) - endpoint.askSync[Unit](ReceiverRow(unsafeRow(333))) - endpoint.askSync[Unit](ReceiverEpochMarker()) + send( + endpoint, + ReceiverRow(unsafeRow(111)), + ReceiverRow(unsafeRow(222)), + ReceiverRow(unsafeRow(333)), + ReceiverEpochMarker() + ) val iter = rdd.compute(rdd.partitions(0), ctx) assert(iter.toSeq.map(_.getInt(0)) == Seq(111, 222, 333)) @@ -88,11 +99,13 @@ class ContinuousShuffleReadSuite extends StreamTest { test("multiple epochs") { val rdd = new ContinuousShuffleReadRDD(sparkContext, numPartitions = 1) val endpoint = rdd.partitions(0).asInstanceOf[ContinuousShuffleReadPartition].endpoint - endpoint.askSync[Unit](ReceiverRow(unsafeRow(111))) - endpoint.askSync[Unit](ReceiverEpochMarker()) - endpoint.askSync[Unit](ReceiverRow(unsafeRow(222))) - endpoint.askSync[Unit](ReceiverRow(unsafeRow(333))) - endpoint.askSync[Unit](ReceiverEpochMarker()) + send( + endpoint, + ReceiverRow(unsafeRow(111)), + ReceiverRow(unsafeRow(222)), + ReceiverRow(unsafeRow(333)), + ReceiverEpochMarker() + ) val firstEpoch = rdd.compute(rdd.partitions(0), ctx) assert(firstEpoch.toSeq.map(_.getInt(0)) == Seq(111)) @@ -104,12 +117,15 @@ class ContinuousShuffleReadSuite extends StreamTest { test("empty epochs") { val rdd = new ContinuousShuffleReadRDD(sparkContext, numPartitions = 1) val endpoint = rdd.partitions(0).asInstanceOf[ContinuousShuffleReadPartition].endpoint - endpoint.askSync[Unit](ReceiverEpochMarker()) - endpoint.askSync[Unit](ReceiverEpochMarker()) - endpoint.askSync[Unit](ReceiverRow(unsafeRow(111))) - endpoint.askSync[Unit](ReceiverEpochMarker()) - endpoint.askSync[Unit](ReceiverEpochMarker()) - endpoint.askSync[Unit](ReceiverEpochMarker()) + send( + endpoint, + ReceiverEpochMarker(), + ReceiverEpochMarker(), + ReceiverRow(unsafeRow(111)), + ReceiverEpochMarker(), + ReceiverEpochMarker(), + ReceiverEpochMarker() + ) assert(rdd.compute(rdd.partitions(0), ctx).isEmpty) assert(rdd.compute(rdd.partitions(0), ctx).isEmpty) @@ -127,8 +143,11 @@ class ContinuousShuffleReadSuite extends StreamTest { for (p <- rdd.partitions) { val part = p.asInstanceOf[ContinuousShuffleReadPartition] // Send index for identification. - part.endpoint.askSync[Unit](ReceiverRow(unsafeRow(part.index))) - part.endpoint.askSync[Unit](ReceiverEpochMarker()) + send( + part.endpoint, + ReceiverRow(unsafeRow(part.index)), + ReceiverEpochMarker() + ) } for (p <- rdd.partitions) { From 00f910ea39b76a24e1e21acdf3d6a20fd7784fa9 Mon Sep 17 00:00:00 2001 From: Jose Torres Date: Fri, 18 May 2018 16:02:10 -0700 Subject: [PATCH 18/18] fix --- .../continuous/shuffle/ContinuousShuffleReadSuite.scala | 1 + 1 file changed, 1 insertion(+) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/shuffle/ContinuousShuffleReadSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/shuffle/ContinuousShuffleReadSuite.scala index e71274365a759..b25e75b3b37a6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/shuffle/ContinuousShuffleReadSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/shuffle/ContinuousShuffleReadSuite.scala @@ -102,6 +102,7 @@ class ContinuousShuffleReadSuite extends StreamTest { send( endpoint, ReceiverRow(unsafeRow(111)), + ReceiverEpochMarker(), ReceiverRow(unsafeRow(222)), ReceiverRow(unsafeRow(333)), ReceiverEpochMarker()