Skip to content

Commit

Permalink
add interface
Browse files Browse the repository at this point in the history
  • Loading branch information
jose-torres committed May 18, 2018
1 parent b23b7bb commit 97f7e8f
Show file tree
Hide file tree
Showing 4 changed files with 51 additions and 22 deletions.
Expand Up @@ -27,7 +27,7 @@ import org.apache.spark.util.NextIterator

case class ContinuousShuffleReadPartition(index: Int, queueSize: Int) extends Partition {
// Initialized only on the executor, and only once even as we call compute() multiple times.
lazy val (receiver, endpoint) = {
lazy val (reader: ContinuousShuffleReader, endpoint) = {
val env = SparkEnv.get.rpcEnv
val receiver = new UnsafeRowReceiver(queueSize, env)
val endpoint = env.setupEndpoint(s"UnsafeRowReceiver-${UUID.randomUUID().toString}", receiver)
Expand All @@ -53,17 +53,6 @@ class ContinuousShuffleReadRDD(sc: SparkContext, numPartitions: Int)
}

override def compute(split: Partition, context: TaskContext): Iterator[UnsafeRow] = {
val receiver = split.asInstanceOf[ContinuousShuffleReadPartition].receiver

new NextIterator[UnsafeRow] {
override def getNext(): UnsafeRow = receiver.take() match {
case ReceiverRow(r) => r
case ReceiverEpochMarker() =>
finished = true
null
}

override def close(): Unit = {}
}
split.asInstanceOf[ContinuousShuffleReadPartition].reader.read()
}
}
@@ -0,0 +1,31 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.execution.streaming.continuous.shuffle

import org.apache.spark.sql.catalyst.expressions.UnsafeRow

/**
* Trait for reading from a continuous processing shuffle.
*/
trait ContinuousShuffleReader {
/**
* Returns an iterator over the incoming rows in the current epoch. Note that this iterator can
* block waiting for new rows to arrive.
*/
def read(): Iterator[UnsafeRow]
}
Expand Up @@ -23,6 +23,7 @@ import java.util.concurrent.atomic.AtomicBoolean
import org.apache.spark.internal.Logging
import org.apache.spark.rpc.{RpcCallContext, RpcEnv, ThreadSafeRpcEndpoint}
import org.apache.spark.sql.catalyst.expressions.UnsafeRow
import org.apache.spark.util.NextIterator

/**
* Messages for the UnsafeRowReceiver endpoint. Either an incoming row or an epoch marker.
Expand All @@ -41,7 +42,7 @@ private[shuffle] case class ReceiverEpochMarker() extends UnsafeRowReceiverMessa
private[shuffle] class UnsafeRowReceiver(
queueSize: Int,
override val rpcEnv: RpcEnv)
extends ThreadSafeRpcEndpoint with Logging {
extends ThreadSafeRpcEndpoint with ContinuousShuffleReader with Logging {
// Note that this queue will be drained from the main task thread and populated in the RPC
// response thread.
private val queue = new ArrayBlockingQueue[UnsafeRowReceiverMessage](queueSize)
Expand All @@ -59,8 +60,16 @@ private[shuffle] class UnsafeRowReceiver(
context.reply(())
}

/**
* Take the next row, blocking until it's ready.
*/
def take(): UnsafeRowReceiverMessage = queue.take()
override def read(): Iterator[UnsafeRow] = {
new NextIterator[UnsafeRow] {
override def getNext(): UnsafeRow = queue.take() match {
case ReceiverRow(r) => r
case ReceiverEpochMarker() =>
finished = true
null
}

override def close(): Unit = {}
}
}
}
Expand Up @@ -54,9 +54,9 @@ class ContinuousShuffleReadSuite extends StreamTest {
endpoint.askSync[Unit](ReceiverRow(unsafeRow(111)))

ctx.markTaskCompleted(None)
val receiver = rdd.partitions(0).asInstanceOf[ContinuousShuffleReadPartition].receiver
val receiver = rdd.partitions(0).asInstanceOf[ContinuousShuffleReadPartition].reader
eventually(timeout(streamingTimeout)) {
assert(receiver.stopped.get())
assert(receiver.asInstanceOf[UnsafeRowReceiver].stopped.get())
}
}

Expand All @@ -67,9 +67,9 @@ class ContinuousShuffleReadSuite extends StreamTest {
endpoint.askSync[Unit](ReceiverEpochMarker())

ctx.markTaskCompleted(None)
val receiver = rdd.partitions(0).asInstanceOf[ContinuousShuffleReadPartition].receiver
val receiver = rdd.partitions(0).asInstanceOf[ContinuousShuffleReadPartition].reader
eventually(timeout(streamingTimeout)) {
assert(receiver.stopped.get())
assert(receiver.asInstanceOf[UnsafeRowReceiver].stopped.get())
}
}

Expand Down

0 comments on commit 97f7e8f

Please sign in to comment.