Skip to content

Commit

Permalink
[SPARK-22912] v2 data source support in MicroBatchExecution
Browse files Browse the repository at this point in the history
## What changes were proposed in this pull request?

Support for v2 data sources in microbatch streaming.

## How was this patch tested?

A very basic new unit test on the toy v2 implementation of rate source. Once we have a v1 source fully migrated to v2, we'll need to do more detailed compatibility testing.

Author: Jose Torres <jose@databricks.com>

Closes #20097 from jose-torres/v2-impl.
  • Loading branch information
jose-torres authored and tdas committed Jan 8, 2018
1 parent eed82a0 commit 4f7e758
Show file tree
Hide file tree
Showing 15 changed files with 241 additions and 79 deletions.
Expand Up @@ -7,3 +7,4 @@ org.apache.spark.sql.execution.datasources.text.TextFileFormat
org.apache.spark.sql.execution.streaming.ConsoleSinkProvider
org.apache.spark.sql.execution.streaming.TextSocketSourceProvider
org.apache.spark.sql.execution.streaming.RateSourceProvider
org.apache.spark.sql.execution.streaming.sources.RateSourceProviderV2
Expand Up @@ -35,6 +35,16 @@ case class DataSourceV2Relation(
}
}

/**
* A specialization of DataSourceV2Relation with the streaming bit set to true. Otherwise identical
* to the non-streaming relation.
*/
class StreamingDataSourceV2Relation(
fullOutput: Seq[AttributeReference],
reader: DataSourceV2Reader) extends DataSourceV2Relation(fullOutput, reader) {
override def isStreaming: Boolean = true
}

object DataSourceV2Relation {
def apply(reader: DataSourceV2Reader): DataSourceV2Relation = {
new DataSourceV2Relation(reader.readSchema().toAttributes, reader)
Expand Down
Expand Up @@ -17,14 +17,20 @@

package org.apache.spark.sql.execution.streaming

import java.util.Optional

import scala.collection.JavaConverters._
import scala.collection.mutable.{ArrayBuffer, Map => MutableMap}

import org.apache.spark.sql.{Dataset, SparkSession}
import org.apache.spark.sql.catalyst.encoders.RowEncoder
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, CurrentBatchTimestamp, CurrentDate, CurrentTimestamp}
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
import org.apache.spark.sql.execution.SQLExecution
import org.apache.spark.sql.sources.v2.streaming.MicroBatchReadSupport
import org.apache.spark.sql.execution.datasources.v2.{StreamingDataSourceV2Relation, WriteToDataSourceV2}
import org.apache.spark.sql.sources.v2.DataSourceV2Options
import org.apache.spark.sql.sources.v2.streaming.{MicroBatchReadSupport, MicroBatchWriteSupport}
import org.apache.spark.sql.sources.v2.streaming.reader.{MicroBatchReader, Offset => OffsetV2}
import org.apache.spark.sql.streaming.{OutputMode, ProcessingTime, Trigger}
import org.apache.spark.util.{Clock, Utils}

Expand All @@ -33,10 +39,11 @@ class MicroBatchExecution(
name: String,
checkpointRoot: String,
analyzedPlan: LogicalPlan,
sink: Sink,
sink: BaseStreamingSink,
trigger: Trigger,
triggerClock: Clock,
outputMode: OutputMode,
extraOptions: Map[String, String],
deleteCheckpointOnStop: Boolean)
extends StreamExecution(
sparkSession, name, checkpointRoot, analyzedPlan, sink,
Expand All @@ -57,26 +64,40 @@ class MicroBatchExecution(
var nextSourceId = 0L
val toExecutionRelationMap = MutableMap[StreamingRelation, StreamingExecutionRelation]()
val v2ToExecutionRelationMap = MutableMap[StreamingRelationV2, StreamingExecutionRelation]()
// We transform each distinct streaming relation into a StreamingExecutionRelation, keeping a
// map as we go to ensure each identical relation gets the same StreamingExecutionRelation
// object. For each microbatch, the StreamingExecutionRelation will be replaced with a logical
// plan for the data within that batch.
// Note that we have to use the previous `output` as attributes in StreamingExecutionRelation,
// since the existing logical plan has already used those attributes. The per-microbatch
// transformation is responsible for replacing attributes with their final values.
val _logicalPlan = analyzedPlan.transform {
case streamingRelation@StreamingRelation(dataSource, _, output) =>
toExecutionRelationMap.getOrElseUpdate(streamingRelation, {
// Materialize source to avoid creating it in every batch
val metadataPath = s"$resolvedCheckpointRoot/sources/$nextSourceId"
val source = dataSource.createSource(metadataPath)
nextSourceId += 1
// We still need to use the previous `output` instead of `source.schema` as attributes in
// "df.logicalPlan" has already used attributes of the previous `output`.
StreamingExecutionRelation(source, output)(sparkSession)
})
case s @ StreamingRelationV2(v2DataSource, _, _, output, v1DataSource)
if !v2DataSource.isInstanceOf[MicroBatchReadSupport] =>
case s @ StreamingRelationV2(source: MicroBatchReadSupport, _, options, output, _) =>
v2ToExecutionRelationMap.getOrElseUpdate(s, {
// Materialize source to avoid creating it in every batch
val metadataPath = s"$resolvedCheckpointRoot/sources/$nextSourceId"
val reader = source.createMicroBatchReader(
Optional.empty(), // user specified schema
metadataPath,
new DataSourceV2Options(options.asJava))
nextSourceId += 1
StreamingExecutionRelation(reader, output)(sparkSession)
})
case s @ StreamingRelationV2(_, _, _, output, v1Relation) =>
v2ToExecutionRelationMap.getOrElseUpdate(s, {
// Materialize source to avoid creating it in every batch
val metadataPath = s"$resolvedCheckpointRoot/sources/$nextSourceId"
val source = v1DataSource.createSource(metadataPath)
assert(v1Relation.isDefined, "v2 execution didn't match but v1 was unavailable")
val source = v1Relation.get.dataSource.createSource(metadataPath)
nextSourceId += 1
// We still need to use the previous `output` instead of `source.schema` as attributes in
// "df.logicalPlan" has already used attributes of the previous `output`.
StreamingExecutionRelation(source, output)(sparkSession)
})
}
Expand Down Expand Up @@ -192,7 +213,8 @@ class MicroBatchExecution(
source.getBatch(start, end)
}
case nonV1Tuple =>
throw new IllegalStateException(s"Unexpected V2 source in $nonV1Tuple")
// The V2 API does not have the same edge case requiring getBatch to be called
// here, so we do nothing here.
}
currentBatchId = latestCommittedBatchId + 1
committedOffsets ++= availableOffsets
Expand Down Expand Up @@ -236,14 +258,27 @@ class MicroBatchExecution(
val hasNewData = {
awaitProgressLock.lock()
try {
val latestOffsets: Map[Source, Option[Offset]] = uniqueSources.map {
// Generate a map from each unique source to the next available offset.
val latestOffsets: Map[BaseStreamingSource, Option[Offset]] = uniqueSources.map {
case s: Source =>
updateStatusMessage(s"Getting offsets from $s")
reportTimeTaken("getOffset") {
(s, s.getOffset)
}
case s: MicroBatchReader =>
updateStatusMessage(s"Getting offsets from $s")
reportTimeTaken("getOffset") {
// Once v1 streaming source execution is gone, we can refactor this away.
// For now, we set the range here to get the source to infer the available end offset,
// get that offset, and then set the range again when we later execute.
s.setOffsetRange(
toJava(availableOffsets.get(s).map(off => s.deserializeOffset(off.json))),
Optional.empty())

(s, Some(s.getEndOffset))
}
}.toMap
availableOffsets ++= latestOffsets.filter { case (s, o) => o.nonEmpty }.mapValues(_.get)
availableOffsets ++= latestOffsets.filter { case (_, o) => o.nonEmpty }.mapValues(_.get)

if (dataAvailable) {
true
Expand Down Expand Up @@ -317,6 +352,8 @@ class MicroBatchExecution(
if (prevBatchOff.isDefined) {
prevBatchOff.get.toStreamProgress(sources).foreach {
case (src: Source, off) => src.commit(off)
case (reader: MicroBatchReader, off) =>
reader.commit(reader.deserializeOffset(off.json))
}
} else {
throw new IllegalStateException(s"batch $currentBatchId doesn't exist")
Expand Down Expand Up @@ -357,31 +394,39 @@ class MicroBatchExecution(
s"DataFrame returned by getBatch from $source did not have isStreaming=true\n" +
s"${batch.queryExecution.logical}")
logDebug(s"Retrieving data from $source: $current -> $available")
Some(source -> batch)
Some(source -> batch.logicalPlan)
case (reader: MicroBatchReader, available)
if committedOffsets.get(reader).map(_ != available).getOrElse(true) =>
val current = committedOffsets.get(reader).map(off => reader.deserializeOffset(off.json))
reader.setOffsetRange(
toJava(current),
Optional.of(available.asInstanceOf[OffsetV2]))
logDebug(s"Retrieving data from $reader: $current -> $available")
Some(reader ->
new StreamingDataSourceV2Relation(reader.readSchema().toAttributes, reader))
case _ => None
}
}

// A list of attributes that will need to be updated.
val replacements = new ArrayBuffer[(Attribute, Attribute)]
// Replace sources in the logical plan with data that has arrived since the last batch.
val withNewSources = logicalPlan transform {
val newBatchesPlan = logicalPlan transform {
case StreamingExecutionRelation(source, output) =>
newData.get(source).map { data =>
val newPlan = data.logicalPlan
assert(output.size == newPlan.output.size,
newData.get(source).map { dataPlan =>
assert(output.size == dataPlan.output.size,
s"Invalid batch: ${Utils.truncatedString(output, ",")} != " +
s"${Utils.truncatedString(newPlan.output, ",")}")
replacements ++= output.zip(newPlan.output)
newPlan
s"${Utils.truncatedString(dataPlan.output, ",")}")
replacements ++= output.zip(dataPlan.output)
dataPlan
}.getOrElse {
LocalRelation(output, isStreaming = true)
}
}

// Rewire the plan to use the new attributes that were returned by the source.
val replacementMap = AttributeMap(replacements)
val triggerLogicalPlan = withNewSources transformAllExpressions {
val newAttributePlan = newBatchesPlan transformAllExpressions {
case a: Attribute if replacementMap.contains(a) =>
replacementMap(a).withMetadata(a.metadata)
case ct: CurrentTimestamp =>
Expand All @@ -392,6 +437,20 @@ class MicroBatchExecution(
cd.dataType, cd.timeZoneId)
}

val triggerLogicalPlan = sink match {
case _: Sink => newAttributePlan
case s: MicroBatchWriteSupport =>
val writer = s.createMicroBatchWriter(
s"$runId",
currentBatchId,
newAttributePlan.schema,
outputMode,
new DataSourceV2Options(extraOptions.asJava))
assert(writer.isPresent, "microbatch writer must always be present")
WriteToDataSourceV2(writer.get, newAttributePlan)
case _ => throw new IllegalArgumentException(s"unknown sink type for $sink")
}

reportTimeTaken("queryPlanning") {
lastExecution = new IncrementalExecution(
sparkSessionToRunBatch,
Expand All @@ -409,7 +468,12 @@ class MicroBatchExecution(

reportTimeTaken("addBatch") {
SQLExecution.withNewExecutionId(sparkSessionToRunBatch, lastExecution) {
sink.addBatch(currentBatchId, nextBatch)
sink match {
case s: Sink => s.addBatch(currentBatchId, nextBatch)
case s: MicroBatchWriteSupport =>
// This doesn't accumulate any data - it just forces execution of the microbatch writer.
nextBatch.collect()
}
}
}

Expand All @@ -421,4 +485,8 @@ class MicroBatchExecution(
awaitProgressLock.unlock()
}
}

private def toJava(scalaOption: Option[OffsetV2]): Optional[OffsetV2] = {
Optional.ofNullable(scalaOption.orNull)
}
}
Expand Up @@ -53,7 +53,7 @@ trait ProgressReporter extends Logging {
protected def triggerClock: Clock
protected def logicalPlan: LogicalPlan
protected def lastExecution: QueryExecution
protected def newData: Map[BaseStreamingSource, DataFrame]
protected def newData: Map[BaseStreamingSource, LogicalPlan]
protected def availableOffsets: StreamProgress
protected def committedOffsets: StreamProgress
protected def sources: Seq[BaseStreamingSource]
Expand Down Expand Up @@ -225,8 +225,8 @@ trait ProgressReporter extends Logging {
//
// 3. For each source, we sum the metrics of the associated execution plan leaves.
//
val logicalPlanLeafToSource = newData.flatMap { case (source, df) =>
df.logicalPlan.collectLeaves().map { leaf => leaf -> source }
val logicalPlanLeafToSource = newData.flatMap { case (source, logicalPlan) =>
logicalPlan.collectLeaves().map { leaf => leaf -> source }
}
val allLogicalPlanLeaves = lastExecution.logical.collectLeaves() // includes non-streaming
val allExecPlanLeaves = lastExecution.executedPlan.collectLeaves()
Expand Down
Expand Up @@ -29,12 +29,12 @@ import org.apache.spark.network.util.JavaUtils
import org.apache.spark.sql.{AnalysisException, DataFrame, SQLContext}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, DateTimeUtils}
import org.apache.spark.sql.execution.streaming.continuous.ContinuousRateStreamReader
import org.apache.spark.sql.execution.streaming.sources.RateStreamV2Reader
import org.apache.spark.sql.execution.streaming.continuous.RateStreamContinuousReader
import org.apache.spark.sql.execution.streaming.sources.RateStreamMicroBatchReader
import org.apache.spark.sql.sources.{DataSourceRegister, StreamSourceProvider}
import org.apache.spark.sql.sources.v2._
import org.apache.spark.sql.sources.v2.streaming.ContinuousReadSupport
import org.apache.spark.sql.sources.v2.streaming.reader.ContinuousReader
import org.apache.spark.sql.sources.v2.streaming.{ContinuousReadSupport, MicroBatchReadSupport}
import org.apache.spark.sql.sources.v2.streaming.reader.{ContinuousReader, MicroBatchReader}
import org.apache.spark.sql.types._
import org.apache.spark.util.{ManualClock, SystemClock}

Expand Down Expand Up @@ -112,7 +112,7 @@ class RateSourceProvider extends StreamSourceProvider with DataSourceRegister
schema: Optional[StructType],
checkpointLocation: String,
options: DataSourceV2Options): ContinuousReader = {
new ContinuousRateStreamReader(options)
new RateStreamContinuousReader(options)
}

override def shortName(): String = "rate"
Expand Down
Expand Up @@ -163,7 +163,7 @@ abstract class StreamExecution(
var lastExecution: IncrementalExecution = _

/** Holds the most recent input data for each source. */
protected var newData: Map[BaseStreamingSource, DataFrame] = _
protected var newData: Map[BaseStreamingSource, LogicalPlan] = _

@volatile
protected var streamDeathCause: StreamingQueryException = null
Expand Down Expand Up @@ -418,7 +418,7 @@ abstract class StreamExecution(
* Blocks the current thread until processing for data from the given `source` has reached at
* least the given `Offset`. This method is intended for use primarily when writing tests.
*/
private[sql] def awaitOffset(source: Source, newOffset: Offset): Unit = {
private[sql] def awaitOffset(source: BaseStreamingSource, newOffset: Offset): Unit = {
assertAwaitThread()
def notDone = {
val localCommittedOffsets = committedOffsets
Expand Down
Expand Up @@ -61,7 +61,7 @@ case class StreamingRelation(dataSource: DataSource, sourceName: String, output:
* [[org.apache.spark.sql.catalyst.plans.logical.LogicalPlan]].
*/
case class StreamingExecutionRelation(
source: Source,
source: BaseStreamingSource,
output: Seq[Attribute])(session: SparkSession)
extends LeafNode {

Expand Down Expand Up @@ -92,7 +92,7 @@ case class StreamingRelationV2(
sourceName: String,
extraOptions: Map[String, String],
output: Seq[Attribute],
v1DataSource: DataSource)(session: SparkSession)
v1Relation: Option[StreamingRelation])(session: SparkSession)
extends LeafNode {
override def isStreaming: Boolean = true
override def toString: String = sourceName
Expand Down
Expand Up @@ -27,7 +27,7 @@ import org.apache.spark.sql.{AnalysisException, SparkSession}
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, CurrentBatchTimestamp, CurrentDate, CurrentTimestamp}
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.execution.SQLExecution
import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2Relation, WriteToDataSourceV2}
import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2Relation, StreamingDataSourceV2Relation, WriteToDataSourceV2}
import org.apache.spark.sql.execution.streaming.{ContinuousExecutionRelation, StreamingRelationV2, _}
import org.apache.spark.sql.sources.v2.DataSourceV2Options
import org.apache.spark.sql.sources.v2.streaming.{ContinuousReadSupport, ContinuousWriteSupport}
Expand Down Expand Up @@ -174,7 +174,7 @@ class ContinuousExecution(
val loggedOffset = offsets.offsets(0)
val realOffset = loggedOffset.map(off => reader.deserializeOffset(off.json))
reader.setOffset(java.util.Optional.ofNullable(realOffset.orNull))
DataSourceV2Relation(newOutput, reader)
new StreamingDataSourceV2Relation(newOutput, reader)
}

// Rewire the plan to use the new attributes that were returned by the source.
Expand Down

0 comments on commit 4f7e758

Please sign in to comment.