Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.streaming

import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, UnaryNode}
import org.apache.spark.sql.connector.catalog.Table
import org.apache.spark.sql.connector.catalog.{Identifier, Table, TableCatalog}
import org.apache.spark.sql.streaming.OutputMode

/**
Expand All @@ -31,7 +31,8 @@ case class WriteToStream(
sink: Table,
outputMode: OutputMode,
deleteCheckpointOnStop: Boolean,
inputQuery: LogicalPlan) extends UnaryNode {
inputQuery: LogicalPlan,
catalogAndIdent: Option[(TableCatalog, Identifier)] = None) extends UnaryNode {

override def isStreaming: Boolean = true

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ import org.apache.hadoop.conf.Configuration

import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, UnaryNode}
import org.apache.spark.sql.connector.catalog.Table
import org.apache.spark.sql.connector.catalog.{Identifier, Table, TableCatalog}
import org.apache.spark.sql.streaming.OutputMode

/**
Expand All @@ -40,6 +40,7 @@ import org.apache.spark.sql.streaming.OutputMode
* @param hadoopConf The Hadoop Configuration to get a FileSystem instance
* @param isContinuousTrigger Whether the statement is triggered by a continuous query or not.
* @param inputQuery The analyzed query plan from the streaming DataFrame.
* @param catalogAndIdent Catalog and identifier for the sink, set when it is a V2 catalog table
*/
case class WriteToStreamStatement(
userSpecifiedName: Option[String],
Expand All @@ -50,7 +51,8 @@ case class WriteToStreamStatement(
outputMode: OutputMode,
hadoopConf: Configuration,
isContinuousTrigger: Boolean,
inputQuery: LogicalPlan) extends UnaryNode {
inputQuery: LogicalPlan,
catalogAndIdent: Option[(TableCatalog, Identifier)] = None) extends UnaryNode {

override def isStreaming: Boolean = true

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -138,8 +138,12 @@ class DataSourceV2Strategy(session: SparkSession) extends Strategy with Predicat

withProjection :: Nil

case WriteToDataSourceV2(writer, query) =>
WriteToDataSourceV2Exec(writer, planLater(query)) :: Nil
case WriteToDataSourceV2(relationOpt, writer, query) =>
val invalidateCacheFunc: () => Unit = () => relationOpt match {
case Some(r) => session.sharedState.cacheManager.uncacheQuery(session, r, cascade = true)
case None => ()
}
WriteToDataSourceV2Exec(writer, invalidateCacheFunc, planLater(query)) :: Nil

case CreateV2Table(catalog, ident, schema, parts, props, ifNotExists) =>
val propsWithOwner = CatalogV2Util.withDefaultOwnership(props)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,10 @@ import org.apache.spark.util.{LongAccumulator, Utils}
* specific logical plans, like [[org.apache.spark.sql.catalyst.plans.logical.AppendData]].
*/
@deprecated("Use specific logical plans like AppendData instead", "2.4.0")
case class WriteToDataSourceV2(batchWrite: BatchWrite, query: LogicalPlan)
extends UnaryNode {
case class WriteToDataSourceV2(
relation: Option[DataSourceV2Relation],
batchWrite: BatchWrite,
query: LogicalPlan) extends UnaryNode {
override def child: LogicalPlan = query
override def output: Seq[Attribute] = Nil
}
Expand Down Expand Up @@ -250,10 +252,13 @@ case class OverwritePartitionsDynamicExec(

case class WriteToDataSourceV2Exec(
batchWrite: BatchWrite,
refreshCache: () => Unit,
query: SparkPlan) extends V2TableWriteExec {

override protected def run(): Seq[InternalRow] = {
writeWithV2(batchWrite)
val writtenRows = writeWithV2(batchWrite)
refreshCache()
writtenRows
Comment on lines 258 to +261
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of refreshing/invalidating the table per trigger, why we don't just invalidate the cache before we start the streaming query that writes the table?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes that should work too and also will require fewer code changes. I went this way to be consistent with other V2 write commands. Also, in future we may introduce DataStreamWriterV2 which could pass write node with UnresolvedRelation to analyzer and be converted to execution plan, and this approach may fit better in that case.

}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ import org.apache.spark.sql.catalyst.util.truncatedString
import org.apache.spark.sql.connector.catalog.{SupportsRead, SupportsWrite, TableCapability}
import org.apache.spark.sql.connector.read.streaming.{MicroBatchStream, Offset => OffsetV2, ReadLimit, SparkDataStream, SupportsAdmissionControl}
import org.apache.spark.sql.execution.SQLExecution
import org.apache.spark.sql.execution.datasources.v2.{StreamingDataSourceV2Relation, StreamWriterCommitProgress, WriteToDataSourceV2Exec}
import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2Relation, StreamingDataSourceV2Relation, StreamWriterCommitProgress, WriteToDataSourceV2Exec}
import org.apache.spark.sql.execution.streaming.sources.WriteToMicroBatchDataSource
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.streaming.Trigger
Expand Down Expand Up @@ -137,7 +137,10 @@ class MicroBatchExecution(
sink match {
case s: SupportsWrite =>
val streamingWrite = createStreamingWrite(s, extraOptions, _logicalPlan)
WriteToMicroBatchDataSource(streamingWrite, _logicalPlan)
val relationOpt = plan.catalogAndIdent.map {
case (catalog, ident) => DataSourceV2Relation.create(s, Some(catalog), Some(ident))
}
WriteToMicroBatchDataSource(relationOpt, streamingWrite, _logicalPlan)

case _ => _logicalPlan
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,8 @@ object ResolveWriteToStream extends Rule[LogicalPlan] with SQLConfHelper {
s.sink,
s.outputMode,
deleteCheckpointOnStop,
s.inputQuery)
s.inputQuery,
s.catalogAndIdent)
}

def resolveCheckpointLocation(s: WriteToStreamStatement): (String, Boolean) = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,20 +20,23 @@ package org.apache.spark.sql.execution.streaming.sources
import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, UnaryNode}
import org.apache.spark.sql.connector.write.streaming.StreamingWrite
import org.apache.spark.sql.execution.datasources.v2.WriteToDataSourceV2
import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2Relation, WriteToDataSourceV2}

/**
* The logical plan for writing data to a micro-batch stream.
*
* Note that this logical plan does not have a corresponding physical plan, as it will be converted
* to [[WriteToDataSourceV2]] with [[MicroBatchWrite]] before execution.
*/
case class WriteToMicroBatchDataSource(write: StreamingWrite, query: LogicalPlan)
case class WriteToMicroBatchDataSource(
relation: Option[DataSourceV2Relation],
write: StreamingWrite,
query: LogicalPlan)
extends UnaryNode {
override def child: LogicalPlan = query
override def output: Seq[Attribute] = Nil

def createPlan(batchId: Long): WriteToDataSourceV2 = {
WriteToDataSourceV2(new MicroBatchWrite(batchId, write), query)
WriteToDataSourceV2(relation, new MicroBatchWrite(batchId, write), query)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ import org.apache.spark.sql.catalyst.catalog.{CatalogTable, CatalogTableType}
import org.apache.spark.sql.catalyst.plans.logical.CreateTableStatement
import org.apache.spark.sql.catalyst.streaming.InternalOutputModes
import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap
import org.apache.spark.sql.connector.catalog.{SupportsWrite, Table, TableProvider, V1Table, V2TableWithV1Fallback}
import org.apache.spark.sql.connector.catalog.{Identifier, SupportsWrite, Table, TableCatalog, TableProvider, V1Table, V2TableWithV1Fallback}
import org.apache.spark.sql.connector.catalog.TableCapability._
import org.apache.spark.sql.execution.command.DDLUtils
import org.apache.spark.sql.execution.datasources.DataSource
Expand Down Expand Up @@ -374,7 +374,8 @@ final class DataStreamWriter[T] private[sql](ds: Dataset[T]) {

import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Implicits._
tableInstance match {
case t: SupportsWrite if t.supports(STREAMING_WRITE) => startQuery(t, extraOptions)
case t: SupportsWrite if t.supports(STREAMING_WRITE) =>
startQuery(t, extraOptions, catalogAndIdent = Some(catalog.asTableCatalog, identifier))
case t: V2TableWithV1Fallback =>
writeToV1Table(t.v1Table)
case t: V1Table =>
Expand Down Expand Up @@ -460,7 +461,8 @@ final class DataStreamWriter[T] private[sql](ds: Dataset[T]) {
private def startQuery(
sink: Table,
newOptions: CaseInsensitiveMap[String],
recoverFromCheckpoint: Boolean = true): StreamingQuery = {
recoverFromCheckpoint: Boolean = true,
catalogAndIdent: Option[(TableCatalog, Identifier)] = None): StreamingQuery = {
val useTempCheckpointLocation = SOURCES_ALLOW_ONE_TIME_QUERY.contains(source)

df.sparkSession.sessionState.streamingQueryManager.startQuery(
Expand All @@ -472,7 +474,8 @@ final class DataStreamWriter[T] private[sql](ds: Dataset[T]) {
outputMode,
useTempCheckpointLocation = useTempCheckpointLocation,
recoverFromCheckpointLocation = recoverFromCheckpoint,
trigger = trigger)
trigger = trigger,
catalogAndIdent = catalogAndIdent)
}

private def createV1Sink(optionsWithPath: CaseInsensitiveMap[String]): Sink = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ import org.apache.spark.annotation.Evolving
import org.apache.spark.internal.Logging
import org.apache.spark.sql.{DataFrame, SparkSession}
import org.apache.spark.sql.catalyst.streaming.{WriteToStream, WriteToStreamStatement}
import org.apache.spark.sql.connector.catalog.{SupportsWrite, Table}
import org.apache.spark.sql.connector.catalog.{Identifier, SupportsWrite, Table, TableCatalog}
import org.apache.spark.sql.execution.streaming._
import org.apache.spark.sql.execution.streaming.continuous.ContinuousExecution
import org.apache.spark.sql.execution.streaming.state.StateStoreCoordinatorRef
Expand Down Expand Up @@ -226,6 +226,7 @@ class StreamingQueryManager private[sql] (sparkSession: SparkSession) extends Lo
listenerBus.post(event)
}

// scalastyle:off argcount
private def createQuery(
userSpecifiedName: Option[String],
userSpecifiedCheckpointLocation: Option[String],
Expand All @@ -236,7 +237,8 @@ class StreamingQueryManager private[sql] (sparkSession: SparkSession) extends Lo
useTempCheckpointLocation: Boolean,
recoverFromCheckpointLocation: Boolean,
trigger: Trigger,
triggerClock: Clock): StreamingQueryWrapper = {
triggerClock: Clock,
catalogAndIdent: Option[(TableCatalog, Identifier)] = None): StreamingQueryWrapper = {
val analyzedPlan = df.queryExecution.analyzed
df.queryExecution.assertAnalyzed()

Expand All @@ -249,7 +251,8 @@ class StreamingQueryManager private[sql] (sparkSession: SparkSession) extends Lo
outputMode,
df.sparkSession.sessionState.newHadoopConf(),
trigger.isInstanceOf[ContinuousTrigger],
analyzedPlan)
analyzedPlan,
catalogAndIdent)

val analyzedStreamWritePlan =
sparkSession.sessionState.executePlan(dataStreamWritePlan).analyzed
Expand All @@ -272,7 +275,9 @@ class StreamingQueryManager private[sql] (sparkSession: SparkSession) extends Lo
analyzedStreamWritePlan))
}
}
// scalastyle:on argcount

// scalastyle:off argcount
/**
* Start a [[StreamingQuery]].
*
Expand All @@ -288,6 +293,7 @@ class StreamingQueryManager private[sql] (sparkSession: SparkSession) extends Lo
* will be thrown.
* @param trigger [[Trigger]] for the query.
* @param triggerClock [[Clock]] to use for the triggering.
* @param catalogAndIdent Catalog and identifier for the sink, set when it is a V2 catalog table
*/
@throws[TimeoutException]
private[sql] def startQuery(
Expand All @@ -300,7 +306,8 @@ class StreamingQueryManager private[sql] (sparkSession: SparkSession) extends Lo
useTempCheckpointLocation: Boolean = false,
recoverFromCheckpointLocation: Boolean = true,
trigger: Trigger = Trigger.ProcessingTime(0),
triggerClock: Clock = new SystemClock()): StreamingQuery = {
triggerClock: Clock = new SystemClock(),
catalogAndIdent: Option[(TableCatalog, Identifier)] = None): StreamingQuery = {
val query = createQuery(
userSpecifiedName,
userSpecifiedCheckpointLocation,
Expand All @@ -311,7 +318,9 @@ class StreamingQueryManager private[sql] (sparkSession: SparkSession) extends Lo
useTempCheckpointLocation,
recoverFromCheckpointLocation,
trigger,
triggerClock)
triggerClock,
catalogAndIdent)
// scalastyle:on argcount

// The following code block checks if a stream with the same name or id is running. Then it
// returns an Option of an already active stream to stop outside of the lock
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ import org.apache.spark.sql.connector.catalog._
import org.apache.spark.sql.connector.catalog.CatalogManager.SESSION_CATALOG_NAME
import org.apache.spark.sql.connector.catalog.CatalogV2Util.withDefaultOwnership
import org.apache.spark.sql.execution.columnar.InMemoryRelation
import org.apache.spark.sql.execution.streaming.MemoryStream
import org.apache.spark.sql.internal.{SQLConf, StaticSQLConf}
import org.apache.spark.sql.internal.SQLConf.{PARTITION_OVERWRITE_MODE, PartitionOverwriteMode, V2_SESSION_CATALOG_IMPLEMENTATION}
import org.apache.spark.sql.internal.connector.SimpleTableProvider
Expand Down Expand Up @@ -847,6 +848,34 @@ class DataSourceV2SQLSuite
}
}

test("SPARK-34947: micro batch streaming write should invalidate cache") {
import testImplicits._

val t = "testcat.ns.t"
withTable(t) {
withTempDir { checkpointDir =>
sql(s"CREATE TABLE $t (id bigint, data string) USING foo")
sql(s"INSERT INTO $t VALUES (1L, 'a')")
sql(s"CACHE TABLE $t")

val inputData = MemoryStream[(Long, String)]
val df = inputData.toDF().toDF("id", "data")
val query = df
.writeStream
.option("checkpointLocation", checkpointDir.getAbsolutePath)
.toTable(t)

val newData = Seq((2L, "b"))
inputData.addData(newData)
query.processAllAvailable()
query.stop()

assert(!spark.catalog.isCached("testcat.ns.t"))
checkAnswer(sql(s"SELECT * FROM $t"), Row(1L, "a") :: Row(2L, "b") :: Nil)
}
}
}

test("Relation: basic") {
val t1 = "testcat.ns1.ns2.tbl"
withTable(t1) {
Expand Down