diff --git a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/AnalysisWarning.scala b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/AnalysisWarning.scala deleted file mode 100644 index 35b8185c255e..000000000000 --- a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/AnalysisWarning.scala +++ /dev/null @@ -1,33 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.pipelines - -/** Represents a warning generated as part of graph analysis. */ -sealed trait AnalysisWarning - -object AnalysisWarning { - - /** - * Warning that some streaming reader options are being dropped - * - * @param sourceName Source for which reader options are being dropped. - * @param droppedOptions Set of reader options that are being dropped for a specific source. - */ - case class StreamingReaderOptionsDropped(sourceName: String, droppedOptions: Seq[String]) - extends AnalysisWarning -} diff --git a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/CoreDataflowNodeProcessor.scala b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/CoreDataflowNodeProcessor.scala index b87c02d562cb..69a221538dd5 100644 --- a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/CoreDataflowNodeProcessor.scala +++ b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/CoreDataflowNodeProcessor.scala @@ -86,7 +86,8 @@ class CoreDataflowNodeProcessor(rawGraph: DataflowGraph) { identifier = table.identifier, specifiedSchema = table.specifiedSchema, incomingFlowIdentifiers = flowsToTable.map(_.identifier).toSet, - availableFlows = resolvedFlowsToTable + availableFlows = resolvedFlowsToTable, + isStreamingTable = table.isStreamingTable ) resolvedInputs.put(table.identifier, virtualTableInput) Seq(table) diff --git a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/Flow.scala b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/Flow.scala index 91feee936170..578c2589e018 100644 --- a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/Flow.scala +++ b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/Flow.scala @@ -22,8 +22,6 @@ import scala.util.Try import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.{AliasIdentifier, TableIdentifier} import org.apache.spark.sql.classic.DataFrame -import org.apache.spark.sql.pipelines.AnalysisWarning -import org.apache.spark.sql.pipelines.util.InputReadOptions import org.apache.spark.sql.types.StructType /** @@ -99,8 +97,7 @@ case class FlowFunctionResult( streamingInputs: Set[ResolvedInput], usedExternalInputs: Set[TableIdentifier], dataFrame: Try[DataFrame], - sqlConf: Map[String, String], - analysisWarnings: Seq[AnalysisWarning] = Nil) { + sqlConf: Map[String, String]) { /** * Returns the names of all of the [[Input]]s used when resolving this [[Flow]]. If the @@ -165,7 +162,7 @@ trait ResolvedFlow extends ResolutionCompletedFlow with Input { /** Returns the schema of the output of this [[Flow]]. */ def schema: StructType = df.schema - override def load(readOptions: InputReadOptions): DataFrame = df + override def load: DataFrame = df def inputs: Set[TableIdentifier] = funcResult.inputs } diff --git a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/FlowAnalysis.scala b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/FlowAnalysis.scala index 18ae45c4f340..779db45ea4f0 100644 --- a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/FlowAnalysis.scala +++ b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/FlowAnalysis.scala @@ -23,10 +23,8 @@ import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.{AliasIdentifier, TableIdentifier} import org.apache.spark.sql.catalyst.analysis.{CTESubstitution, UnresolvedRelation} import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, SubqueryAlias} -import org.apache.spark.sql.classic.{DataFrame, Dataset, DataStreamReader, SparkSession} -import org.apache.spark.sql.pipelines.AnalysisWarning +import org.apache.spark.sql.classic.{DataFrame, DataFrameReader, Dataset, DataStreamReader, SparkSession} import org.apache.spark.sql.pipelines.graph.GraphIdentifierManager.{ExternalDatasetIdentifier, InternalDatasetIdentifier} -import org.apache.spark.sql.pipelines.util.{BatchReadOptions, InputReadOptions, StreamingReadOptions} object FlowAnalysis { @@ -67,8 +65,7 @@ object FlowAnalysis { streamingInputs = ctx.streamingInputs.toSet, usedExternalInputs = ctx.externalInputs.toSet, dataFrame = df, - sqlConf = confs, - analysisWarnings = ctx.analysisWarnings.toList + sqlConf = confs ) } } @@ -116,8 +113,7 @@ object FlowAnalysis { val resolved = readStreamInput( context, name = IdentifierHelper.toQuotedString(u.multipartIdentifier), - spark.readStream, - streamingReadOptions = StreamingReadOptions() + streamReader = spark.readStream.options(u.options) ).queryExecution.analyzed // Spark Connect requires the PLAN_ID_TAG to be propagated to the resolved plan // to allow correct analysis of the parent plan that contains this subquery @@ -128,7 +124,7 @@ object FlowAnalysis { val resolved = readBatchInput( context, name = IdentifierHelper.toQuotedString(u.multipartIdentifier), - batchReadOptions = BatchReadOptions() + batchReader = spark.read.options(u.options) ).queryExecution.analyzed // Spark Connect requires the PLAN_ID_TAG to be propagated to the resolved plan // to allow correct analysis of the parent plan that contains this subquery @@ -147,23 +143,25 @@ object FlowAnalysis { * All the public APIs that read from a dataset should call this function to read the dataset. * * @param name the name of the Dataset to be read. - * @param batchReadOptions Options for this batch read + * @param batchReader the batch dataframe reader, possibly with options, to execute the read + * with. * @return batch DataFrame that represents data from the specified Dataset. */ final private def readBatchInput( context: FlowAnalysisContext, name: String, - batchReadOptions: BatchReadOptions + batchReader: DataFrameReader ): DataFrame = { GraphIdentifierManager.parseAndQualifyInputIdentifier(context, name) match { case inputIdentifier: InternalDatasetIdentifier => - readGraphInput(context, inputIdentifier, batchReadOptions) + readGraphInput(context, inputIdentifier, isStreamingRead = false) case inputIdentifier: ExternalDatasetIdentifier => readExternalBatchInput( context, inputIdentifier = inputIdentifier, - name = name + name = name, + batchReader = batchReader ) } } @@ -177,21 +175,19 @@ object FlowAnalysis { * * @param name the name of the Dataset to be read. * @param streamReader The [[DataStreamReader]] that may hold read options specified by the user. - * @param streamingReadOptions Options for this streaming read. * @return streaming DataFrame that represents data from the specified Dataset. */ final private def readStreamInput( context: FlowAnalysisContext, name: String, - streamReader: DataStreamReader, - streamingReadOptions: StreamingReadOptions + streamReader: DataStreamReader ): DataFrame = { GraphIdentifierManager.parseAndQualifyInputIdentifier(context, name) match { case inputIdentifier: InternalDatasetIdentifier => readGraphInput( context, inputIdentifier, - streamingReadOptions + isStreamingRead = true ) case inputIdentifier: ExternalDatasetIdentifier => @@ -208,13 +204,13 @@ object FlowAnalysis { * Internal helper to reference dataset defined in the same [[DataflowGraph]]. * * @param inputIdentifier The identifier of the Dataset to be read. - * @param readOptions Options for this read (may be either streaming or batch options) + * @param isStreamingRead Whether this is a streaming read or batch read. * @return streaming or batch DataFrame that represents data from the specified Dataset. */ final private def readGraphInput( ctx: FlowAnalysisContext, inputIdentifier: InternalDatasetIdentifier, - readOptions: InputReadOptions + isStreamingRead: Boolean ): DataFrame = { val datasetIdentifier = inputIdentifier.identifier @@ -230,8 +226,15 @@ object FlowAnalysis { // Dataset is resolved, so we can read from it ctx.availableInput(datasetIdentifier) } + val inputDF = i match { + case vt: VirtualTableInput => + // Unlike temporary views (which would have been substituted into flows by this point), we + // allow tables to batch read a streaming dataset. We do not allow the opposite however, + // which is checked on the resolved graph during graph validation. + vt.load(asStreaming = isStreamingRead) + case _ => i.load + } - val inputDF = i.load(readOptions) i match { // If the referenced input is a [[Flow]], because the query plans will be fused // together, we also need to fuse their confs. @@ -252,30 +255,22 @@ object FlowAnalysis { qualifier = Seq(datasetIdentifier.catalog, datasetIdentifier.database).flatten ) - readOptions match { - case sro: StreamingReadOptions => - if (!inputDF.isStreaming && incompatibleViewReadCheck) { - throw new AnalysisException( - "INCOMPATIBLE_BATCH_VIEW_READ", - Map("datasetIdentifier" -> datasetIdentifier.toString) - ) - } - - if (sro.droppedUserOptions.nonEmpty) { - ctx.analysisWarnings += AnalysisWarning.StreamingReaderOptionsDropped( - sourceName = datasetIdentifier.unquotedString, - droppedOptions = sro.droppedUserOptions.keys.toSeq - ) - } - ctx.streamingInputs += ResolvedInput(i, aliasIdentifier) - case _ => - if (inputDF.isStreaming && incompatibleViewReadCheck) { - throw new AnalysisException( - "INCOMPATIBLE_STREAMING_VIEW_READ", - Map("datasetIdentifier" -> datasetIdentifier.toString) - ) - } - ctx.batchInputs += ResolvedInput(i, aliasIdentifier) + if (isStreamingRead) { + if (!inputDF.isStreaming && incompatibleViewReadCheck) { + throw new AnalysisException( + "INCOMPATIBLE_BATCH_VIEW_READ", + Map("datasetIdentifier" -> datasetIdentifier.toString) + ) + } + ctx.streamingInputs += ResolvedInput(i, aliasIdentifier) + } else { + if (inputDF.isStreaming && incompatibleViewReadCheck) { + throw new AnalysisException( + "INCOMPATIBLE_STREAMING_VIEW_READ", + Map("datasetIdentifier" -> datasetIdentifier.toString) + ) + } + ctx.batchInputs += ResolvedInput(i, aliasIdentifier) } Dataset.ofRows( ctx.spark, @@ -293,11 +288,11 @@ object FlowAnalysis { final private def readExternalBatchInput( context: FlowAnalysisContext, inputIdentifier: ExternalDatasetIdentifier, - name: String): DataFrame = { + name: String, + batchReader: DataFrameReader): DataFrame = { - val spark = context.spark context.externalInputs += inputIdentifier.identifier - spark.read.table(inputIdentifier.identifier.quotedString) + batchReader.table(inputIdentifier.identifier.quotedString) } /** diff --git a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/FlowAnalysisContext.scala b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/FlowAnalysisContext.scala index 1139946df59a..e5f7cddc4d32 100644 --- a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/FlowAnalysisContext.scala +++ b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/FlowAnalysisContext.scala @@ -18,11 +18,9 @@ package org.apache.spark.sql.pipelines.graph import scala.collection.mutable -import scala.collection.mutable.ListBuffer import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.classic.SparkSession -import org.apache.spark.sql.pipelines.AnalysisWarning /** * A context used when evaluating a `Flow`'s query into a concrete DataFrame. @@ -44,7 +42,6 @@ private[pipelines] case class FlowAnalysisContext( streamingInputs: mutable.HashSet[ResolvedInput] = mutable.HashSet.empty, requestedInputs: mutable.HashSet[TableIdentifier] = mutable.HashSet.empty, shouldLowerCaseNames: Boolean = false, - analysisWarnings: mutable.Buffer[AnalysisWarning] = new ListBuffer[AnalysisWarning], spark: SparkSession, externalInputs: mutable.HashSet[TableIdentifier] = mutable.HashSet.empty ) { diff --git a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/GraphValidations.scala b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/GraphValidations.scala index 824c15dc8791..0e48f52f8c96 100644 --- a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/GraphValidations.scala +++ b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/GraphValidations.scala @@ -212,7 +212,7 @@ trait GraphValidations extends Logging { } protected def validateUserSpecifiedSchemas(): Unit = { - flows.flatMap(f => table.get(f.identifier)).foreach { t: TableInput => + flows.flatMap(f => table.get(f.identifier)).foreach { t: TableElement => // The output inferred schema of a table is the declared schema merged with the // schema of all incoming flows. This must be equivalent to the declared schema. val inferredSchema = SchemaInferenceUtils diff --git a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/State.scala b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/State.scala index efe5849d1cbd..1ab701fd2598 100644 --- a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/State.scala +++ b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/State.scala @@ -29,7 +29,9 @@ object State extends Logging { * @param graph The graph to reset. * @param env The current update context. */ - private def findElementsToReset(graph: DataflowGraph, env: PipelineUpdateContext): Seq[Input] = { + private def findElementsToReset( + graph: DataflowGraph, + env: PipelineUpdateContext): Seq[GraphElement] = { // If tableFilter is an instance of SomeTables, this is a refresh selection and all tables // to reset should be resettable; Otherwise, this is a full graph update, and we reset all // tables that are resettable. @@ -71,8 +73,8 @@ object State extends Logging { * - Clearing checkpoint data * - Truncating table data */ - def reset(resolvedGraph: DataflowGraph, env: PipelineUpdateContext): Seq[Input] = { - val elementsToReset: Seq[Input] = findElementsToReset(resolvedGraph, env) + def reset(resolvedGraph: DataflowGraph, env: PipelineUpdateContext): Seq[GraphElement] = { + val elementsToReset: Seq[GraphElement] = findElementsToReset(resolvedGraph, env) elementsToReset.foreach { case f: ResolvedFlow => reset(f, env, resolvedGraph) diff --git a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/elements.scala b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/elements.scala index c762174e6725..d224b9f20c99 100644 --- a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/elements.scala +++ b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/elements.scala @@ -19,8 +19,6 @@ package org.apache.spark.sql.pipelines.graph import java.util -import scala.util.control.NonFatal - import org.apache.spark.SparkException import org.apache.spark.internal.Logging import org.apache.spark.sql.Row @@ -29,12 +27,7 @@ import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.classic.{DataFrame, SparkSession} import org.apache.spark.sql.execution.streaming.runtime.MemoryStream import org.apache.spark.sql.pipelines.common.DatasetType -import org.apache.spark.sql.pipelines.util.{ - BatchReadOptions, - InputReadOptions, - SchemaInferenceUtils, - StreamingReadOptions -} +import org.apache.spark.sql.pipelines.util.SchemaInferenceUtils import org.apache.spark.sql.types.StructType /** An element in a [[DataflowGraph]]. */ @@ -68,10 +61,9 @@ trait Input extends GraphElement { /** * Returns a DataFrame that is a result of loading data from this [[Input]]. - * @param readOptions Type of input. Used to determine streaming/batch * @return Streaming or batch DataFrame of this Input's data. */ - def load(readOptions: InputReadOptions): DataFrame + def load: DataFrame } /** @@ -102,7 +94,7 @@ sealed trait Dataset extends Output { } /** A type of [[Input]] where data is loaded from a table. */ -sealed trait TableInput extends Input { +sealed trait TableElement extends GraphElement { /** The user-specified schema for this table. */ def specifiedSchema: Option[StructType] @@ -132,29 +124,9 @@ case class Table( override val origin: QueryOrigin, isStreamingTable: Boolean, format: Option[String] -) extends TableInput +) extends TableElement with Dataset { - // Load this table's data from underlying storage. - override def load(readOptions: InputReadOptions): DataFrame = { - try { - lazy val tableName = identifier.quotedString - - val df = readOptions match { - case sro: StreamingReadOptions => - spark.readStream.options(sro.userOptions).table(tableName) - case _: BatchReadOptions => - spark.read.table(tableName) - case _ => - throw new IllegalArgumentException("Unhandled `InputReadOptions` type when loading table") - } - - df - } catch { - case NonFatal(e) => throw LoadTableException(displayName, Option(e)) - } - } - /** Returns the normalized storage location to this [[Table]]. */ override def path: String = { if (!normalized) { @@ -176,42 +148,62 @@ case class Table( } /** - * A type of [[TableInput]] that returns data from a specified schema or from the inferred - * [[Flow]]s that write to the table. + * A virtual table is a representation of a pipeline table used during analysis. During analysis we + * only care about the schemas of declared tables, and its possible the declared tables do not yet + * exist in the catalog. Hence we represent all tables in the graph with their "virtual" + * counterparts, which are simply empty dataframes but with the same schemas. + * + * We refer to the declared table that the virtual counterpart represents as the "parent" table + * below. + * + * @param identifier The identifier of the parent table. + * @param specifiedSchema The user-specified schema for the parent table. + * @param incomingFlowIdentifiers The identifiers of all flows that write to the parent table. + * @param availableFlows All resolved flows that write to the parent table. + * @param isStreamingTable Whether the parent table is a streaming table or not. */ case class VirtualTableInput( identifier: TableIdentifier, specifiedSchema: Option[StructType], incomingFlowIdentifiers: Set[TableIdentifier], - availableFlows: Seq[ResolvedFlow] = Nil -) extends TableInput + availableFlows: Seq[ResolvedFlow] = Nil, + isStreamingTable: Boolean +) extends TableElement with Input with Logging { override def origin: QueryOrigin = QueryOrigin() assert(availableFlows.forall(_.destinationIdentifier == identifier)) - override def load(readOptions: InputReadOptions): DataFrame = { - // Infer the schema for this virtual table - def getFinalSchema: StructType = { - specifiedSchema match { - // This is not a backing table, and we have a user-specified schema, so use it directly. - case Some(ss) => ss - // Otherwise infer the schema from a combination of the incoming flows and the - // user-specified schema, if provided. - case _ => - SchemaInferenceUtils.inferSchemaFromFlows(availableFlows, specifiedSchema) - } + + /** + * Loads this virtual table as a dataframe + * + * @param asStreaming whether to load as a streaming DF or batch DF. There are cases where we may + * want to batch read from a streaming table, for example. + */ + def load(asStreaming: Boolean): DataFrame = { + val deducedSchema = specifiedSchema match { + // If the user specified a schema, use it directly. + case Some(ss) => ss + // Otherwise infer the schema from a combination of the incoming flows and the + // user-specified schema, if provided. + case _ => + SchemaInferenceUtils.inferSchemaFromFlows(availableFlows, specifiedSchema) } - // create empty streaming/batch df based on input type. - def createEmptyDF(schema: StructType): DataFrame = readOptions match { - case _: StreamingReadOptions => - MemoryStream[Row](ExpressionEncoder(schema, lenient = false), spark) - .toDF() - case _ => spark.createDataFrame(new util.ArrayList[Row](), schema) + // Produce either a streaming or batch dataframe, depending on whether this is a virtual + // representation of a streaming or non-streaming table. Return the [empty] dataframe with the + // deduced schema. + if (asStreaming) { + MemoryStream[Row](ExpressionEncoder(deducedSchema, lenient = false), spark) + .toDF() + } else { + spark.createDataFrame(new util.ArrayList[Row](), deducedSchema) } + } - val df = createEmptyDF(getFinalSchema) - df + /** Default load virtual table into dataframe. */ + override def load: DataFrame = { + load(asStreaming = isStreamingTable) } } diff --git a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/util/InputReadInfo.scala b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/util/InputReadInfo.scala deleted file mode 100644 index 070927aea295..000000000000 --- a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/util/InputReadInfo.scala +++ /dev/null @@ -1,48 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.pipelines.util - -import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap -import org.apache.spark.sql.pipelines.util.StreamingReadOptions.EmptyUserOptions - -/** - * Generic options for a read of an input. - */ -sealed trait InputReadOptions - -/** - * Options for a batch read of an input. - */ -final case class BatchReadOptions() extends InputReadOptions - -/** - * Options for a streaming read of an input. - * - * @param userOptions Holds the user defined read options. - * @param droppedUserOptions Holds the options that were specified by the user but - * not actually used. This is a bug but we are preserving this behavior - * for now to avoid making a backwards incompatible change. - */ -final case class StreamingReadOptions( - userOptions: CaseInsensitiveMap[String] = EmptyUserOptions, - droppedUserOptions: CaseInsensitiveMap[String] = EmptyUserOptions -) extends InputReadOptions - -object StreamingReadOptions { - val EmptyUserOptions: CaseInsensitiveMap[String] = CaseInsensitiveMap(Map()) -} diff --git a/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/analysis/ReadOptionsPropagationOnAnalysisSuite.scala b/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/analysis/ReadOptionsPropagationOnAnalysisSuite.scala new file mode 100644 index 000000000000..908c036738b3 --- /dev/null +++ b/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/analysis/ReadOptionsPropagationOnAnalysisSuite.scala @@ -0,0 +1,252 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.pipelines.analysis + +import scala.collection.mutable.{Map => MutableMap} +import scala.jdk.CollectionConverters._ + +import org.apache.spark.sql.catalyst.TableIdentifier +import org.apache.spark.sql.catalyst.plans.logical.SubqueryAlias +import org.apache.spark.sql.classic.SparkSession +import org.apache.spark.sql.execution.datasources.{HadoopFsRelation, LogicalRelation} +import org.apache.spark.sql.execution.streaming.runtime.StreamingRelation +import org.apache.spark.sql.pipelines.graph.{FlowFunction, FlowFunctionResult, Input, QueryContext, QueryOrigin} +import org.apache.spark.sql.pipelines.utils.{ExecutionTest, TestGraphRegistrationContext} +import org.apache.spark.sql.test.SharedSparkSession +import org.apache.spark.sql.util.CaseInsensitiveStringMap + +/** + * Tracker for flow function results. + * @param flowFunctionResults Mutable map storing the latest FlowFunctionResult per flow function + */ +case class FlowFunctionResultTracker( + flowFunctionResults: MutableMap[String, FlowFunctionResult] +) + +/** + * Instrumented FlowFunction implementation, used to track flow function results. + * @param flowName The name of the flow function being tracked + * @param flowFunction The flow function being tracked + * @param flowFunctionResultTracker The flow function results tracker instance + */ +class InstrumentedFlowFunction( + flowName: String, + flowFunction: FlowFunction, + flowFunctionResultTracker: FlowFunctionResultTracker +) + extends FlowFunction { + override def call( + allInputs: Set[TableIdentifier], + availableInputs: Seq[Input], + configuration: Map[String, String], + queryContext: QueryContext, + queryOrigin: QueryOrigin + ): FlowFunctionResult = { + val flowFunctionResult = flowFunction.call( + allInputs, + availableInputs, + configuration, + queryContext, + queryOrigin + ) + flowFunctionResultTracker.flowFunctionResults.put(flowName, flowFunctionResult) + flowFunctionResult + } +} + +class InstrumentedTestGraphRegistrationContext( + spark: SparkSession, + flowFunctionResultTracker: FlowFunctionResultTracker +) + extends TestGraphRegistrationContext(spark) { + + def readFlowFunc( + flowNameForTracking: String, + tableName: String, + extraOptions: CaseInsensitiveStringMap + ): FlowFunction = + new InstrumentedFlowFunction( + flowName = flowNameForTracking, + flowFunction = readFlowFunc(tableName, extraOptions), + flowFunctionResultTracker = flowFunctionResultTracker + ) + + def readStreamFlowFunc( + flowNameForTracking: String, + tableName: String, + extraOptions: CaseInsensitiveStringMap + ): FlowFunction = + new InstrumentedFlowFunction( + flowName = flowNameForTracking, + flowFunction = readStreamFlowFunc(tableName, extraOptions), + flowFunctionResultTracker = flowFunctionResultTracker + ) +} + +/** + * Test suite for verifying propagation of read options during pipelines analysis. + */ +class ReadOptionsPropagationOnAnalysisSuite extends ExecutionTest with SharedSparkSession { + test("Internal pipeline batch read options are propagated during flow function analysis") { + val session = spark + import session.implicits._ + + val flowFunctionResultTracker = FlowFunctionResultTracker(MutableMap.empty) + + withTable("a", "b") { + val graphRegistrationContext = + new InstrumentedTestGraphRegistrationContext(spark, flowFunctionResultTracker) { + registerMaterializedView(name = "a", query = dfFlowFunc(Seq(1, 2).toDF("id"))) + registerMaterializedView( + name = "b", + query = readFlowFunc( + flowNameForTracking = "bFlow", + tableName = "a", + extraOptions = new CaseInsensitiveStringMap(Map("x" -> "y").asJava) + ) + ) + } + val unresolvedGraph = graphRegistrationContext.toDataflowGraph + + val updateContext = TestPipelineUpdateContext(spark, unresolvedGraph, storageRoot) + updateContext.pipelineExecution.runPipeline() + updateContext.pipelineExecution.awaitCompletion() + + val bFlow = flowFunctionResultTracker.flowFunctionResults.get("bFlow").get + + // Verify the flow function's analyzed DF logical plan contains specified options. + assert(bFlow.dataFrame.get.logicalPlan + .asInstanceOf[SubqueryAlias].child + .asInstanceOf[LogicalRelation].relation + .asInstanceOf[HadoopFsRelation].options.get("x").contains("y")) + } + } + + test("Internal pipeline stream read options are propagated during flow function analysis") { + val flowFunctionResultTracker = FlowFunctionResultTracker(MutableMap.empty) + + withTable("spark_catalog.default.a", "b", "c") { + // Create a regular external table that ST "b" can stream from, then have ST "c" stream from + // "b". + spark.range(10).write.saveAsTable("spark_catalog.default.a") + + val graphRegistrationContext = + new InstrumentedTestGraphRegistrationContext(spark, flowFunctionResultTracker) { + registerTable( + name = "b", + query = Option( + readStreamFlowFunc( + name = "spark_catalog.default.a" + ) + ) + ) + registerTable( + name = "c", + query = Option( + readStreamFlowFunc( + flowNameForTracking = "cFlow", + tableName = "b", + extraOptions = new CaseInsensitiveStringMap(Map("x" -> "y").asJava) + ) + ) + ) + } + val unresolvedGraph = graphRegistrationContext.toDataflowGraph + + val updateContext = TestPipelineUpdateContext(spark, unresolvedGraph, storageRoot) + updateContext.pipelineExecution.runPipeline() + updateContext.pipelineExecution.awaitCompletion() + + val cFlow = flowFunctionResultTracker.flowFunctionResults.get("cFlow").get + + // Verify the flow function's analyzed DF logical plan contains specified options. + assert(cFlow.dataFrame.get.logicalPlan + .asInstanceOf[SubqueryAlias].child + .asInstanceOf[StreamingRelation].dataSource.options.get("x").contains("y")) + } + } + + test("External pipeline batch read options are propagated during flow function analysis") { + val flowFunctionResultTracker = FlowFunctionResultTracker(MutableMap.empty) + + withTable("spark_catalog.default.a", "b") { + // Create regular external table to batch read from with options. + spark.range(10).write.saveAsTable("spark_catalog.default.a") + + val graphRegistrationContext = + new InstrumentedTestGraphRegistrationContext(spark, flowFunctionResultTracker) { + registerMaterializedView( + name = "b", + query = readFlowFunc( + flowNameForTracking = "bFlow", + tableName = "spark_catalog.default.a", + extraOptions = new CaseInsensitiveStringMap(Map("x" -> "y").asJava) + ) + ) + } + val unresolvedGraph = graphRegistrationContext.toDataflowGraph + + val updateContext = TestPipelineUpdateContext(spark, unresolvedGraph, storageRoot) + updateContext.pipelineExecution.runPipeline() + updateContext.pipelineExecution.awaitCompletion() + + val bFlow = flowFunctionResultTracker.flowFunctionResults.get("bFlow").get + + // Verify the flow function's analyzed DF logical plan contains specified options. + assert(bFlow.dataFrame.get.logicalPlan + .asInstanceOf[SubqueryAlias].child + .asInstanceOf[LogicalRelation].relation + .asInstanceOf[HadoopFsRelation].options.get("x").contains("y")) + } + } + + test("External pipeline stream read options are propagated during flow function analysis") { + val flowFunctionResultTracker = FlowFunctionResultTracker(MutableMap.empty) + + withTable("spark_catalog.default.a", "b") { + // Create regular external table to stream from with read options. + spark.range(10).write.saveAsTable("spark_catalog.default.a") + + val graphRegistrationContext = + new InstrumentedTestGraphRegistrationContext(spark, flowFunctionResultTracker) { + registerTable( + name = "b", + query = Option( + readStreamFlowFunc( + flowNameForTracking = "bFlow", + tableName = "spark_catalog.default.a", + extraOptions = new CaseInsensitiveStringMap(Map("x" -> "y").asJava) + ) + ) + ) + } + val unresolvedGraph = graphRegistrationContext.toDataflowGraph + + val updateContext = TestPipelineUpdateContext(spark, unresolvedGraph, storageRoot) + updateContext.pipelineExecution.runPipeline() + updateContext.pipelineExecution.awaitCompletion() + + val bFlow = flowFunctionResultTracker.flowFunctionResults.get("bFlow").get + + // Verify the flow function's analyzed DF logical plan contains specified options. + assert(bFlow.dataFrame.get.logicalPlan + .asInstanceOf[SubqueryAlias].child + .asInstanceOf[StreamingRelation].dataSource.options.get("x").contains("y")) + } + } +} diff --git a/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/graph/ConnectValidPipelineSuite.scala b/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/graph/ConnectValidPipelineSuite.scala index a4bb7c067d87..8fff51059952 100644 --- a/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/graph/ConnectValidPipelineSuite.scala +++ b/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/graph/ConnectValidPipelineSuite.scala @@ -411,7 +411,7 @@ class ConnectValidPipelineSuite extends PipelineTest with SharedSparkSession { mem.addData(1, 2) registerPersistedView("complete-view", query = dfFlowFunc(Seq(1, 2).toDF("x"))) registerPersistedView("incremental-view", query = dfFlowFunc(mem.toDF())) - registerTable("`complete-table`", query = Option(readFlowFunc("complete-view"))) + registerTable("`complete-table`", query = Option(readFlowFunc("`complete-view`"))) registerTable("`incremental-table`") registerFlow( "`incremental-table`", diff --git a/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/utils/TestGraphRegistrationContext.scala b/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/utils/TestGraphRegistrationContext.scala index d88432d68ca3..605433a60edc 100644 --- a/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/utils/TestGraphRegistrationContext.scala +++ b/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/utils/TestGraphRegistrationContext.scala @@ -36,6 +36,9 @@ class TestGraphRegistrationContext( defaultSqlConf = sqlConf ) { + /** Expose all registered flows for tests */ + def getFlows: List[UnresolvedFlow] = flows.toList + // scalastyle:off // Disable scalastyle to ignore argument count. /** Registers a streaming table in this [[TestGraphRegistrationContext]] */ @@ -355,19 +358,31 @@ class TestGraphRegistrationContext( /** * Creates a flow function from a logical plan that reads from a table with the given name. */ - def readFlowFunc(name: String): FlowFunction = { - FlowAnalysis.createFlowFunctionFromLogicalPlan(UnresolvedRelation(TableIdentifier(name))) + def readFlowFunc( + name: String, + extraOptions: CaseInsensitiveStringMap = CaseInsensitiveStringMap.empty() + ): FlowFunction = { + FlowAnalysis.createFlowFunctionFromLogicalPlan( + UnresolvedRelation( + tableIdentifier = GraphIdentifierManager.parseTableIdentifier(name, spark), + extraOptions = extraOptions, + isStreaming = false + ) + ) } /** * Creates a flow function from a logical plan that reads a stream from a table with the given * name. */ - def readStreamFlowFunc(name: String): FlowFunction = { + def readStreamFlowFunc( + name: String, + extraOptions: CaseInsensitiveStringMap = CaseInsensitiveStringMap.empty() + ): FlowFunction = { FlowAnalysis.createFlowFunctionFromLogicalPlan( UnresolvedRelation( - TableIdentifier(name), - extraOptions = CaseInsensitiveStringMap.empty(), + tableIdentifier = GraphIdentifierManager.parseTableIdentifier(name, spark), + extraOptions = extraOptions, isStreaming = true ) )