-
Notifications
You must be signed in to change notification settings - Fork 28.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[SPARK-23092][SQL] Migrate MemoryStream to DataSourceV2 APIs #20445
Changes from 13 commits
7c09b37
78c50f8
2777b5b
50a541b
fd61724
7a0b564
a81c2ec
1a4f410
083e93c
a817c8d
35b8854
e66d809
5adf1fe
478ad17
6389d80
3f50f33
c713048
1204755
f0ce5df
c3508e9
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -270,16 +270,17 @@ class MicroBatchExecution( | |
} | ||
case s: MicroBatchReader => | ||
updateStatusMessage(s"Getting offsets from $s") | ||
reportTimeTaken("getOffset") { | ||
// Once v1 streaming source execution is gone, we can refactor this away. | ||
// For now, we set the range here to get the source to infer the available end offset, | ||
// get that offset, and then set the range again when we later execute. | ||
s.setOffsetRange( | ||
toJava(availableOffsets.get(s).map(off => s.deserializeOffset(off.json))), | ||
Optional.empty()) | ||
|
||
(s, Some(s.getEndOffset)) | ||
reportTimeTaken("setOffsetRange") { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I agree that the old metric names don't make much sense anymore, but I worry about changing external-facing behavior as part of an API migration. |
||
// Once v1 streaming source execution is gone, we can refactor this away. | ||
// For now, we set the range here to get the source to infer the available end offset, | ||
// get that offset, and then set the range again when we later execute. | ||
s.setOffsetRange( | ||
toJava(availableOffsets.get(s).map(off => s.deserializeOffset(off.json))), | ||
Optional.empty()) | ||
} | ||
|
||
val currentOffset = reportTimeTaken("getEndOffset") { s.getEndOffset() } | ||
(s, Option(currentOffset)) | ||
}.toMap | ||
availableOffsets ++= latestOffsets.filter { case (_, o) => o.nonEmpty }.mapValues(_.get) | ||
|
||
|
@@ -401,10 +402,14 @@ class MicroBatchExecution( | |
case (reader: MicroBatchReader, available) | ||
if committedOffsets.get(reader).map(_ != available).getOrElse(true) => | ||
val current = committedOffsets.get(reader).map(off => reader.deserializeOffset(off.json)) | ||
val availableV2: OffsetV2 = available match { | ||
case v1: SerializedOffset => reader.deserializeOffset(v1.json) | ||
case v2: OffsetV2 => v2 | ||
} | ||
reader.setOffsetRange( | ||
toJava(current), | ||
Optional.of(available.asInstanceOf[OffsetV2])) | ||
logDebug(s"Retrieving data from $reader: $current -> $available") | ||
Optional.of(availableV2)) | ||
logDebug(s"Retrieving data from $reader: $current -> $availableV2") | ||
Some(reader -> | ||
new StreamingDataSourceV2Relation(reader.readSchema().toAttributes, reader)) | ||
case _ => None | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -17,11 +17,12 @@ | |
|
||
package org.apache.spark.sql.execution.streaming | ||
|
||
import java.{util => ju} | ||
import java.util.Optional | ||
import java.util.concurrent.atomic.AtomicInteger | ||
import javax.annotation.concurrent.GuardedBy | ||
|
||
import scala.collection.JavaConverters._ | ||
import scala.collection.mutable | ||
import scala.collection.mutable.{ArrayBuffer, ListBuffer} | ||
import scala.util.control.NonFatal | ||
|
||
|
@@ -31,7 +32,8 @@ import org.apache.spark.sql.catalyst.encoders.encoderFor | |
import org.apache.spark.sql.catalyst.expressions.Attribute | ||
import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LocalRelation, Statistics} | ||
import org.apache.spark.sql.catalyst.streaming.InternalOutputModes._ | ||
import org.apache.spark.sql.execution.SQLExecution | ||
import org.apache.spark.sql.sources.v2.reader.{DataReader, DataReaderFactory} | ||
import org.apache.spark.sql.sources.v2.streaming.reader.{MicroBatchReader, Offset => OffsetV2} | ||
import org.apache.spark.sql.streaming.OutputMode | ||
import org.apache.spark.sql.types.StructType | ||
import org.apache.spark.util.Utils | ||
|
@@ -51,9 +53,10 @@ object MemoryStream { | |
* available. | ||
*/ | ||
case class MemoryStream[A : Encoder](id: Int, sqlContext: SQLContext) | ||
extends Source with Logging { | ||
extends MicroBatchReader with Logging { | ||
protected val encoder = encoderFor[A] | ||
protected val logicalPlan = StreamingExecutionRelation(this, sqlContext.sparkSession) | ||
private val attributes = encoder.schema.toAttributes | ||
protected val logicalPlan = StreamingExecutionRelation(this, attributes)(sqlContext.sparkSession) | ||
protected val output = logicalPlan.output | ||
|
||
/** | ||
|
@@ -66,15 +69,19 @@ case class MemoryStream[A : Encoder](id: Int, sqlContext: SQLContext) | |
@GuardedBy("this") | ||
protected var currentOffset: LongOffset = new LongOffset(-1) | ||
|
||
@GuardedBy("this") | ||
private var startOffset = new LongOffset(-1) | ||
|
||
@GuardedBy("this") | ||
private var endOffset = new LongOffset(-1) | ||
|
||
/** | ||
* Last offset that was discarded, or -1 if no commits have occurred. Note that the value | ||
* -1 is used in calculations below and isn't just an arbitrary constant. | ||
*/ | ||
@GuardedBy("this") | ||
protected var lastOffsetCommitted : LongOffset = new LongOffset(-1) | ||
|
||
def schema: StructType = encoder.schema | ||
|
||
def toDS(): Dataset[A] = { | ||
Dataset(sqlContext.sparkSession, logicalPlan) | ||
} | ||
|
@@ -89,7 +96,7 @@ case class MemoryStream[A : Encoder](id: Int, sqlContext: SQLContext) | |
|
||
def addData(data: TraversableOnce[A]): Offset = { | ||
val encoded = data.toVector.map(d => encoder.toRow(d).copy()) | ||
val plan = new LocalRelation(schema.toAttributes, encoded, isStreaming = true) | ||
val plan = new LocalRelation(attributes, encoded, isStreaming = false) | ||
val ds = Dataset[A](sqlContext.sparkSession, plan) | ||
logDebug(s"Adding ds: $ds") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do we still need to store the batches as datasets, now that we're just collect()ing them back out in createDataReaderFactories()? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Good point. |
||
this.synchronized { | ||
|
@@ -101,19 +108,29 @@ case class MemoryStream[A : Encoder](id: Int, sqlContext: SQLContext) | |
|
||
override def toString: String = s"MemoryStream[${Utils.truncatedString(output, ",")}]" | ||
|
||
override def getOffset: Option[Offset] = synchronized { | ||
if (currentOffset.offset == -1) { | ||
None | ||
} else { | ||
Some(currentOffset) | ||
override def setOffsetRange(start: Optional[OffsetV2], end: Optional[OffsetV2]): Unit = { | ||
synchronized { | ||
startOffset = start.orElse(LongOffset(-1)).asInstanceOf[LongOffset] | ||
endOffset = end.orElse(currentOffset).asInstanceOf[LongOffset] | ||
} | ||
} | ||
|
||
override def getBatch(start: Option[Offset], end: Offset): DataFrame = { | ||
override def readSchema(): StructType = encoder.schema | ||
|
||
override def deserializeOffset(json: String): OffsetV2 = LongOffset(json.toLong) | ||
|
||
override def getStartOffset: OffsetV2 = synchronized { | ||
if (startOffset.offset == -1) null else startOffset | ||
} | ||
|
||
override def getEndOffset: OffsetV2 = synchronized { | ||
if (endOffset.offset == -1) null else endOffset | ||
} | ||
|
||
override def createDataReaderFactories(): ju.List[DataReaderFactory[Row]] = synchronized { | ||
// Compute the internal batch numbers to fetch: [startOrdinal, endOrdinal) | ||
val startOrdinal = | ||
start.flatMap(LongOffset.convert).getOrElse(LongOffset(-1)).offset.toInt + 1 | ||
val endOrdinal = LongOffset.convert(end).getOrElse(LongOffset(-1)).offset.toInt + 1 | ||
val startOrdinal = startOffset.offset.toInt + 1 | ||
val endOrdinal = endOffset.offset.toInt + 1 | ||
|
||
// Internal buffer only holds the batches after lastCommittedOffset. | ||
val newBlocks = synchronized { | ||
|
@@ -123,19 +140,11 @@ case class MemoryStream[A : Encoder](id: Int, sqlContext: SQLContext) | |
batches.slice(sliceStart, sliceEnd) | ||
} | ||
|
||
if (newBlocks.isEmpty) { | ||
return sqlContext.internalCreateDataFrame( | ||
sqlContext.sparkContext.emptyRDD, schema, isStreaming = true) | ||
} | ||
|
||
logDebug(generateDebugString(newBlocks, startOrdinal, endOrdinal)) | ||
|
||
newBlocks | ||
.map(_.toDF()) | ||
.reduceOption(_ union _) | ||
.getOrElse { | ||
sys.error("No data selected!") | ||
} | ||
newBlocks.map { ds => | ||
new MemoryStreamDataReaderFactory(ds.toDF().collect()).asInstanceOf[DataReaderFactory[Row]] | ||
}.asJava | ||
} | ||
|
||
private def generateDebugString( | ||
|
@@ -153,7 +162,7 @@ case class MemoryStream[A : Encoder](id: Int, sqlContext: SQLContext) | |
} | ||
} | ||
|
||
override def commit(end: Offset): Unit = synchronized { | ||
override def commit(end: OffsetV2): Unit = synchronized { | ||
def check(newOffset: LongOffset): Unit = { | ||
val offsetDiff = (newOffset.offset - lastOffsetCommitted.offset).toInt | ||
|
||
|
@@ -176,11 +185,32 @@ case class MemoryStream[A : Encoder](id: Int, sqlContext: SQLContext) | |
|
||
def reset(): Unit = synchronized { | ||
batches.clear() | ||
startOffset = LongOffset(-1) | ||
endOffset = LongOffset(-1) | ||
currentOffset = new LongOffset(-1) | ||
lastOffsetCommitted = new LongOffset(-1) | ||
} | ||
} | ||
|
||
|
||
class MemoryStreamDataReaderFactory(records: Array[Row]) extends DataReaderFactory[Row] { | ||
override def createDataReader(): DataReader[Row] = { | ||
new DataReader[Row] { | ||
private var currentIndex = -1 | ||
|
||
override def next(): Boolean = { | ||
// Return true as long as the new index is in the array. | ||
currentIndex += 1 | ||
currentIndex < records.length | ||
} | ||
|
||
override def get(): Row = records(currentIndex) | ||
|
||
override def close(): Unit = {} | ||
} | ||
} | ||
} | ||
|
||
/** | ||
* A sink that stores the results in memory. This [[Sink]] is primarily intended for use in unit | ||
* tests and does not provide durability. | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@cloud-fan This fixes the bug I spoke to you offline about.
The target of this PR is only master, not 2.3.x. So if you want to have this fix in 2.3.0, please make a separate PR accordingly.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If this PR has to be merged to 2.3.0 branch does it require more additional changes?