Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-23092][SQL] Migrate MemoryStream to DataSourceV2 APIs #20445

Closed
wants to merge 20 commits into from
Closed
Show file tree
Hide file tree
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,9 @@ import org.apache.spark.sql.sources.v2.reader._
trait DataSourceReaderHolder {

/**
* The full output of the data source reader, without column pruning.
* The output of the data source reader, without column pruning.
*/
def fullOutput: Seq[AttributeReference]
Copy link
Contributor Author

@tdas tdas Jan 31, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@cloud-fan This fixes the bug I spoke to you offline about.
The target of this PR is only master, not 2.3.x. So if you want to have this fix in 2.3.0, please make a separate PR accordingly.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If this PR has to be merged to 2.3.0 branch does it require more additional changes?

def output: Seq[Attribute]

/**
* The held data source reader.
Expand All @@ -46,7 +46,7 @@ trait DataSourceReaderHolder {
case s: SupportsPushDownFilters => s.pushedFilters().toSet
case _ => Nil
}
Seq(fullOutput, reader.getClass, reader.readSchema(), filters)
Seq(output, reader.getClass, reader.readSchema(), filters)
}

def canEqual(other: Any): Boolean
Expand All @@ -61,8 +61,4 @@ trait DataSourceReaderHolder {
override def hashCode(): Int = {
metadata.map(Objects.hashCode).foldLeft(0)((a, b) => 31 * a + b)
}

lazy val output: Seq[Attribute] = reader.readSchema().map(_.name).map { name =>
fullOutput.find(_.name == name).get
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,12 @@

package org.apache.spark.sql.execution.datasources.v2

import org.apache.spark.sql.catalyst.expressions.AttributeReference
import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, Statistics}
import org.apache.spark.sql.sources.v2.reader._

case class DataSourceV2Relation(
fullOutput: Seq[AttributeReference],
output: Seq[Attribute],
reader: DataSourceV2Reader) extends LeafNode with DataSourceReaderHolder {

override def canEqual(other: Any): Boolean = other.isInstanceOf[DataSourceV2Relation]
Expand All @@ -40,8 +40,8 @@ case class DataSourceV2Relation(
* to the non-streaming relation.
*/
class StreamingDataSourceV2Relation(
fullOutput: Seq[AttributeReference],
reader: DataSourceV2Reader) extends DataSourceV2Relation(fullOutput, reader) {
output: Seq[Attribute],
reader: DataSourceV2Reader) extends DataSourceV2Relation(output, reader) {
override def isStreaming: Boolean = true
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,14 +35,12 @@ import org.apache.spark.sql.types.StructType
* Physical plan node for scanning data from a data source.
*/
case class DataSourceV2ScanExec(
fullOutput: Seq[AttributeReference],
override val output: Seq[Attribute],
@transient reader: DataSourceV2Reader)
extends LeafExecNode with DataSourceReaderHolder with ColumnarBatchScan {

override def canEqual(other: Any): Boolean = other.isInstanceOf[DataSourceV2ScanExec]

override def producedAttributes: AttributeSet = AttributeSet(fullOutput)

override def outputPartitioning: physical.Partitioning = reader match {
case s: SupportsReportPartitioning =>
new DataSourcePartitioning(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,12 @@

package org.apache.spark.sql.execution.streaming

import org.apache.spark.sql.sources.v2.streaming.reader.{Offset => OffsetV2}

/**
* A simple offset for sources that produce a single linear stream of data.
*/
case class LongOffset(offset: Long) extends Offset {
case class LongOffset(offset: Long) extends OffsetV2 {

override val json = offset.toString

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -268,16 +268,17 @@ class MicroBatchExecution(
}
case s: MicroBatchReader =>
updateStatusMessage(s"Getting offsets from $s")
reportTimeTaken("getOffset") {
// Once v1 streaming source execution is gone, we can refactor this away.
// For now, we set the range here to get the source to infer the available end offset,
// get that offset, and then set the range again when we later execute.
s.setOffsetRange(
toJava(availableOffsets.get(s).map(off => s.deserializeOffset(off.json))),
Optional.empty())

(s, Some(s.getEndOffset))
reportTimeTaken("setOffsetRange") {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree that the old metric names don't make much sense anymore, but I worry about changing external-facing behavior as part of an API migration.

// Once v1 streaming source execution is gone, we can refactor this away.
// For now, we set the range here to get the source to infer the available end offset,
// get that offset, and then set the range again when we later execute.
s.setOffsetRange(
toJava(availableOffsets.get(s).map(off => s.deserializeOffset(off.json))),
Optional.empty())
}

val currentOffset = reportTimeTaken("getEndOffset") { s.getEndOffset() }
(s, Option(currentOffset))
}.toMap
availableOffsets ++= latestOffsets.filter { case (_, o) => o.nonEmpty }.mapValues(_.get)

Expand Down Expand Up @@ -399,10 +400,14 @@ class MicroBatchExecution(
case (reader: MicroBatchReader, available)
if committedOffsets.get(reader).map(_ != available).getOrElse(true) =>
val current = committedOffsets.get(reader).map(off => reader.deserializeOffset(off.json))
val availableV2: OffsetV2 = available match {
case v1: SerializedOffset => reader.deserializeOffset(v1.json)
case v2: OffsetV2 => v2
}
reader.setOffsetRange(
toJava(current),
Optional.of(available.asInstanceOf[OffsetV2]))
logDebug(s"Retrieving data from $reader: $current -> $available")
Optional.of(availableV2))
logDebug(s"Retrieving data from $reader: $current -> $availableV2")
Some(reader ->
new StreamingDataSourceV2Relation(reader.readSchema().toAttributes, reader))
case _ => None
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,12 @@

package org.apache.spark.sql.execution.streaming

import java.{util => ju}
import java.util.Optional
import java.util.concurrent.atomic.AtomicInteger
import javax.annotation.concurrent.GuardedBy

import scala.collection.JavaConverters._
import scala.collection.mutable
import scala.collection.mutable.{ArrayBuffer, ListBuffer}
import scala.util.control.NonFatal

Expand All @@ -31,7 +32,8 @@ import org.apache.spark.sql.catalyst.encoders.encoderFor
import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LocalRelation, Statistics}
import org.apache.spark.sql.catalyst.streaming.InternalOutputModes._
import org.apache.spark.sql.execution.SQLExecution
import org.apache.spark.sql.sources.v2.reader.{DataReader, ReadTask}
import org.apache.spark.sql.sources.v2.streaming.reader.{MicroBatchReader, Offset => OffsetV2}
import org.apache.spark.sql.streaming.OutputMode
import org.apache.spark.sql.types.StructType
import org.apache.spark.util.Utils
Expand All @@ -51,9 +53,10 @@ object MemoryStream {
* available.
*/
case class MemoryStream[A : Encoder](id: Int, sqlContext: SQLContext)
extends Source with Logging {
extends MicroBatchReader with Logging {
protected val encoder = encoderFor[A]
protected val logicalPlan = StreamingExecutionRelation(this, sqlContext.sparkSession)
private val attributes = encoder.schema.toAttributes
protected val logicalPlan = StreamingExecutionRelation(this, attributes)(sqlContext.sparkSession)
protected val output = logicalPlan.output

/**
Expand All @@ -66,15 +69,19 @@ case class MemoryStream[A : Encoder](id: Int, sqlContext: SQLContext)
@GuardedBy("this")
protected var currentOffset: LongOffset = new LongOffset(-1)

@GuardedBy("this")
private var startOffset = new LongOffset(-1)

@GuardedBy("this")
private var endOffset = new LongOffset(-1)

/**
* Last offset that was discarded, or -1 if no commits have occurred. Note that the value
* -1 is used in calculations below and isn't just an arbitrary constant.
*/
@GuardedBy("this")
protected var lastOffsetCommitted : LongOffset = new LongOffset(-1)

def schema: StructType = encoder.schema

def toDS(): Dataset[A] = {
Dataset(sqlContext.sparkSession, logicalPlan)
}
Expand All @@ -89,7 +96,7 @@ case class MemoryStream[A : Encoder](id: Int, sqlContext: SQLContext)

def addData(data: TraversableOnce[A]): Offset = {
val encoded = data.toVector.map(d => encoder.toRow(d).copy())
val plan = new LocalRelation(schema.toAttributes, encoded, isStreaming = true)
val plan = new LocalRelation(attributes, encoded, isStreaming = false)
val ds = Dataset[A](sqlContext.sparkSession, plan)
logDebug(s"Adding ds: $ds")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we still need to store the batches as datasets, now that we're just collect()ing them back out in createDataReaderFactories()?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point.

this.synchronized {
Expand All @@ -101,19 +108,29 @@ case class MemoryStream[A : Encoder](id: Int, sqlContext: SQLContext)

override def toString: String = s"MemoryStream[${Utils.truncatedString(output, ",")}]"

override def getOffset: Option[Offset] = synchronized {
if (currentOffset.offset == -1) {
None
} else {
Some(currentOffset)
override def setOffsetRange(start: Optional[OffsetV2], end: Optional[OffsetV2]): Unit = {
synchronized {
startOffset = start.orElse(LongOffset(-1)).asInstanceOf[LongOffset]
endOffset = end.orElse(currentOffset).asInstanceOf[LongOffset]
}
}

override def getBatch(start: Option[Offset], end: Offset): DataFrame = {
override def readSchema(): StructType = encoder.schema

override def deserializeOffset(json: String): OffsetV2 = LongOffset(json.toLong)

override def getStartOffset: OffsetV2 = synchronized {
if (startOffset.offset == -1) null else startOffset
}

override def getEndOffset: OffsetV2 = synchronized {
if (endOffset.offset == -1) null else endOffset
}

override def createReadTasks(): ju.List[ReadTask[Row]] = synchronized {
// Compute the internal batch numbers to fetch: [startOrdinal, endOrdinal)
val startOrdinal =
start.flatMap(LongOffset.convert).getOrElse(LongOffset(-1)).offset.toInt + 1
val endOrdinal = LongOffset.convert(end).getOrElse(LongOffset(-1)).offset.toInt + 1
val startOrdinal = startOffset.offset.toInt + 1
val endOrdinal = endOffset.offset.toInt + 1

// Internal buffer only holds the batches after lastCommittedOffset.
val newBlocks = synchronized {
Expand All @@ -123,19 +140,12 @@ case class MemoryStream[A : Encoder](id: Int, sqlContext: SQLContext)
batches.slice(sliceStart, sliceEnd)
}

if (newBlocks.isEmpty) {
return sqlContext.internalCreateDataFrame(
sqlContext.sparkContext.emptyRDD, schema, isStreaming = true)
}

logDebug(generateDebugString(newBlocks, startOrdinal, endOrdinal))

newBlocks
.map(_.toDF())
.reduceOption(_ union _)
.getOrElse {
sys.error("No data selected!")
}
newBlocks.map { ds =>
val items = ds.toDF().collect()
new MemoryStreamReadTask(items).asInstanceOf[ReadTask[Row]]
}.asJava
}

private def generateDebugString(
Expand All @@ -153,7 +163,7 @@ case class MemoryStream[A : Encoder](id: Int, sqlContext: SQLContext)
}
}

override def commit(end: Offset): Unit = synchronized {
override def commit(end: OffsetV2): Unit = synchronized {
def check(newOffset: LongOffset): Unit = {
val offsetDiff = (newOffset.offset - lastOffsetCommitted.offset).toInt

Expand All @@ -176,11 +186,31 @@ case class MemoryStream[A : Encoder](id: Int, sqlContext: SQLContext)

def reset(): Unit = synchronized {
batches.clear()
startOffset = LongOffset(-1)
endOffset = LongOffset(-1)
currentOffset = new LongOffset(-1)
lastOffsetCommitted = new LongOffset(-1)
}
}

class MemoryStreamReadTask(records: Array[Row]) extends ReadTask[Row] {
override def createDataReader(): DataReader[Row] = new MemoryStreamDataReader(records)
}

class MemoryStreamDataReader(records: Array[Row]) extends DataReader[Row] {
private var currentIndex = -1

override def next(): Boolean = {
// Return true as long as the new index is in the array.
currentIndex += 1
currentIndex < records.length
}

override def get(): Row = records(currentIndex)

override def close(): Unit = {}
}

/**
* A sink that stores the results in memory. This [[Sink]] is primarily intended for use in unit
* tests and does not provide durability.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ case class RateStreamBatchTask(vals: Seq[(Long, Long)]) extends ReadTask[Row] {
}

class RateStreamBatchReader(vals: Seq[(Long, Long)]) extends DataReader[Row] {
var currentIndex = -1
private var currentIndex = -1

override def next(): Boolean = {
// Return true as long as the new index is in the seq.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -492,16 +492,16 @@ class StreamSuite extends StreamTest {

val explainWithoutExtended = q.explainInternal(false)
// `extended = false` only displays the physical plan.
assert("LocalRelation".r.findAllMatchIn(explainWithoutExtended).size === 0)
assert("LocalTableScan".r.findAllMatchIn(explainWithoutExtended).size === 1)
assert("StreamingDataSourceV2Relation".r.findAllMatchIn(explainWithoutExtended).size === 0)
assert("DataSourceV2Scan".r.findAllMatchIn(explainWithoutExtended).size === 1)
// Use "StateStoreRestore" to verify that it does output a streaming physical plan
assert(explainWithoutExtended.contains("StateStoreRestore"))

val explainWithExtended = q.explainInternal(true)
// `extended = true` displays 3 logical plans (Parsed/Optimized/Optimized) and 1 physical
// plan.
assert("LocalRelation".r.findAllMatchIn(explainWithExtended).size === 3)
assert("LocalTableScan".r.findAllMatchIn(explainWithExtended).size === 1)
assert("StreamingDataSourceV2Relation".r.findAllMatchIn(explainWithExtended).size === 3)
assert("DataSourceV2Scan".r.findAllMatchIn(explainWithExtended).size === 1)
// Use "StateStoreRestore" to verify that it does output a streaming physical plan
assert(explainWithExtended.contains("StateStoreRestore"))
} finally {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be
case class AddDataMemory[A](source: MemoryStream[A], data: Seq[A]) extends AddData {
override def toString: String = s"AddData to $source: ${data.mkString(",")}"

override def addData(query: Option[StreamExecution]): (Source, Offset) = {
override def addData(query: Option[StreamExecution]): (BaseStreamingSource, Offset) = {
(source, source.addData(data))
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ import org.apache.spark.scheduler._
import org.apache.spark.sql.{Encoder, SparkSession}
import org.apache.spark.sql.execution.streaming._
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.sources.v2.streaming.reader.{Offset => OffsetV2}
import org.apache.spark.sql.streaming.StreamingQueryListener._
import org.apache.spark.sql.streaming.util.StreamManualClock
import org.apache.spark.util.JsonProtocol
Expand Down Expand Up @@ -298,9 +299,9 @@ class StreamingQueryListenerSuite extends StreamTest with BeforeAndAfter {
try {
val input = new MemoryStream[Int](0, sqlContext) {
@volatile var numTriggers = 0
override def getOffset: Option[Offset] = {
override def getEndOffset: OffsetV2 = {
numTriggers += 1
super.getOffset
super.getEndOffset
}
}
val clock = new StreamManualClock()
Expand Down
Loading