Skip to content

Commit

Permalink
[SPARK-23687][SS] Add a memory source for continuous processing.
Browse files Browse the repository at this point in the history
## What changes were proposed in this pull request?

Add a memory source for continuous processing.

Note that only one of the ContinuousSuite tests is migrated to minimize the diff here. I'll submit a second PR for SPARK-23688 to change the rest and get rid of waitForRateSourceTriggers.

## How was this patch tested?

unit test

Author: Jose Torres <torres.joseph.f+github@gmail.com>

Closes #20828 from jose-torres/continuousMemory.
  • Loading branch information
jose-torres authored and tdas committed Apr 17, 2018
1 parent 14844a6 commit 1cc66a0
Show file tree
Hide file tree
Showing 5 changed files with 266 additions and 44 deletions.
Expand Up @@ -31,6 +31,7 @@ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.execution.SQLExecution
import org.apache.spark.sql.execution.datasources.v2.{StreamingDataSourceV2Relation, WriteToDataSourceV2}
import org.apache.spark.sql.execution.streaming.{ContinuousExecutionRelation, StreamingRelationV2, _}
import org.apache.spark.sql.sources.v2
import org.apache.spark.sql.sources.v2.{ContinuousReadSupport, DataSourceOptions, StreamWriteSupport}
import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousReader, PartitionOffset}
import org.apache.spark.sql.streaming.{OutputMode, ProcessingTime, Trigger}
Expand Down Expand Up @@ -317,8 +318,10 @@ class ContinuousExecution(
synchronized {
if (queryExecutionThread.isAlive) {
commitLog.add(epoch)
val offset = offsetLog.get(epoch).get.offsets(0).get
val offset =
continuousSources(0).deserializeOffset(offsetLog.get(epoch).get.offsets(0).get.json)
committedOffsets ++= Seq(continuousSources(0) -> offset)
continuousSources(0).commit(offset.asInstanceOf[v2.reader.streaming.Offset])
} else {
return
}
Expand Down
Expand Up @@ -24,17 +24,19 @@ import javax.annotation.concurrent.GuardedBy

import scala.collection.JavaConverters._
import scala.collection.mutable.{ArrayBuffer, ListBuffer}
import scala.reflect.ClassTag
import scala.util.control.NonFatal

import org.apache.spark.internal.Logging
import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.encoders.encoderFor
import org.apache.spark.sql.catalyst.encoders.{encoderFor, ExpressionEncoder}
import org.apache.spark.sql.catalyst.expressions.{Attribute, UnsafeRow}
import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, Statistics}
import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LogicalPlan, Statistics}
import org.apache.spark.sql.catalyst.streaming.InternalOutputModes._
import org.apache.spark.sql.execution.streaming.continuous.ContinuousTrigger
import org.apache.spark.sql.sources.v2.reader.{DataReader, DataReaderFactory, SupportsScanUnsafeRow}
import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchReader, Offset => OffsetV2}
import org.apache.spark.sql.streaming.OutputMode
import org.apache.spark.sql.streaming.{OutputMode, Trigger}
import org.apache.spark.sql.types.StructType
import org.apache.spark.util.Utils

Expand All @@ -47,16 +49,43 @@ object MemoryStream {
new MemoryStream[A](memoryStreamId.getAndIncrement(), sqlContext)
}

/**
* A base class for memory stream implementations. Supports adding data and resetting.
*/
abstract class MemoryStreamBase[A : Encoder](sqlContext: SQLContext) extends BaseStreamingSource {
protected val encoder = encoderFor[A]
protected val attributes = encoder.schema.toAttributes

def toDS(): Dataset[A] = {
Dataset[A](sqlContext.sparkSession, logicalPlan)
}

def toDF(): DataFrame = {
Dataset.ofRows(sqlContext.sparkSession, logicalPlan)
}

def addData(data: A*): Offset = {
addData(data.toTraversable)
}

def readSchema(): StructType = encoder.schema

protected def logicalPlan: LogicalPlan

def addData(data: TraversableOnce[A]): Offset
}

/**
* A [[Source]] that produces value stored in memory as they are added by the user. This [[Source]]
* is intended for use in unit tests as it can only replay data when the object is still
* available.
*/
case class MemoryStream[A : Encoder](id: Int, sqlContext: SQLContext)
extends MicroBatchReader with SupportsScanUnsafeRow with Logging {
protected val encoder = encoderFor[A]
private val attributes = encoder.schema.toAttributes
protected val logicalPlan = StreamingExecutionRelation(this, attributes)(sqlContext.sparkSession)
extends MemoryStreamBase[A](sqlContext)
with MicroBatchReader with SupportsScanUnsafeRow with Logging {

protected val logicalPlan: LogicalPlan =
StreamingExecutionRelation(this, attributes)(sqlContext.sparkSession)
protected val output = logicalPlan.output

/**
Expand All @@ -70,7 +99,7 @@ case class MemoryStream[A : Encoder](id: Int, sqlContext: SQLContext)
protected var currentOffset: LongOffset = new LongOffset(-1)

@GuardedBy("this")
private var startOffset = new LongOffset(-1)
protected var startOffset = new LongOffset(-1)

@GuardedBy("this")
private var endOffset = new LongOffset(-1)
Expand All @@ -82,18 +111,6 @@ case class MemoryStream[A : Encoder](id: Int, sqlContext: SQLContext)
@GuardedBy("this")
protected var lastOffsetCommitted : LongOffset = new LongOffset(-1)

def toDS(): Dataset[A] = {
Dataset(sqlContext.sparkSession, logicalPlan)
}

def toDF(): DataFrame = {
Dataset.ofRows(sqlContext.sparkSession, logicalPlan)
}

def addData(data: A*): Offset = {
addData(data.toTraversable)
}

def addData(data: TraversableOnce[A]): Offset = {
val objects = data.toSeq
val rows = objects.iterator.map(d => encoder.toRow(d).copy().asInstanceOf[UnsafeRow]).toArray
Expand All @@ -114,8 +131,6 @@ case class MemoryStream[A : Encoder](id: Int, sqlContext: SQLContext)
}
}

override def readSchema(): StructType = encoder.schema

override def deserializeOffset(json: String): OffsetV2 = LongOffset(json.toLong)

override def getStartOffset: OffsetV2 = synchronized {
Expand Down
@@ -0,0 +1,211 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.execution.streaming.sources

import java.{util => ju}
import java.util.Optional
import java.util.concurrent.atomic.AtomicInteger
import javax.annotation.concurrent.GuardedBy

import scala.collection.JavaConverters._
import scala.collection.mutable.ListBuffer

import org.json4s.NoTypeHints
import org.json4s.jackson.Serialization

import org.apache.spark.SparkEnv
import org.apache.spark.rpc.{RpcCallContext, RpcEndpointRef, RpcEnv, ThreadSafeRpcEndpoint}
import org.apache.spark.sql.{Encoder, Row, SQLContext}
import org.apache.spark.sql.execution.streaming._
import org.apache.spark.sql.execution.streaming.sources.ContinuousMemoryStream.GetRecord
import org.apache.spark.sql.sources.v2.{ContinuousReadSupport, DataSourceOptions}
import org.apache.spark.sql.sources.v2.reader.DataReaderFactory
import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousDataReader, ContinuousReader, Offset, PartitionOffset}
import org.apache.spark.sql.types.StructType
import org.apache.spark.util.RpcUtils

/**
* The overall strategy here is:
* * ContinuousMemoryStream maintains a list of records for each partition. addData() will
* distribute records evenly-ish across partitions.
* * RecordEndpoint is set up as an endpoint for executor-side
* ContinuousMemoryStreamDataReader instances to poll. It returns the record at the specified
* offset within the list, or null if that offset doesn't yet have a record.
*/
class ContinuousMemoryStream[A : Encoder](id: Int, sqlContext: SQLContext)
extends MemoryStreamBase[A](sqlContext) with ContinuousReader with ContinuousReadSupport {
private implicit val formats = Serialization.formats(NoTypeHints)
private val NUM_PARTITIONS = 2

protected val logicalPlan =
StreamingRelationV2(this, "memory", Map(), attributes, None)(sqlContext.sparkSession)

// ContinuousReader implementation

@GuardedBy("this")
private val records = Seq.fill(NUM_PARTITIONS)(new ListBuffer[A])

@GuardedBy("this")
private var startOffset: ContinuousMemoryStreamOffset = _

private val recordEndpoint = new RecordEndpoint()
@volatile private var endpointRef: RpcEndpointRef = _

def addData(data: TraversableOnce[A]): Offset = synchronized {
// Distribute data evenly among partition lists.
data.toSeq.zipWithIndex.map {
case (item, index) => records(index % NUM_PARTITIONS) += item
}

// The new target offset is the offset where all records in all partitions have been processed.
ContinuousMemoryStreamOffset((0 until NUM_PARTITIONS).map(i => (i, records(i).size)).toMap)
}

override def setStartOffset(start: Optional[Offset]): Unit = synchronized {
// Inferred initial offset is position 0 in each partition.
startOffset = start.orElse {
ContinuousMemoryStreamOffset((0 until NUM_PARTITIONS).map(i => (i, 0)).toMap)
}.asInstanceOf[ContinuousMemoryStreamOffset]
}

override def getStartOffset: Offset = synchronized {
startOffset
}

override def deserializeOffset(json: String): ContinuousMemoryStreamOffset = {
ContinuousMemoryStreamOffset(Serialization.read[Map[Int, Int]](json))
}

override def mergeOffsets(offsets: Array[PartitionOffset]): ContinuousMemoryStreamOffset = {
ContinuousMemoryStreamOffset(
offsets.map {
case ContinuousMemoryStreamPartitionOffset(part, num) => (part, num)
}.toMap
)
}

override def createDataReaderFactories(): ju.List[DataReaderFactory[Row]] = {
synchronized {
val endpointName = s"ContinuousMemoryStreamRecordEndpoint-${java.util.UUID.randomUUID()}-$id"
endpointRef =
recordEndpoint.rpcEnv.setupEndpoint(endpointName, recordEndpoint)

startOffset.partitionNums.map {
case (part, index) =>
new ContinuousMemoryStreamDataReaderFactory(
endpointName, part, index): DataReaderFactory[Row]
}.toList.asJava
}
}

override def stop(): Unit = {
if (endpointRef != null) recordEndpoint.rpcEnv.stop(endpointRef)
}

override def commit(end: Offset): Unit = {}

// ContinuousReadSupport implementation
// This is necessary because of how StreamTest finds the source for AddDataMemory steps.
def createContinuousReader(
schema: Optional[StructType],
checkpointLocation: String,
options: DataSourceOptions): ContinuousReader = {
this
}

/**
* Endpoint for executors to poll for records.
*/
private class RecordEndpoint extends ThreadSafeRpcEndpoint {
override val rpcEnv: RpcEnv = SparkEnv.get.rpcEnv

override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = {
case GetRecord(ContinuousMemoryStreamPartitionOffset(part, index)) =>
ContinuousMemoryStream.this.synchronized {
val buf = records(part)
val record = if (buf.size <= index) None else Some(buf(index))

context.reply(record.map(Row(_)))
}
}
}
}

object ContinuousMemoryStream {
case class GetRecord(offset: ContinuousMemoryStreamPartitionOffset)
protected val memoryStreamId = new AtomicInteger(0)

def apply[A : Encoder](implicit sqlContext: SQLContext): ContinuousMemoryStream[A] =
new ContinuousMemoryStream[A](memoryStreamId.getAndIncrement(), sqlContext)
}

/**
* Data reader factory for continuous memory stream.
*/
class ContinuousMemoryStreamDataReaderFactory(
driverEndpointName: String,
partition: Int,
startOffset: Int) extends DataReaderFactory[Row] {
override def createDataReader: ContinuousMemoryStreamDataReader =
new ContinuousMemoryStreamDataReader(driverEndpointName, partition, startOffset)
}

/**
* Data reader for continuous memory stream.
*
* Polls the driver endpoint for new records.
*/
class ContinuousMemoryStreamDataReader(
driverEndpointName: String,
partition: Int,
startOffset: Int) extends ContinuousDataReader[Row] {
private val endpoint = RpcUtils.makeDriverRef(
driverEndpointName,
SparkEnv.get.conf,
SparkEnv.get.rpcEnv)

private var currentOffset = startOffset
private var current: Option[Row] = None

override def next(): Boolean = {
current = None
while (current.isEmpty) {
Thread.sleep(10)
current = endpoint.askSync[Option[Row]](
GetRecord(ContinuousMemoryStreamPartitionOffset(partition, currentOffset)))
}
currentOffset += 1
true
}

override def get(): Row = current.get

override def close(): Unit = {}

override def getOffset: ContinuousMemoryStreamPartitionOffset =
ContinuousMemoryStreamPartitionOffset(partition, currentOffset)
}

case class ContinuousMemoryStreamOffset(partitionNums: Map[Int, Int])
extends Offset {
private implicit val formats = Serialization.formats(NoTypeHints)
override def json(): String = Serialization.write(partitionNums)
}

case class ContinuousMemoryStreamPartitionOffset(partition: Int, numProcessed: Int)
extends PartitionOffset
Expand Up @@ -99,7 +99,7 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be
* been processed.
*/
object AddData {
def apply[A](source: MemoryStream[A], data: A*): AddDataMemory[A] =
def apply[A](source: MemoryStreamBase[A], data: A*): AddDataMemory[A] =
AddDataMemory(source, data)
}

Expand Down Expand Up @@ -131,7 +131,7 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be
def runAction(): Unit
}

case class AddDataMemory[A](source: MemoryStream[A], data: Seq[A]) extends AddData {
case class AddDataMemory[A](source: MemoryStreamBase[A], data: Seq[A]) extends AddData {
override def toString: String = s"AddData to $source: ${data.mkString(",")}"

override def addData(query: Option[StreamExecution]): (BaseStreamingSource, Offset) = {
Expand Down

0 comments on commit 1cc66a0

Please sign in to comment.