Skip to content

Commit

Permalink
Store added data as rows not datasets
Browse files Browse the repository at this point in the history
  • Loading branch information
tdas committed Feb 2, 2018
1 parent 3f50f33 commit c713048
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 39 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,10 @@ import scala.util.control.NonFatal
import org.apache.spark.internal.Logging
import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.encoders.encoderFor
import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LocalRelation, Statistics}
import org.apache.spark.sql.catalyst.expressions.{Attribute, UnsafeRow}
import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, Statistics}
import org.apache.spark.sql.catalyst.streaming.InternalOutputModes._
import org.apache.spark.sql.sources.v2.reader.{DataReader, DataReaderFactory}
import org.apache.spark.sql.sources.v2.reader.{DataReader, DataReaderFactory, SupportsScanUnsafeRow}
import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchReader, Offset => OffsetV2}
import org.apache.spark.sql.streaming.OutputMode
import org.apache.spark.sql.types.StructType
Expand All @@ -53,7 +53,7 @@ object MemoryStream {
* available.
*/
case class MemoryStream[A : Encoder](id: Int, sqlContext: SQLContext)
extends MicroBatchReader with Logging {
extends MicroBatchReader with SupportsScanUnsafeRow with Logging {
protected val encoder = encoderFor[A]
private val attributes = encoder.schema.toAttributes
protected val logicalPlan = StreamingExecutionRelation(this, attributes)(sqlContext.sparkSession)
Expand All @@ -64,7 +64,7 @@ case class MemoryStream[A : Encoder](id: Int, sqlContext: SQLContext)
* Stored in a ListBuffer to facilitate removing committed batches.
*/
@GuardedBy("this")
protected val batches = new ListBuffer[Dataset[A]]
protected val batches = new ListBuffer[Array[UnsafeRow]]

@GuardedBy("this")
protected var currentOffset: LongOffset = new LongOffset(-1)
Expand Down Expand Up @@ -95,13 +95,12 @@ case class MemoryStream[A : Encoder](id: Int, sqlContext: SQLContext)
}

def addData(data: TraversableOnce[A]): Offset = {
val encoded = data.toVector.map(d => encoder.toRow(d).copy())
val plan = new LocalRelation(attributes, encoded, isStreaming = false)
val ds = Dataset[A](sqlContext.sparkSession, plan)
logDebug(s"Adding ds: $ds")
val objects = data.toSeq
val rows = objects.iterator.map(d => encoder.toRow(d).copy().asInstanceOf[UnsafeRow]).toArray
logDebug(s"Adding: $objects")
this.synchronized {
currentOffset = currentOffset + 1
batches += ds
batches += rows
currentOffset
}
}
Expand All @@ -127,36 +126,38 @@ case class MemoryStream[A : Encoder](id: Int, sqlContext: SQLContext)
if (endOffset.offset == -1) null else endOffset
}

override def createDataReaderFactories(): ju.List[DataReaderFactory[Row]] = synchronized {
// Compute the internal batch numbers to fetch: [startOrdinal, endOrdinal)
val startOrdinal = startOffset.offset.toInt + 1
val endOrdinal = endOffset.offset.toInt + 1

// Internal buffer only holds the batches after lastCommittedOffset.
val newBlocks = synchronized {
val sliceStart = startOrdinal - lastOffsetCommitted.offset.toInt - 1
val sliceEnd = endOrdinal - lastOffsetCommitted.offset.toInt - 1
assert(sliceStart <= sliceEnd, s"sliceStart: $sliceStart sliceEnd: $sliceEnd")
batches.slice(sliceStart, sliceEnd)
}
override def createUnsafeRowReaderFactories(): ju.List[DataReaderFactory[UnsafeRow]] = {
synchronized {
// Compute the internal batch numbers to fetch: [startOrdinal, endOrdinal)
val startOrdinal = startOffset.offset.toInt + 1
val endOrdinal = endOffset.offset.toInt + 1

// Internal buffer only holds the batches after lastCommittedOffset.
val newBlocks = synchronized {
val sliceStart = startOrdinal - lastOffsetCommitted.offset.toInt - 1
val sliceEnd = endOrdinal - lastOffsetCommitted.offset.toInt - 1
assert(sliceStart <= sliceEnd, s"sliceStart: $sliceStart sliceEnd: $sliceEnd")
batches.slice(sliceStart, sliceEnd)
}

logDebug(generateDebugString(newBlocks, startOrdinal, endOrdinal))
logDebug(generateDebugString(newBlocks, startOrdinal, endOrdinal))

newBlocks.map { ds =>
new MemoryStreamDataReaderFactory(ds.toDF().collect()).asInstanceOf[DataReaderFactory[Row]]
}.asJava
newBlocks.map { block =>
new MemoryStreamDataReaderFactory(block).asInstanceOf[DataReaderFactory[UnsafeRow]]
}.asJava
}
}

private def generateDebugString(
blocks: TraversableOnce[Dataset[A]],
blocks: Iterable[Array[UnsafeRow]],
startOrdinal: Int,
endOrdinal: Int): String = {
val originalUnsupportedCheck =
sqlContext.getConf("spark.sql.streaming.unsupportedOperationCheck")
try {
sqlContext.setConf("spark.sql.streaming.unsupportedOperationCheck", "false")
s"MemoryBatch [$startOrdinal, $endOrdinal]: " +
s"${blocks.flatMap(_.collect()).mkString(", ")}"
s"${blocks.flatten.map(row => encoder.fromRow(row)).mkString(", ")}"
} finally {
sqlContext.setConf("spark.sql.streaming.unsupportedOperationCheck", originalUnsupportedCheck)
}
Expand Down Expand Up @@ -193,9 +194,10 @@ case class MemoryStream[A : Encoder](id: Int, sqlContext: SQLContext)
}


class MemoryStreamDataReaderFactory(records: Array[Row]) extends DataReaderFactory[Row] {
override def createDataReader(): DataReader[Row] = {
new DataReader[Row] {
class MemoryStreamDataReaderFactory(records: Array[UnsafeRow])
extends DataReaderFactory[UnsafeRow] {
override def createDataReader(): DataReader[UnsafeRow] = {
new DataReader[UnsafeRow] {
private var currentIndex = -1

override def next(): Boolean = {
Expand All @@ -204,7 +206,7 @@ class MemoryStreamDataReaderFactory(records: Array[Row]) extends DataReaderFacto
currentIndex < records.length
}

override def get(): Row = records(currentIndex)
override def get(): UnsafeRow = records(currentIndex)

override def close(): Unit = {}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@ import org.scalatest.mockito.MockitoSugar

import org.apache.spark.SparkException
import org.apache.spark.internal.Logging
import org.apache.spark.sql.{DataFrame, Dataset, Row}
import org.apache.spark.sql.{DataFrame, Dataset}
import org.apache.spark.sql.catalyst.expressions.UnsafeRow
import org.apache.spark.sql.execution.streaming._
import org.apache.spark.sql.functions._
import org.apache.spark.sql.internal.SQLConf
Expand Down Expand Up @@ -220,16 +221,16 @@ class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging wi

// getEndOffset should take 100 ms the first time it is called after data is added
override def getEndOffset(): OffsetV2 = synchronized {
if (currentOffset.offset != -1) { // no data available
clock.waitTillTime(1150)
}
if (dataAdded) clock.waitTillTime(1150)
super.getEndOffset()
}

// getBatch should take 100 ms the first time it is called
override def createDataReaderFactories(): ju.List[DataReaderFactory[Row]] = synchronized {
clock.waitTillTime(1350)
super.createDataReaderFactories()
override def createUnsafeRowReaderFactories(): ju.List[DataReaderFactory[UnsafeRow]] = {
synchronized {
clock.waitTillTime(1350)
super.createUnsafeRowReaderFactories()
}
}
}

Expand Down

0 comments on commit c713048

Please sign in to comment.