diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/WidenStatefulOperatorAttributeNullability.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/WidenStatefulOperatorAttributeNullability.scala new file mode 100644 index 0000000000000..f315408033148 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/WidenStatefulOperatorAttributeNullability.scala @@ -0,0 +1,174 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.analysis + +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, ExprId} +import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LogicalPlan} +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types.StructType + +/** + * Shared helpers for the stateful-operator nullability fix. The fix has three + * independent components, all gated by + * [[SQLConf.STATEFUL_OPERATOR_ALWAYS_NULLABLE_OUTPUT]] (pinned per-query via the + * offset log so existing queries keep their pre-fix behavior on restart): + * + * - (a) `widenStateSchema`: explicit `asNullable` at every state-schema construction + * site in each stateful physical exec. + * - (b) `widenOutputForStatefulOp`: a per-op `output` override on every stateful logical + * and physical operator, used by the operator's `output` definition. + * - (c) [[WidenStatefulOperatorAttributeNullability]] (defined below in this file): a + * custom optimizer rule that widens `AttributeReference`s inside stateful ops' + * internal expressions and propagates upward to ancestor expressions. + */ +object WidenStatefulOpNullability { + + def isEnabled: Boolean = + SQLConf.get.getConf(SQLConf.STATEFUL_OPERATOR_ALWAYS_NULLABLE_OUTPUT) + + /** + * Recursively widens an attribute to be fully nullable: outer `nullable = true` plus + * every nested `StructField.nullable`, `ArrayType.containsNull`, and + * `MapType.valueContainsNull` flipped to `true` via + * [[org.apache.spark.sql.types.DataType#asNullable]]. + */ + def deepWidenAttribute(a: Attribute): Attribute = a match { + case ref: AttributeReference => + AttributeReference( + ref.name, ref.dataType.asNullable, nullable = true, ref.metadata)( + ref.exprId, ref.qualifier) + case other => other.withNullability(true) + } + + /** + * Component (a): widens a state schema to fully nullable. Stateful physical execs apply + * this at every `validateAndMaybeEvolveStateSchema(...)` call site and every + * `mapPartitionsWith*StateStore(...)` call site. When the conf is off, returns the + * schema unchanged. + */ + def widenStateSchema(schema: StructType): StructType = + if (isEnabled) schema.asNullable else schema + + /** + * Component (b): wraps a stateful operator's `output` to be fully nullable. The caller + * is responsible for only calling this from within an `output` definition on a stateful + * operator; gating is handled here via [[isEnabled]]. + */ + def widenOutputForStatefulOp(base: Seq[Attribute]): Seq[Attribute] = + if (isEnabled) base.map(deepWidenAttribute) else base + + /** + * Recursively walks a schema and replaces any nested `StructType` that + * structurally matches `original` (by field names and base types, ignoring + * nullability) with `widened`. Used by TransformWithState execs to widen + * the grouping-key portion of col-family key schemas without touching + * user-defined key/value portions. + */ + def widenGroupingKeyInSchema( + schema: StructType, + original: StructType, + widened: StructType): StructType = { + if (!isEnabled) return schema + if (structurallyMatches(schema, original)) { + widened + } else { + StructType(schema.fields.map { field => + field.dataType match { + case st: StructType if structurallyMatches(st, original) => + field.copy(dataType = widened) + case st: StructType => + field.copy(dataType = + widenGroupingKeyInSchema(st, original, widened)) + case _ => field + } + }) + } + } + + private def structurallyMatches( + a: StructType, b: StructType): Boolean = { + a.length == b.length && a.zip(b).forall { case (fa, fb) => + fa.name == fb.name && + fa.dataType.typeName == fb.dataType.typeName + } + } +} + +/** + * Component (c) of the stateful-operator nullability fix: a custom optimizer rule that + * widens `AttributeReference`s inside streaming-stateful operators' internal expressions + * and propagates the widening upward to ancestor operators' expressions. + * + * The rule does NOT introduce any new logical or physical node. It is purely an + * attribute-rewrite pass using `resolveOperatorsUp` (bottom-up): for every node whose + * subtree contains a stateful operator, collect `exprId`s from children's output, then + * deep-widen every `AttributeReference` in the node's expressions whose `exprId` is in + * that set via [[WidenStatefulOpNullability#deepWidenAttribute]]. + * + * At a stateful operator itself, all children's output attributes are included because + * the operator's internal expressions (e.g. grouping keys) reference them directly. + * At non-stateful ancestor operators, only children whose subtrees contain a stateful + * operator are included, to avoid unnecessary widening of non-stateful siblings. + * The node's own `p.output` is not needed for non-stateful ancestors because the + * bottom-up traversal guarantees children are already transformed, so their output + * attributes are already nullable and the ancestor's expressions reference those + * children's `exprId`s. + * + * '''Scope.''' The walk only fires on nodes whose subtree contains a stateful operator. + * + * '''Ordering constraint.''' This rule must run AFTER every `UpdateAttributeNullability` + * invocation in both the main optimizer and AQE. + * + * '''Idempotence.''' [[WidenStatefulOpNullability#deepWidenAttribute]] is idempotent. + */ +object WidenStatefulOperatorAttributeNullability extends Rule[LogicalPlan] { + + override def apply(plan: LogicalPlan): LogicalPlan = { + if (!conf.getConf(SQLConf.STATEFUL_OPERATOR_ALWAYS_NULLABLE_OUTPUT) || + !plan.containsStatefulOperator) { + return plan + } + plan.resolveOperatorsUp { + case p if !p.resolved => p + case p: LeafNode => p + case p if !p.containsStatefulOperator => p + case p => + val widenableAttrs = if (p.isStateful) { + p.output ++ p.children.flatMap(_.output) + } else { + p.children.filter(_.containsStatefulOperator).flatMap(_.output) + } + val widenableExprIds: Set[ExprId] = widenableAttrs + .iterator.collect { case ar: AttributeReference => ar.exprId }.toSet + if (widenableExprIds.isEmpty) { + p + } else { + p.transformExpressions { + case ar: AttributeReference if widenableExprIds.contains(ar.exprId) => + val widened = WidenStatefulOpNullability.deepWidenAttribute(ar) + if (ar.dataType == widened.dataType && ar.nullable == widened.nullable) { + ar + } else { + widened + } + } + } + } + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala index a7ad11848c3f5..9184c5ef412b6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.catalyst.{AliasIdentifier, InternalRow, SQLConfHelper} -import org.apache.spark.sql.catalyst.analysis.{Analyzer, AnsiTypeCoercion, MultiInstanceRelation, Resolver, TypeCoercion, TypeCoercionBase, UnresolvedUnaryNode} +import org.apache.spark.sql.catalyst.analysis.{Analyzer, AnsiTypeCoercion, MultiInstanceRelation, Resolver, TypeCoercion, TypeCoercionBase, UnresolvedUnaryNode, WidenStatefulOpNullability} import org.apache.spark.sql.catalyst.catalog.{CatalogStorageFormat, CatalogTable} import org.apache.spark.sql.catalyst.catalog.CatalogTable.VIEW_STORING_ANALYZED_PLAN import org.apache.spark.sql.catalyst.expressions._ @@ -746,7 +746,10 @@ case class Join( } } - override def output: Seq[Attribute] = Join.computeOutput(joinType, left.output, right.output) + override def output: Seq[Attribute] = { + val base = Join.computeOutput(joinType, left.output, right.output) + if (isStateful) WidenStatefulOpNullability.widenOutputForStatefulOp(base) else base + } override def metadataOutput: Seq[Attribute] = { joinType match { @@ -1225,7 +1228,10 @@ case class Aggregate( expressions.forall(_.resolved) && childrenResolved && !hasWindowExpressions } - override def output: Seq[Attribute] = aggregateExpressions.map(_.toAttribute) + override def output: Seq[Attribute] = { + val base = aggregateExpressions.map(_.toAttribute) + if (isStateful) WidenStatefulOpNullability.widenOutputForStatefulOp(base) else base + } override def metadataOutput: Seq[Attribute] = Nil override def maxRows: Option[Long] = { if (groupingExpressions.isEmpty) { @@ -1749,7 +1755,10 @@ object Limit { * order. */ case class GlobalLimit(limitExpr: Expression, child: LogicalPlan) extends UnaryNode { - override def output: Seq[Attribute] = child.output + override def output: Seq[Attribute] = { + val base = child.output + if (isStateful) WidenStatefulOpNullability.widenOutputForStatefulOp(base) else base + } override def maxRows: Option[Long] = { limitExpr match { case IntegerLiteral(limit) => Some(limit) @@ -2004,7 +2013,10 @@ case class Sample( */ case class Distinct(child: LogicalPlan) extends UnaryNode { override def maxRows: Option[Long] = child.maxRows - override def output: Seq[Attribute] = child.output + override def output: Seq[Attribute] = { + val base = child.output + if (isStateful) WidenStatefulOpNullability.widenOutputForStatefulOp(base) else base + } final override val nodePatterns: Seq[TreePattern] = Seq(DISTINCT_LIKE) override protected def withNewChildInternal(newChild: LogicalPlan): Distinct = copy(child = newChild) @@ -2174,7 +2186,10 @@ case class Deduplicate( keys: Seq[Attribute], child: LogicalPlan) extends UnaryNode { override def maxRows: Option[Long] = child.maxRows - override def output: Seq[Attribute] = child.output + override def output: Seq[Attribute] = { + val base = child.output + if (isStateful) WidenStatefulOpNullability.widenOutputForStatefulOp(base) else base + } final override val nodePatterns: Seq[TreePattern] = Seq(DISTINCT_LIKE) override protected def withNewChildInternal(newChild: LogicalPlan): Deduplicate = copy(child = newChild) @@ -2186,7 +2201,10 @@ case class DeduplicateWithinWatermark(keys: Seq[Attribute], child: LogicalPlan) override def references: AttributeSet = AttributeSet(keys) ++ AttributeSet(child.output.filter(_.metadata.contains(EventTimeWatermark.delayKey))) override def maxRows: Option[Long] = child.maxRows - override def output: Seq[Attribute] = child.output + override def output: Seq[Attribute] = { + val base = child.output + if (isStateful) WidenStatefulOpNullability.widenOutputForStatefulOp(base) else base + } final override val nodePatterns: Seq[TreePattern] = Seq(DISTINCT_LIKE) override protected def withNewChildInternal(newChild: LogicalPlan): DeduplicateWithinWatermark = copy(child = newChild) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala index 0c6f59073559f..720b0dd640d00 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.api.java.function.FilterFunction import org.apache.spark.broadcast.Broadcast import org.apache.spark.sql.{catalyst, Encoder, Row} -import org.apache.spark.sql.catalyst.analysis.{Resolver, UnresolvedDeserializer} +import org.apache.spark.sql.catalyst.analysis.{Resolver, UnresolvedDeserializer, WidenStatefulOpNullability} import org.apache.spark.sql.catalyst.encoders._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.objects.Invoke @@ -568,6 +568,11 @@ case class FlatMapGroupsWithState( newLeft: LogicalPlan, newRight: LogicalPlan): FlatMapGroupsWithState = copy(child = newLeft, initialState = newRight) override def isStateful: Boolean = child.isStreaming + + override def output: Seq[Attribute] = { + val base = super.output + if (isStateful) WidenStatefulOpNullability.widenOutputForStatefulOp(base) else base + } } object TransformWithState { @@ -657,6 +662,11 @@ case class TransformWithState( newLeft: LogicalPlan, newRight: LogicalPlan): TransformWithState = copy(child = newLeft, initialState = newRight) override def isStateful: Boolean = child.isStreaming + + override def output: Seq[Attribute] = { + val base = super.output + if (isStateful) WidenStatefulOpNullability.widenOutputForStatefulOp(base) else base + } } /** Factory for constructing new `FlatMapGroupsInR` nodes. */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/pythonLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/pythonLogicalOperators.scala index 56dc2f6de0437..31e7d94029687 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/pythonLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/pythonLogicalOperators.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.resource.ResourceProfile import org.apache.spark.sql.catalyst.SQLConfHelper -import org.apache.spark.sql.catalyst.analysis.{FunctionRegistryBase, MultiInstanceRelation, UnresolvedAttribute, UnresolvedStar} +import org.apache.spark.sql.catalyst.analysis.{FunctionRegistryBase, MultiInstanceRelation, UnresolvedAttribute, UnresolvedStar, WidenStatefulOpNullability} import org.apache.spark.sql.catalyst.analysis.TableFunctionRegistry.TableFunctionBuilder import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeSet, Expression, ExpressionDescription, ExpressionInfo, JsonToStructs, PythonUDF, PythonUDTF} import org.apache.spark.sql.catalyst.trees.TreePattern._ @@ -159,7 +159,9 @@ case class FlatMapGroupsInPandasWithState( timeout: GroupStateTimeout, child: LogicalPlan) extends UnaryNode { - override def output: Seq[Attribute] = outputAttrs + override def output: Seq[Attribute] = + if (isStateful) WidenStatefulOpNullability.widenOutputForStatefulOp(outputAttrs) + else outputAttrs override def producedAttributes: AttributeSet = AttributeSet(outputAttrs) @@ -206,7 +208,9 @@ case class TransformWithStateInPySpark( override def right: LogicalPlan = initialState - override def output: Seq[Attribute] = outputAttrs + override def output: Seq[Attribute] = + if (isStateful) WidenStatefulOpNullability.widenOutputForStatefulOp(outputAttrs) + else outputAttrs override def producedAttributes: AttributeSet = AttributeSet(outputAttrs) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 270b8aa31a565..25e4c5134b0e9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -3440,6 +3440,24 @@ object SQLConf { .booleanConf .createWithDefault(true) + val STATEFUL_OPERATOR_ALWAYS_NULLABLE_OUTPUT = + buildConf("spark.sql.streaming.statefulOperator.alwaysNullableOutput.enabled") + .internal() + .withBindingPolicy(ConfigBindingPolicy.SESSION) + .doc("When true, every streaming stateful operator reports its output schema with " + + "nullable=true on all columns (including nested struct fields, array elements, and " + + "map values), and the state schema is widened at every construction site, so the " + + "existing state schema " + + "compatibility check trivially passes regardless of input nullability. " + + "This prevents query-optimizer decisions (e.g., PropagateEmptyRelation dropping a " + + "Union branch) from flipping the state schema nullability across microbatches or " + + "restarts. The effective value is pinned per query via the offset log at batch 0, " + + "so pre-existing queries keep their original behavior; only newly started queries " + + "pick this up.") + .version("4.3.0") + .booleanConf + .createWithDefault(true) + val FILESTREAM_SINK_METADATA_IGNORED = buildConf("spark.sql.streaming.fileStreamSink.ignoreMetadata") .internal() diff --git a/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/streaming/ClientStreamingQuerySuite.scala b/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/streaming/ClientStreamingQuerySuite.scala index c8a25652dacbd..057e2fdc47750 100644 --- a/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/streaming/ClientStreamingQuerySuite.scala +++ b/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/streaming/ClientStreamingQuerySuite.scala @@ -86,7 +86,7 @@ class ClientStreamingQuerySuite extends QueryTest with RemoteSparkSession with L .count() .selectExpr("window.start as timestamp", "count as num_events") - assert(countsDF.schema.toDDL == "timestamp TIMESTAMP,num_events BIGINT NOT NULL") + assert(countsDF.schema.toDDL == "timestamp TIMESTAMP,num_events BIGINT") // Start the query val queryName = "sparkConnectStreamingQuery" diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AQEOptimizer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AQEOptimizer.scala index f16c6d9cfe6dd..3c23930090ab1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AQEOptimizer.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AQEOptimizer.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.execution.adaptive import org.apache.spark.internal.LogKeys.{BATCH_NAME, RULE_NAME} -import org.apache.spark.sql.catalyst.analysis.UpdateAttributeNullability +import org.apache.spark.sql.catalyst.analysis.{UpdateAttributeNullability, WidenStatefulOperatorAttributeNullability} import org.apache.spark.sql.catalyst.optimizer.{ConvertToLocalRelation, EliminateLimits, OptimizeOneRowPlan} import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, LogicalPlanIntegrity} import org.apache.spark.sql.catalyst.rules.{Rule, RuleExecutor} @@ -44,7 +44,8 @@ class AQEOptimizer(conf: SQLConf, extendedRuntimeOptimizerRules: Seq[Rule[Logica Batch("Dynamic Join Selection", Once, DynamicJoinSelection), Batch("Eliminate Limits", fixedPoint, EliminateLimits), Batch("Optimize One Row Plan", fixedPoint, OptimizeOneRowPlan)) :+ - Batch("User Provided Runtime Optimizers", fixedPoint, extendedRuntimeOptimizerRules: _*) + Batch("User Provided Runtime Optimizers", fixedPoint, extendedRuntimeOptimizerRules: _*) :+ + Batch("Widen Stateful Op Nullability", Once, WidenStatefulOperatorAttributeNullability) final override protected def batches: Seq[Batch] = { val excludedRules = conf.getConf(SQLConf.ADAPTIVE_OPTIMIZER_EXCLUDED_RULES) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/FlatMapGroupsInPandasWithStateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/FlatMapGroupsInPandasWithStateExec.scala index e9430ed9f9b7f..a61f905158361 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/FlatMapGroupsInPandasWithStateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/FlatMapGroupsInPandasWithStateExec.scala @@ -20,6 +20,7 @@ import org.apache.spark.{JobArtifactSet, SparkException, SparkUnsupportedOperati import org.apache.spark.api.python.{ChainedPythonFunctions, PythonEvalType} import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.analysis.WidenStatefulOpNullability import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical.{EventTimeTimeout, ProcessingTimeTimeout} @@ -81,7 +82,8 @@ case class FlatMapGroupsInPandasWithStateExec( override protected val stateEncoder: ExpressionEncoder[Any] = ExpressionEncoder(stateType).resolveAndBind().asInstanceOf[ExpressionEncoder[Any]] - override def output: Seq[Attribute] = outAttributes + override def output: Seq[Attribute] = + WidenStatefulOpNullability.widenOutputForStatefulOp(outAttributes) private val sessionLocalTimeZone = conf.sessionLocalTimeZone private val pythonRunnerConf = ArrowPythonRunner.getPythonRunnerConfMap(conf) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/TransformWithStateInPySparkExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/TransformWithStateInPySparkExec.scala index 45f2af5c1dfe8..16f7e1232463b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/TransformWithStateInPySparkExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/TransformWithStateInPySparkExec.scala @@ -27,6 +27,7 @@ import org.apache.spark.api.python.{ChainedPythonFunctions, PythonEvalType} import org.apache.spark.broadcast.Broadcast import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.analysis.WidenStatefulOpNullability import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, PythonUDF} import org.apache.spark.sql.catalyst.plans.logical.TransformWithStateInPySpark @@ -39,7 +40,7 @@ import org.apache.spark.sql.execution.streaming.operators.stateful.{StatefulOper import org.apache.spark.sql.execution.streaming.operators.stateful.join.StreamingSymmetricHashJoinHelper.StateStoreAwareZipPartitionsHelper import org.apache.spark.sql.execution.streaming.operators.stateful.transformwithstate.{TransformWithStateExecBase, TransformWithStateVariableInfo} import org.apache.spark.sql.execution.streaming.operators.stateful.transformwithstate.statefulprocessor.{DriverStatefulProcessorHandleImpl, StatefulProcessorHandleImpl} -import org.apache.spark.sql.execution.streaming.state.{NoPrefixKeyStateEncoderSpec, RocksDBStateStoreProvider, StateSchemaValidationResult, StateStore, StateStoreColFamilySchema, StateStoreConf, StateStoreId, StateStoreOps, StateStoreProvider, StateStoreProviderId} +import org.apache.spark.sql.execution.streaming.state.{NoPrefixKeyStateEncoderSpec, PrefixKeyScanStateEncoderSpec, RangeKeyScanStateEncoderSpec, RocksDBStateStoreProvider, StateSchemaValidationResult, StateStore, StateStoreColFamilySchema, StateStoreConf, StateStoreId, StateStoreOps, StateStoreProvider, StateStoreProviderId, TimestampAsPostfixKeyStateEncoderSpec, TimestampAsPrefixKeyStateEncoderSpec} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.streaming.{OutputMode, TimeMode} import org.apache.spark.sql.types.{BinaryType, StructField, StructType} @@ -51,7 +52,7 @@ import org.apache.spark.util.{CompletionIterator, SerializableConfiguration, Uti * * @param functionExpr function called on each group * @param groupingAttributes used to group the data - * @param output used to define the output rows + * @param outputAttrs used to define the output rows * @param outputMode defines the output mode for the statefulProcessor * @param timeMode The time mode semantics of the stateful processor for timers and TTL. * @param stateInfo Used to identify the state store for a given operator. @@ -69,7 +70,7 @@ import org.apache.spark.util.{CompletionIterator, SerializableConfiguration, Uti case class TransformWithStateInPySparkExec( functionExpr: Expression, groupingAttributes: Seq[Attribute], - output: Seq[Attribute], + outputAttrs: Seq[Attribute], outputMode: OutputMode, timeMode: TimeMode, stateInfo: Option[StatefulOperatorStateInfo], @@ -94,6 +95,9 @@ case class TransformWithStateInPySparkExec( initialStateGroupingAttrs, initialState) { + override def output: Seq[Attribute] = + WidenStatefulOpNullability.widenOutputForStatefulOp(outputAttrs) + // NOTE: This is needed to comply with existing release of transformWithStateInPandas. override def shortName: String = if ( userFacingDataType == TransformWithStateInPySpark.UserFacingDataType.PANDAS @@ -127,14 +131,47 @@ case class TransformWithStateInPySparkExec( // Each state variable has its own schema, this is a dummy one. protected val schemaForValueRow: StructType = new StructType().add("value", BinaryType) + private lazy val widenedGroupingKeySchema: StructType = + WidenStatefulOpNullability.widenStateSchema(groupingKeySchema) + override def getColFamilySchemas( shouldBeNullable: Boolean): Map[String, StateStoreColFamilySchema] = { // For Python, the user can explicitly set nullability on schema, so // we need to throw an error if the schema is nullable - driverProcessorHandle.getColumnFamilySchemas( + val schemas = driverProcessorHandle.getColumnFamilySchemas( shouldCheckNullable = shouldBeNullable, shouldSetNullable = shouldBeNullable ) + widenColFamilyGroupingKeys(schemas) + } + + private def widenColFamilyGroupingKeys( + schemas: Map[String, StateStoreColFamilySchema]) + : Map[String, StateStoreColFamilySchema] = { + val original = groupingKeySchema + val widened = widenedGroupingKeySchema + if (original == widened) return schemas + def widenKey(ks: StructType): StructType = + WidenStatefulOpNullability.widenGroupingKeyInSchema( + ks, original, widened) + schemas.map { case (name, cf) => + val widenedSpec = cf.keyStateEncoderSpec.map { + case NoPrefixKeyStateEncoderSpec(ks) => + NoPrefixKeyStateEncoderSpec(widenKey(ks)) + case PrefixKeyScanStateEncoderSpec(ks, n) => + PrefixKeyScanStateEncoderSpec(widenKey(ks), n) + case RangeKeyScanStateEncoderSpec(ks, o) => + RangeKeyScanStateEncoderSpec(widenKey(ks), o) + case TimestampAsPrefixKeyStateEncoderSpec(ks) => + TimestampAsPrefixKeyStateEncoderSpec(widenKey(ks)) + case TimestampAsPostfixKeyStateEncoderSpec(ks) => + TimestampAsPostfixKeyStateEncoderSpec(widenKey(ks)) + } + name -> cf.copy( + keySchema = widenKey(cf.keySchema), + valueSchema = WidenStatefulOpNullability.widenStateSchema(cf.valueSchema), + keyStateEncoderSpec = widenedSpec) + } } override def getStateVariableInfos(): Map[String, TransformWithStateVariableInfo] = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/checkpointing/OffsetSeq.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/checkpointing/OffsetSeq.scala index bf2278b814922..9ba99ac2c036e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/checkpointing/OffsetSeq.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/checkpointing/OffsetSeq.scala @@ -204,7 +204,8 @@ object OffsetSeqMetadata extends Logging { STATEFUL_OPERATOR_USE_STRICT_DISTRIBUTION, PRUNE_FILTERS_CAN_PRUNE_STREAMING_SUBPLAN, STREAMING_STATE_STORE_ENCODING_FORMAT, STATE_STORE_ROW_CHECKSUM_ENABLED, PROTOBUF_EXTENSIONS_SUPPORT_ENABLED, - ENABLE_STREAMING_SOURCE_EVOLUTION + ENABLE_STREAMING_SOURCE_EVOLUTION, + STATEFUL_OPERATOR_ALWAYS_NULLABLE_OUTPUT ) /** @@ -254,7 +255,8 @@ object OffsetSeqMetadata extends Logging { STATE_STORE_ROW_CHECKSUM_ENABLED.key -> "false", STATE_STORE_ROCKSDB_MERGE_OPERATOR_VERSION.key -> "1", PROTOBUF_EXTENSIONS_SUPPORT_ENABLED.key -> "false", - ENABLE_STREAMING_SOURCE_EVOLUTION.key -> "false" + ENABLE_STREAMING_SOURCE_EVOLUTION.key -> "false", + STATEFUL_OPERATOR_ALWAYS_NULLABLE_OUTPUT.key -> "false" ) def readValue[T](metadataLog: OffsetSeqMetadataBase, confKey: ConfigEntry[T]): String = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/flatmapgroupswithstate/FlatMapGroupsWithStateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/flatmapgroupswithstate/FlatMapGroupsWithStateExec.scala index 6b9f90a9ab5ca..48d1dad70f5e9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/flatmapgroupswithstate/FlatMapGroupsWithStateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/flatmapgroupswithstate/FlatMapGroupsWithStateExec.scala @@ -25,6 +25,7 @@ import org.apache.hadoop.conf.Configuration import org.apache.spark.{SparkException, SparkThrowable} import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.analysis.WidenStatefulOpNullability import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.catalyst.expressions.{Ascending, Attribute, Expression, SortOrder, UnsafeRow} import org.apache.spark.sql.catalyst.plans.logical._ @@ -36,6 +37,7 @@ import org.apache.spark.sql.execution.streaming.operators.stateful.join.Streamin import org.apache.spark.sql.execution.streaming.state._ import org.apache.spark.sql.streaming.{GroupStateTimeout, OutputMode} import org.apache.spark.sql.streaming.GroupStateTimeout.NoTimeout +import org.apache.spark.sql.types.StructType import org.apache.spark.util.{CompletionIterator, SerializableConfiguration} /** @@ -72,6 +74,11 @@ trait FlatMapGroupsWithStateExecBase lazy val stateManager: StateManager = createStateManager(stateEncoder, isTimeoutEnabled, stateFormatVersion) + private lazy val stateKeySchema: StructType = + WidenStatefulOpNullability.widenStateSchema(groupingAttributes.toStructType) + private lazy val stateValueSchema: StructType = + WidenStatefulOpNullability.widenStateSchema(stateManager.stateSchema) + /** * Distribute by grouping attributes - We need the underlying data and the initial state data * to have the same grouping so that the data are co-lacated on the same task. @@ -200,7 +207,7 @@ trait FlatMapGroupsWithStateExecBase batchId: Long, stateSchemaVersion: Int): List[StateSchemaValidationResult] = { val newStateSchema = List(StateStoreColFamilySchema(StateStore.DEFAULT_COL_FAMILY_NAME, 0, - groupingAttributes.toStructType, 0, stateManager.stateSchema)) + stateKeySchema, 0, stateValueSchema)) List(StateSchemaCompatibilityChecker.validateAndMaybeEvolveStateSchema(getStateInfo, hadoopConf, newStateSchema, session.sessionState, stateSchemaVersion)) } @@ -243,9 +250,9 @@ trait FlatMapGroupsWithStateExecBase val storeProviderId = StateStoreProviderId(stateStoreId, stateInfo.get.queryRunId) val store = StateStore.get( storeProviderId, - groupingAttributes.toStructType, - stateManager.stateSchema, - NoPrefixKeyStateEncoderSpec(groupingAttributes.toStructType), + stateKeySchema, + stateValueSchema, + NoPrefixKeyStateEncoderSpec(stateKeySchema), stateInfo.get.storeVersion, stateInfo.get.getStateStoreCkptId(partitionId).map(_.head), None, @@ -257,9 +264,9 @@ trait FlatMapGroupsWithStateExecBase } else { child.execute().mapPartitionsWithStateStore[InternalRow]( getStateInfo, - groupingAttributes.toStructType, - stateManager.stateSchema, - NoPrefixKeyStateEncoderSpec(groupingAttributes.toStructType), + stateKeySchema, + stateValueSchema, + NoPrefixKeyStateEncoderSpec(stateKeySchema), session.sessionState, Some(session.streams.stateStoreCoordinator) ) { case (store: StateStore, singleIterator: Iterator[InternalRow]) => @@ -425,6 +432,9 @@ case class FlatMapGroupsWithStateExec( skipEmittingInitialStateKeys: Boolean, child: SparkPlan) extends FlatMapGroupsWithStateExecBase with BinaryExecNode with ObjectProducerExec { + override def output: Seq[Attribute] = + WidenStatefulOpNullability.widenOutputForStatefulOp(super.output) + import GroupStateImpl._ import FlatMapGroupsWithStateExecHelper._ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/join/StreamingSymmetricHashJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/join/StreamingSymmetricHashJoinExec.scala index 9eca04c985913..8f90a603c7efb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/join/StreamingSymmetricHashJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/join/StreamingSymmetricHashJoinExec.scala @@ -24,6 +24,7 @@ import org.apache.hadoop.conf.Configuration import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.StreamingJoinHelper +import org.apache.spark.sql.catalyst.analysis.WidenStatefulOpNullability import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeSet, Expression, GenericInternalRow, JoinedRow, Literal, Predicate, UnsafeProjection, UnsafeRow} import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.physical._ @@ -231,13 +232,16 @@ case class StreamingSymmetricHashJoinExec( StatefulOpClusteredDistribution(leftKeys, getStateInfo.numPartitions) :: StatefulOpClusteredDistribution(rightKeys, getStateInfo.numPartitions) :: Nil - override def output: Seq[Attribute] = joinType match { - case _: InnerLike => left.output ++ right.output - case LeftOuter => left.output ++ right.output.map(_.withNullability(true)) - case RightOuter => left.output.map(_.withNullability(true)) ++ right.output - case FullOuter => (left.output ++ right.output).map(_.withNullability(true)) - case LeftSemi => left.output - case _ => throwBadJoinTypeException() + override def output: Seq[Attribute] = { + val base = joinType match { + case _: InnerLike => left.output ++ right.output + case LeftOuter => left.output ++ right.output.map(_.withNullability(true)) + case RightOuter => left.output.map(_.withNullability(true)) ++ right.output + case FullOuter => (left.output ++ right.output).map(_.withNullability(true)) + case LeftSemi => left.output + case _ => throwBadJoinTypeException() + } + WidenStatefulOpNullability.widenOutputForStatefulOp(base) } override def outputPartitioning: Partitioning = joinType match { @@ -279,11 +283,16 @@ case class StreamingSymmetricHashJoinExec( override def getColFamilySchemas( shouldBeNullable: Boolean): Map[String, StateStoreColFamilySchema] = { assert(useVirtualColumnFamilies) - // We only have one state store for the join, but there are four distinct schemas - SymmetricHashJoinStateManager + val raw = SymmetricHashJoinStateManager .getSchemasForStateStoreWithColFamily(LeftSide, left.output, leftKeys, stateFormatVersion) ++ - SymmetricHashJoinStateManager - .getSchemasForStateStoreWithColFamily(RightSide, right.output, rightKeys, stateFormatVersion) + SymmetricHashJoinStateManager + .getSchemasForStateStoreWithColFamily( + RightSide, right.output, rightKeys, stateFormatVersion) + raw.map { case (name, cf) => + name -> cf.copy( + keySchema = WidenStatefulOpNullability.widenStateSchema(cf.keySchema), + valueSchema = WidenStatefulOpNullability.widenStateSchema(cf.valueSchema)) + } } override def shouldRunAnotherBatch(newInputWatermark: Long): Boolean = { @@ -328,7 +337,8 @@ case class StreamingSymmetricHashJoinExec( // we have to add the default column family schema because the RocksDBStateEncoder // expects this entry to be present in the stateSchemaProvider. val newStateSchema = List(StateStoreColFamilySchema(StateStore.DEFAULT_COL_FAMILY_NAME, 0, - keySchema, 0, valueSchema)) + WidenStatefulOpNullability.widenStateSchema(keySchema), 0, + WidenStatefulOpNullability.widenStateSchema(valueSchema))) StateSchemaCompatibilityChecker.validateAndMaybeEvolveStateSchema(getStateInfo, hadoopConf, newStateSchema, session.sessionState, stateSchemaVersion, storeName = stateStoreName) }.toList diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/statefulOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/statefulOperators.scala index 59a2b9ee74f85..022fa3469eea5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/statefulOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/statefulOperators.scala @@ -31,6 +31,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.rdd.RDD import org.apache.spark.sql.{AnalysisException, SparkSession} import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.analysis.WidenStatefulOpNullability import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection import org.apache.spark.sql.catalyst.optimizer.NormalizeFloatingNumbers @@ -767,11 +768,16 @@ case class StateStoreRestoreExec( private[sql] val stateManager = StreamingAggregationStateManager.createStateManager( keyExpressions, child.output, stateFormatVersion) + private val stateKeySchema: StructType = + WidenStatefulOpNullability.widenStateSchema(keyExpressions.toStructType) + private val stateValueSchema: StructType = + WidenStatefulOpNullability.widenStateSchema(stateManager.getStateValueSchema) + override def validateAndMaybeEvolveStateSchema( hadoopConf: Configuration, batchId: Long, stateSchemaVersion: Int): List[StateSchemaValidationResult] = { val newStateSchema = List(StateStoreColFamilySchema(StateStore.DEFAULT_COL_FAMILY_NAME, - 0, keyExpressions.toStructType, 0, stateManager.getStateValueSchema)) + 0, stateKeySchema, 0, stateValueSchema)) List(StateSchemaCompatibilityChecker.validateAndMaybeEvolveStateSchema(getStateInfo, hadoopConf, newStateSchema, session.sessionState, stateSchemaVersion)) } @@ -781,9 +787,9 @@ case class StateStoreRestoreExec( child.execute().mapPartitionsWithReadStateStore( getStateInfo, - keyExpressions.toStructType, - stateManager.getStateValueSchema, - NoPrefixKeyStateEncoderSpec(keyExpressions.toStructType), + stateKeySchema, + stateValueSchema, + NoPrefixKeyStateEncoderSpec(stateKeySchema), session.sessionState, Some(session.streams.stateStoreCoordinator)) { case (store, iter) => val hasInput = iter.hasNext @@ -805,7 +811,8 @@ case class StateStoreRestoreExec( } } - override def output: Seq[Attribute] = child.output + override def output: Seq[Attribute] = + WidenStatefulOpNullability.widenOutputForStatefulOp(child.output) override def outputPartitioning: Partitioning = child.outputPartitioning @@ -838,13 +845,18 @@ case class StateStoreSaveExec( private[sql] val stateManager = StreamingAggregationStateManager.createStateManager( keyExpressions, child.output, stateFormatVersion) + private val stateKeySchema: StructType = + WidenStatefulOpNullability.widenStateSchema(keyExpressions.toStructType) + private val stateValueSchema: StructType = + WidenStatefulOpNullability.widenStateSchema(stateManager.getStateValueSchema) + override def validateAndMaybeEvolveStateSchema( hadoopConf: Configuration, batchId: Long, stateSchemaVersion: Int): List[StateSchemaValidationResult] = { val newStateSchema = List(StateStoreColFamilySchema(StateStore.DEFAULT_COL_FAMILY_NAME, - keySchemaId = 0, keyExpressions.toStructType, valueSchemaId = 0, - stateManager.getStateValueSchema)) + keySchemaId = 0, stateKeySchema, valueSchemaId = 0, + stateValueSchema)) List(StateSchemaCompatibilityChecker.validateAndMaybeEvolveStateSchema(getStateInfo, hadoopConf, newStateSchema, session.sessionState, stateSchemaVersion)) } @@ -856,9 +868,9 @@ case class StateStoreSaveExec( child.execute().mapPartitionsWithStateStore( getStateInfo, - keyExpressions.toStructType, - stateManager.getStateValueSchema, - NoPrefixKeyStateEncoderSpec(keyExpressions.toStructType), + stateKeySchema, + stateValueSchema, + NoPrefixKeyStateEncoderSpec(stateKeySchema), session.sessionState, Some(session.streams.stateStoreCoordinator)) { (store, iter) => val numOutputRows = longMetric("numOutputRows") @@ -1000,7 +1012,8 @@ case class StateStoreSaveExec( } } - override def output: Seq[Attribute] = child.output + override def output: Seq[Attribute] = + WidenStatefulOpNullability.widenOutputForStatefulOp(child.output) override def outputPartitioning: Partitioning = child.outputPartitioning @@ -1054,12 +1067,17 @@ case class SessionWindowStateStoreRestoreExec( private val stateManager = StreamingSessionWindowStateManager.createStateManager( keyWithoutSessionExpressions, sessionExpression, child.output, stateFormatVersion) + private val stateKeySchema: StructType = + WidenStatefulOpNullability.widenStateSchema(stateManager.getStateKeySchema) + private val stateValueSchema: StructType = + WidenStatefulOpNullability.widenStateSchema(stateManager.getStateValueSchema) + override def validateAndMaybeEvolveStateSchema( hadoopConf: Configuration, batchId: Long, stateSchemaVersion: Int): List[StateSchemaValidationResult] = { val newStateSchema = List(StateStoreColFamilySchema(StateStore.DEFAULT_COL_FAMILY_NAME, - keySchemaId = 0, stateManager.getStateKeySchema, valueSchemaId = 0, - stateManager.getStateValueSchema)) + keySchemaId = 0, stateKeySchema, valueSchemaId = 0, + stateValueSchema)) List(StateSchemaCompatibilityChecker.validateAndMaybeEvolveStateSchema(getStateInfo, hadoopConf, newStateSchema, session.sessionState, stateSchemaVersion)) } @@ -1069,9 +1087,9 @@ case class SessionWindowStateStoreRestoreExec( child.execute().mapPartitionsWithReadStateStore( getStateInfo, - stateManager.getStateKeySchema, - stateManager.getStateValueSchema, - PrefixKeyScanStateEncoderSpec(stateManager.getStateKeySchema, + stateKeySchema, + stateValueSchema, + PrefixKeyScanStateEncoderSpec(stateKeySchema, stateManager.getNumColsForPrefixKey), session.sessionState, Some(session.streams.stateStoreCoordinator)) { case (store, iter) => @@ -1099,7 +1117,8 @@ case class SessionWindowStateStoreRestoreExec( } } - override def output: Seq[Attribute] = child.output + override def output: Seq[Attribute] = + WidenStatefulOpNullability.widenOutputForStatefulOp(child.output) override def outputPartitioning: Partitioning = child.outputPartitioning @@ -1147,11 +1166,16 @@ case class SessionWindowStateStoreSaveExec( private val stateManager = StreamingSessionWindowStateManager.createStateManager( keyWithoutSessionExpressions, sessionExpression, child.output, stateFormatVersion) + private val stateKeySchema: StructType = + WidenStatefulOpNullability.widenStateSchema(stateManager.getStateKeySchema) + private val stateValueSchema: StructType = + WidenStatefulOpNullability.widenStateSchema(stateManager.getStateValueSchema) + override def validateAndMaybeEvolveStateSchema( hadoopConf: Configuration, batchId: Long, stateSchemaVersion: Int): List[StateSchemaValidationResult] = { val newStateSchema = List(StateStoreColFamilySchema(StateStore.DEFAULT_COL_FAMILY_NAME, 0, - stateManager.getStateKeySchema, 0, stateManager.getStateValueSchema)) + stateKeySchema, 0, stateValueSchema)) List(StateSchemaCompatibilityChecker.validateAndMaybeEvolveStateSchema(getStateInfo, hadoopConf, newStateSchema, session.sessionState, stateSchemaVersion)) } @@ -1165,9 +1189,9 @@ case class SessionWindowStateStoreSaveExec( child.execute().mapPartitionsWithStateStore( getStateInfo, - stateManager.getStateKeySchema, - stateManager.getStateValueSchema, - PrefixKeyScanStateEncoderSpec(stateManager.getStateKeySchema, + stateKeySchema, + stateValueSchema, + PrefixKeyScanStateEncoderSpec(stateKeySchema, stateManager.getNumColsForPrefixKey), session.sessionState, Some(session.streams.stateStoreCoordinator)) { case (store, iter) => @@ -1251,7 +1275,8 @@ case class SessionWindowStateStoreSaveExec( } } - override def output: Seq[Attribute] = child.output + override def output: Seq[Attribute] = + WidenStatefulOpNullability.widenOutputForStatefulOp(child.output) override def outputPartitioning: Partitioning = child.outputPartitioning @@ -1355,14 +1380,19 @@ abstract class BaseStreamingDeduplicateExec protected val schemaForValueRow: StructType protected val extraOptionOnStateStore: Map[String, String] + protected lazy val stateKeySchema: StructType = + WidenStatefulOpNullability.widenStateSchema(keyExpressions.toStructType) + protected lazy val stateValueSchema: StructType = + WidenStatefulOpNullability.widenStateSchema(schemaForValueRow) + override protected def doExecute(): RDD[InternalRow] = { metrics // force lazy init at driver child.execute().mapPartitionsWithStateStore( getStateInfo, - keyExpressions.toStructType, - schemaForValueRow, - NoPrefixKeyStateEncoderSpec(keyExpressions.toStructType), + stateKeySchema, + stateValueSchema, + NoPrefixKeyStateEncoderSpec(stateKeySchema), session.sessionState, Some(session.streams.stateStoreCoordinator), extraOptions = extraOptionOnStateStore) { (store, iter) => @@ -1422,7 +1452,8 @@ abstract class BaseStreamingDeduplicateExec protected def evictDupInfoFromState(store: StateStore): Unit - override def output: Seq[Attribute] = child.output + override def output: Seq[Attribute] = + WidenStatefulOpNullability.widenOutputForStatefulOp(child.output) override def outputPartitioning: Partitioning = child.outputPartitioning @@ -1476,7 +1507,7 @@ case class StreamingDeduplicateExec( hadoopConf: Configuration, batchId: Long, stateSchemaVersion: Int): List[StateSchemaValidationResult] = { val newStateSchema = List(StateStoreColFamilySchema(StateStore.DEFAULT_COL_FAMILY_NAME, 0, - keyExpressions.toStructType, 0, schemaForValueRow)) + stateKeySchema, 0, stateValueSchema)) List(StateSchemaCompatibilityChecker.validateAndMaybeEvolveStateSchema(getStateInfo, hadoopConf, newStateSchema, session.sessionState, stateSchemaVersion, extraOptions = extraOptionOnStateStore)) @@ -1562,7 +1593,7 @@ case class StreamingDeduplicateWithinWatermarkExec( hadoopConf: Configuration, batchId: Long, stateSchemaVersion: Int): List[StateSchemaValidationResult] = { val newStateSchema = List(StateStoreColFamilySchema(StateStore.DEFAULT_COL_FAMILY_NAME, 0, - keyExpressions.toStructType, 0, schemaForValueRow)) + stateKeySchema, 0, stateValueSchema)) List(StateSchemaCompatibilityChecker.validateAndMaybeEvolveStateSchema(getStateInfo, hadoopConf, newStateSchema, session.sessionState, stateSchemaVersion, extraOptions = extraOptionOnStateStore)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/streamingLimits.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/streamingLimits.scala index 6816be103f6e2..da54c0ce0fe6e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/streamingLimits.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/streamingLimits.scala @@ -22,6 +22,7 @@ import org.apache.hadoop.conf.Configuration import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.analysis.WidenStatefulOpNullability import org.apache.spark.sql.catalyst.expressions.{Attribute, GenericInternalRow, SortOrder, UnsafeProjection, UnsafeRow} import org.apache.spark.sql.catalyst.plans.physical.{AllTuples, Distribution, Partitioning} import org.apache.spark.sql.execution.{LimitExec, SparkPlan, UnaryExecNode} @@ -98,7 +99,8 @@ case class StreamingGlobalLimitExec( } } - override def output: Seq[Attribute] = child.output + override def output: Seq[Attribute] = + WidenStatefulOpNullability.widenOutputForStatefulOp(child.output) override def outputPartitioning: Partitioning = child.outputPartitioning diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/transformwithstate/TransformWithStateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/transformwithstate/TransformWithStateExec.scala index b200bde96cbc9..d4d1ecb6d9b88 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/transformwithstate/TransformWithStateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/transformwithstate/TransformWithStateExec.scala @@ -24,6 +24,7 @@ import org.apache.hadoop.conf.Configuration import org.apache.spark.broadcast.Broadcast import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.analysis.WidenStatefulOpNullability import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, UnsafeRow} import org.apache.spark.sql.catalyst.plans.logical._ @@ -35,6 +36,7 @@ import org.apache.spark.sql.execution.streaming.operators.stateful.transformwith import org.apache.spark.sql.execution.streaming.state._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.streaming._ +import org.apache.spark.sql.types.StructType import org.apache.spark.util.{CompletionIterator, SerializableConfiguration, Utils} /** @@ -88,6 +90,12 @@ case class TransformWithStateExec( initialState) with ObjectProducerExec { + override def output: Seq[Attribute] = + WidenStatefulOpNullability.widenOutputForStatefulOp(super.output) + + private lazy val stateKeySchema: StructType = + WidenStatefulOpNullability.widenStateSchema(keyExpressions.toStructType) + override def shortName: String = StatefulOperatorsUtils.TRANSFORM_WITH_STATE_EXEC_OP_NAME // We need to just initialize key and value deserializer once per partition. @@ -133,12 +141,11 @@ case class TransformWithStateExec( override def getColFamilySchemas( shouldBeNullable: Boolean ): Map[String, StateStoreColFamilySchema] = { - val keySchema = keyExpressions.toStructType // we have to add the default column family schema because the RocksDBStateEncoder // expects this entry to be present in the stateSchemaProvider. val defaultSchema = StateStoreColFamilySchema(StateStore.DEFAULT_COL_FAMILY_NAME, - 0, keyExpressions.toStructType, 0, DUMMY_VALUE_ROW_SCHEMA, - Some(NoPrefixKeyStateEncoderSpec(keySchema))) + 0, stateKeySchema, 0, DUMMY_VALUE_ROW_SCHEMA, + Some(NoPrefixKeyStateEncoderSpec(stateKeySchema))) // For Scala, the user can't explicitly set nullability on schema, so there is // no reason to throw an error, and we can simply set the schema to nullable. @@ -147,7 +154,35 @@ case class TransformWithStateExec( shouldCheckNullable = false, shouldSetNullable = shouldBeNullable) ++ Map(StateStore.DEFAULT_COL_FAMILY_NAME -> defaultSchema) closeProcessorHandle() - columnFamilySchemas + widenColFamilyGroupingKeys(columnFamilySchemas) + } + + private def widenColFamilyGroupingKeys( + schemas: Map[String, StateStoreColFamilySchema]) + : Map[String, StateStoreColFamilySchema] = { + val original = keyEncoder.schema + val widened = stateKeySchema + if (original == widened) return schemas + def widenKey(ks: StructType): StructType = + WidenStatefulOpNullability.widenGroupingKeyInSchema(ks, original, widened) + schemas.map { case (name, cf) => + val widenedSpec = cf.keyStateEncoderSpec.map { + case NoPrefixKeyStateEncoderSpec(ks) => + NoPrefixKeyStateEncoderSpec(widenKey(ks)) + case PrefixKeyScanStateEncoderSpec(ks, n) => + PrefixKeyScanStateEncoderSpec(widenKey(ks), n) + case RangeKeyScanStateEncoderSpec(ks, o) => + RangeKeyScanStateEncoderSpec(widenKey(ks), o) + case TimestampAsPrefixKeyStateEncoderSpec(ks) => + TimestampAsPrefixKeyStateEncoderSpec(widenKey(ks)) + case TimestampAsPostfixKeyStateEncoderSpec(ks) => + TimestampAsPostfixKeyStateEncoderSpec(widenKey(ks)) + } + name -> cf.copy( + keySchema = widenKey(cf.keySchema), + valueSchema = WidenStatefulOpNullability.widenStateSchema(cf.valueSchema), + keyStateEncoderSpec = widenedSpec) + } } override def getStateVariableInfos(): Map[String, TransformWithStateVariableInfo] = { @@ -401,9 +436,9 @@ case class TransformWithStateExec( val storeProviderId = StateStoreProviderId(stateStoreId, stateInfo.get.queryRunId) val store = StateStore.get( storeProviderId = storeProviderId, - keyEncoder.schema, + stateKeySchema, DUMMY_VALUE_ROW_SCHEMA, - NoPrefixKeyStateEncoderSpec(keyEncoder.schema), + NoPrefixKeyStateEncoderSpec(stateKeySchema), version = stateInfo.get.storeVersion, stateStoreCkptId = stateInfo.get.getStateStoreCkptId(partitionId).map(_.head), stateSchemaBroadcast = stateInfo.get.stateSchemaMetadata, @@ -423,9 +458,9 @@ case class TransformWithStateExec( if (isStreaming) { child.execute().mapPartitionsWithStateStore[InternalRow]( getStateInfo, - keyEncoder.schema, + stateKeySchema, DUMMY_VALUE_ROW_SCHEMA, - NoPrefixKeyStateEncoderSpec(keyEncoder.schema), + NoPrefixKeyStateEncoderSpec(stateKeySchema), session.sessionState, Some(session.streams.stateStoreCoordinator), useColumnFamilies = true @@ -473,9 +508,9 @@ case class TransformWithStateExec( // Create StateStoreProvider for this partition val stateStoreProvider = StateStoreProvider.createAndInit( providerId, - keyEncoder.schema, + stateKeySchema, DUMMY_VALUE_ROW_SCHEMA, - NoPrefixKeyStateEncoderSpec(keyEncoder.schema), + NoPrefixKeyStateEncoderSpec(stateKeySchema), useColumnFamilies = true, storeConf = storeConf, hadoopConf = hadoopConfBroadcast.value.value, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/runtime/IncrementalExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/runtime/IncrementalExecution.scala index 9fc72241e83b0..0d2e4a6941a00 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/runtime/IncrementalExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/runtime/IncrementalExecution.scala @@ -27,6 +27,7 @@ import org.apache.hadoop.fs.Path import org.apache.spark.internal.Logging import org.apache.spark.internal.LogKeys.{BATCH_TIMESTAMP, ERROR} import org.apache.spark.sql.catalyst.QueryPlanningTracker +import org.apache.spark.sql.catalyst.analysis.WidenStatefulOperatorAttributeNullability import org.apache.spark.sql.catalyst.expressions.{CurrentBatchTimestamp, ExpressionWithRandomSeed} import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules.Rule @@ -133,7 +134,7 @@ class IncrementalExecution( // of sink information. case w: WriteToMicroBatchDataSourceV1 => w.child } - sparkSession.sessionState.optimizer.executeAndTrack(preOptimized, + val optimized = sparkSession.sessionState.optimizer.executeAndTrack(preOptimized, tracker).transformAllExpressionsWithPruning( _.containsAnyPattern(CURRENT_LIKE, EXPRESSION_WITH_RANDOM_SEED)) { case ts @ CurrentBatchTimestamp(timestamp, _, _) => @@ -141,6 +142,7 @@ class IncrementalExecution( ts.toLiteral case e: ExpressionWithRandomSeed => e.withNewSeed(Utils.random.nextLong()) } + WidenStatefulOperatorAttributeNullability(optimized) } // Use `this` for explain so the already-open transaction and executedPlan are reused. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala index 1e1aa451a0aeb..c46f0076721b9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala @@ -1181,16 +1181,16 @@ abstract class StreamingInnerJoinSuite extends StreamingInnerJoinBase { val hadoopConf = spark.sessionState.newHadoopConf() val fm = CheckpointFileManager.create(stateSchemaPath, hadoopConf) - val keySchemaForNums = new StructType().add("field0", IntegerType, nullable = false) + val keySchemaForNums = new StructType().add("field0", IntegerType) val keySchemaForIndex = keySchemaForNums.add("index", LongType) val numSchema: StructType = new StructType().add("value", LongType) val leftIndexSchema: StructType = new StructType() - .add("key", IntegerType, nullable = false) - .add("leftValue", IntegerType, nullable = false) + .add("key", IntegerType) + .add("leftValue", IntegerType) .add("matched", BooleanType) val rightIndexSchema: StructType = new StructType() - .add("key", IntegerType, nullable = false) - .add("rightValue", IntegerType, nullable = false) + .add("key", IntegerType) + .add("rightValue", IntegerType) .add("matched", BooleanType) val schemaLeftIndex = StateStoreColFamilySchema( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinV4Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinV4Suite.scala index e58af3b2bf651..6d4a97861efef 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinV4Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinV4Suite.scala @@ -112,16 +112,16 @@ class StreamingInnerJoinV4Suite CheckpointFileManager.create(stateSchemaPath, hadoopConf) val keySchemaWithTimestamp = new StructType() - .add("field0", IntegerType, nullable = false) - .add("__event_time", LongType, nullable = false) + .add("field0", IntegerType) + .add("__event_time", LongType) val leftValueSchema: StructType = new StructType() - .add("key", IntegerType, nullable = false) - .add("leftValue", IntegerType, nullable = false) + .add("key", IntegerType) + .add("leftValue", IntegerType) .add("matched", BooleanType) val rightValueSchema: StructType = new StructType() - .add("key", IntegerType, nullable = false) - .add("rightValue", IntegerType, nullable = false) + .add("key", IntegerType) + .add("rightValue", IntegerType) .add("matched", BooleanType) val dummyValueSchema = diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingStatefulOperatorNullabilityDriftSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingStatefulOperatorNullabilityDriftSuite.scala new file mode 100644 index 0000000000000..5278f68fbb0ad --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingStatefulOperatorNullabilityDriftSuite.scala @@ -0,0 +1,534 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.streaming + +import org.apache.hadoop.fs.Path + +import org.apache.spark.SparkUnsupportedOperationException +import org.apache.spark.sql.{DataFrame, Encoders} +import org.apache.spark.sql.catalyst.analysis.WidenStatefulOpNullability +import org.apache.spark.sql.execution.streaming.checkpointing.CheckpointFileManager +import org.apache.spark.sql.execution.streaming.runtime.MemoryStream +import org.apache.spark.sql.execution.streaming.state.{RocksDBStateStoreProvider, StateSchemaCompatibilityChecker, StateStore} +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types.{ArrayType, DataType, MapType, StructType} + +/** + * Regression suite for stateful-operator nullability drift. + * + * Driver: `PropagateEmptyRelation` drops empty `Union` branches without a streaming + * guard, so the surviving branch's per-column nullability becomes the Union's + * nullability and propagates into a stateful operator above -- across microbatches or + * restarts. + * + * Coverage: + * - New-query (default conf): originally-failing scenarios now complete cleanly. + * - Existing-query (conf forced false): pre-fix behavior preserved verbatim. + * - Helper invariant: `WidenStatefulOpNullability.deepWidenAttribute` recurses into + * nested types. + */ +class StreamingStatefulOperatorNullabilityDriftSuite extends StreamTest { + + import testImplicits._ + + private def buildTwoSources(): (MemoryStream[Int], MemoryStream[Int], DataFrame, DataFrame) = { + val inputA = MemoryStream[Int] + val inputB = MemoryStream[Int] + + val dfA = inputA.toDF().select($"value".as("key")) + val dfB = inputB.toDF() + .select(when($"value" > Int.MinValue, $"value") + .otherwise(lit(null).cast("int")) + .as("key")) + + (inputA, inputB, dfA, dfB) + } + + private def buildTwoSourcesWithWatermark() + : (MemoryStream[Int], MemoryStream[Int], DataFrame, DataFrame) = { + val inputA = MemoryStream[Int] + val inputB = MemoryStream[Int] + + val dfA = inputA.toDF() + .select($"value".as("key"), + timestamp_seconds($"value").as("ts")) + .withWatermark("ts", "1 minute") + val dfB = inputB.toDF() + .select(when($"value" > Int.MinValue, $"value") + .otherwise(lit(null).cast("int")).as("key"), + timestamp_seconds($"value").as("ts")) + .withWatermark("ts", "1 minute") + + (inputA, inputB, dfA, dfB) + } + + private def runUnionBranchDropRestart( + buildSources: () => (MemoryStream[Int], MemoryStream[Int], DataFrame, DataFrame), + buildQuery: (DataFrame, DataFrame) => DataFrame, + outputMode: OutputMode, + nullableToNonNullable: Boolean): Unit = { + withTempDir { checkpointDir => + val checkpointPath = checkpointDir.getAbsolutePath + + val (inputA, inputB, dfA, dfB) = buildSources() + val q = buildQuery(dfA, dfB) + + if (nullableToNonNullable) { + testStream(q, outputMode)( + StartStream(checkpointLocation = checkpointPath), + MultiAddData(inputA, 1, 2, 3)(inputB, 4, 5), + ProcessAllAvailable(), + StopStream + ) + } else { + testStream(q, outputMode)( + StartStream(checkpointLocation = checkpointPath), + AddData(inputA, 1, 2, 3), + ProcessAllAvailable(), + StopStream + ) + } + + assertJournaledStateSchemaAllNullable(checkpointPath) + + if (nullableToNonNullable) { + testStream(q, outputMode)( + StartStream(checkpointLocation = checkpointPath), + AddData(inputA, 6), + ProcessAllAvailable() + ) + } else { + testStream(q, outputMode)( + StartStream(checkpointLocation = checkpointPath), + MultiAddData(inputA, 6)(inputB, 7), + ProcessAllAvailable() + ) + } + } + } + + private def assertJournaledStateSchemaAllNullable(checkpointPath: String): Unit = { + val partId = StateStore.PARTITION_ID_TO_CHECK_SCHEMA + val operatorRoot = new Path(checkpointPath, "state/0") + val partitionRoot = new Path(operatorRoot, s"$partId") + val hadoopConf = spark.sessionState.newHadoopConf() + val fm = CheckpointFileManager.create(operatorRoot, hadoopConf) + val fs = operatorRoot.getFileSystem(hadoopConf) + + def collectSchemaFiles(dir: Path): Seq[Path] = { + if (!fm.exists(dir)) return Seq.empty + if (fs.getFileStatus(dir).isDirectory) { + fs.listStatus(dir).filter(_.isFile).map(_.getPath).toSeq + } else { + Seq(dir) + } + } + + val schemaFiles = scala.collection.mutable.ArrayBuffer.empty[Path] + + val storeDirs = scala.collection.mutable.ArrayBuffer(partitionRoot) + if (fs.exists(partitionRoot)) { + fs.listStatus(partitionRoot) + .filter(_.isDirectory) + .filterNot(_.getPath.getName.startsWith("_")) + .foreach(d => storeDirs += d.getPath) + } + storeDirs.foreach { storeDir => + schemaFiles ++= collectSchemaFiles( + new Path(storeDir, "_metadata/schema")) + } + + val stateSchemaRoot = new Path(operatorRoot, "_stateSchema") + if (fs.exists(stateSchemaRoot)) { + fs.listStatus(stateSchemaRoot) + .filter(_.isDirectory) + .foreach { storeDir => + schemaFiles ++= collectSchemaFiles(storeDir.getPath) + } + } + + assert(schemaFiles.nonEmpty, + s"expected at least one schema file under $operatorRoot") + schemaFiles.foreach { schemaFile => + val inStream = fm.open(schemaFile) + try { + val schemas = StateSchemaCompatibilityChecker.readSchemaFile(inStream) + schemas.foreach { s => + assertSchemaAllNullable(s.keySchema, + s"$schemaFile: key schema for col family ${s.colFamilyName}") + } + } finally inStream.close() + } + } + + private def assertSchemaAllNullable(schema: StructType, label: String): Unit = { + schema.fields.foreach { f => + assert(f.nullable, s"$label: field ${f.name} should be nullable") + assertDataTypeAllNullable(f.dataType, s"$label.${f.name}") + } + } + + private def assertDataTypeAllNullable(dataType: DataType, label: String): Unit = dataType match { + case s: StructType => assertSchemaAllNullable(s, label) + case ArrayType(elementType, containsNull) => + assert(containsNull, s"$label: array element should be nullable") + assertDataTypeAllNullable(elementType, s"$label[]") + case MapType(keyType, valueType, valueContainsNull) => + assert(valueContainsNull, s"$label: map value should be nullable") + assertDataTypeAllNullable(keyType, s"$label.key") + assertDataTypeAllNullable(valueType, s"$label.value") + case _ => + } + + test("streaming aggregate: non-nullable -> nullable widening remains restart-compatible") { + runUnionBranchDropRestart( + buildSources = () => buildTwoSources(), + buildQuery = (dfA, dfB) => dfA.union(dfB).groupBy($"key").count(), + outputMode = OutputMode.Update(), + nullableToNonNullable = false) + } + + test("streaming aggregate: nullable -> non-nullable narrowing remains restart-compatible") { + runUnionBranchDropRestart( + buildSources = () => buildTwoSources(), + buildQuery = (dfA, dfB) => dfA.union(dfB).groupBy($"key").count(), + outputMode = OutputMode.Update(), + nullableToNonNullable = true) + } + + test("streaming dropDuplicates: non-nullable -> nullable widening remains restart-compatible") { + runUnionBranchDropRestart( + buildSources = () => buildTwoSources(), + buildQuery = (dfA, dfB) => dfA.union(dfB).dropDuplicates(Seq("key")), + outputMode = OutputMode.Append(), + nullableToNonNullable = false) + } + + test("streaming dropDuplicatesWithinWatermark: " + + "non-nullable -> nullable widening remains restart-compatible") { + runUnionBranchDropRestart( + buildSources = () => buildTwoSourcesWithWatermark(), + buildQuery = (dfA, dfB) => dfA.union(dfB).dropDuplicatesWithinWatermark(Seq("key")), + outputMode = OutputMode.Append(), + nullableToNonNullable = false) + } + + test("streaming aggregate (Complete mode): no codegen NPE on state-restored null " + + "struct grouping key after fix") { + import org.apache.spark.sql.functions.struct + + def mkQuery(inNullableK: MemoryStream[Int], inNonNullK: MemoryStream[Int]): DataFrame = { + val dfNullable = inNullableK.toDF() + .select( + when($"value" > 0, struct($"value".as("v"))) + .otherwise(lit(null).cast("struct")) + .as("key"), + lit(1).as("metric")) + + val dfNonNull = inNonNullK.toDF() + .select( + struct($"value".as("v")).as("key"), + lit(1).as("metric")) + + dfNullable.union(dfNonNull) + .groupBy($"key") + .agg(sum($"metric").as("c")) + .select($"key.v".as("v"), $"c") + } + + withTempDir { checkpointDir => + withSQLConf( + SQLConf.STATE_SCHEMA_CHECK_ENABLED.key -> "false", + SQLConf.STATE_STORE_FORMAT_VALIDATION_ENABLED.key -> "false", + SQLConf.SHUFFLE_PARTITIONS.key -> "1") { + val inNullable = MemoryStream[Int] + val inNonNull = MemoryStream[Int] + val q = mkQuery(inNullable, inNonNull) + testStream(q, OutputMode.Complete())( + StartStream(checkpointLocation = checkpointDir.getAbsolutePath), + AddData(inNullable, 0), + ProcessAllAvailable(), + StopStream + ) + + testStream(q, OutputMode.Complete())( + StartStream(checkpointLocation = checkpointDir.getAbsolutePath), + AddData(inNonNull, 1), + ProcessAllAvailable() + ) + } + } + } + + test("streaming aggregate: with widening forced off (existing-query path), " + + "STATE_STORE_KEY_SCHEMA_NOT_COMPATIBLE still triggers on restart") { + withTempDir { checkpointDir => + withSQLConf( + SQLConf.STATEFUL_OPERATOR_ALWAYS_NULLABLE_OUTPUT.key -> "false") { + val (inputA, inputB, dfA, dfB) = buildTwoSources() + val aggregated = dfA.union(dfB).groupBy($"key").count() + testStream(aggregated, OutputMode.Update())( + StartStream(checkpointLocation = checkpointDir.getAbsolutePath), + AddData(inputA, 1, 2, 3), + ProcessAllAvailable(), + StopStream + ) + + inputA.addData(4) + inputB.addData(5) + + val ex = intercept[SparkUnsupportedOperationException] { + testStream(aggregated, OutputMode.Update())( + StartStream(checkpointLocation = checkpointDir.getAbsolutePath), + ProcessAllAvailable() + ) + } + + checkError( + ex, + condition = "STATE_STORE_KEY_SCHEMA_NOT_COMPATIBLE", + parameters = Map( + "storedKeySchema" -> ".*", + "newKeySchema" -> ".*"), + matchPVals = true + ) + } + } + } + + test("stream-stream join: non-nullable -> nullable widening remains restart-compatible") { + withTempDir { checkpointDir => + val checkpointPath = checkpointDir.getAbsolutePath + + def buildJoinQuery(): (MemoryStream[Int], MemoryStream[Int], DataFrame) = { + val leftInput = MemoryStream[Int] + val rightInput = MemoryStream[Int] + + val left = leftInput.toDF() + .select($"value".as("key"), + timestamp_seconds($"value").as("leftTime")) + .withWatermark("leftTime", "10 seconds") + val right = rightInput.toDF() + .select( + when($"value" > Int.MinValue, $"value") + .otherwise(lit(null).cast("int")).as("key"), + timestamp_seconds($"value").as("rightTime")) + .withWatermark("rightTime", "10 seconds") + + val joined = left.join(right, + left("key") === right("key") && + left("leftTime") > right("rightTime") - expr("INTERVAL 10 SECONDS") && + left("leftTime") < right("rightTime") + expr("INTERVAL 10 SECONDS"), + "inner") + (leftInput, rightInput, joined) + } + + val (leftInput1, rightInput1, joined1) = buildJoinQuery() + testStream(joined1, OutputMode.Append())( + StartStream(checkpointLocation = checkpointPath), + MultiAddData(leftInput1, 1, 2, 3)(rightInput1, 1, 2), + ProcessAllAvailable(), + StopStream + ) + + assertJournaledStateSchemaAllNullable(checkpointPath) + + val (leftInput2, rightInput2, joined2) = buildJoinQuery() + testStream(joined2, OutputMode.Append())( + StartStream(checkpointLocation = checkpointPath), + MultiAddData(leftInput2, 4)(rightInput2, 5), + ProcessAllAvailable() + ) + } + } + + test("streaming flatMapGroupsWithState: " + + "non-nullable -> nullable widening remains restart-compatible") { + val stateFunc = (key: Int, values: Iterator[Int], state: GroupState[Int]) => { + val sum = values.sum + state.getOption.getOrElse(0) + state.update(sum) + Iterator((key, sum)) + } + + withTempDir { checkpointDir => + val checkpointPath = checkpointDir.getAbsolutePath + + def buildFmgwsQuery() + : (MemoryStream[Int], MemoryStream[Int], DataFrame) = { + val (inputA, inputB, dfA, dfB) = buildTwoSources() + val result = dfA.union(dfB) + .as[Int] + .groupByKey(identity) + .flatMapGroupsWithState( + OutputMode.Update(), GroupStateTimeout.NoTimeout())(stateFunc) + .toDF("key", "sum") + (inputA, inputB, result) + } + + val (inputA1, inputB1, q1) = buildFmgwsQuery() + testStream(q1, OutputMode.Update())( + StartStream(checkpointLocation = checkpointPath), + AddData(inputA1, 1, 2, 3), + ProcessAllAvailable(), + StopStream + ) + + assertJournaledStateSchemaAllNullable(checkpointPath) + + val (inputA2, inputB2, q2) = buildFmgwsQuery() + testStream(q2, OutputMode.Update())( + StartStream(checkpointLocation = checkpointPath), + MultiAddData(inputA2, 4)(inputB2, 5), + ProcessAllAvailable() + ) + } + } + + test("streaming transformWithState: " + + "non-nullable -> nullable widening remains restart-compatible") { + withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> + classOf[RocksDBStateStoreProvider].getName) { + withTempDir { checkpointDir => + val checkpointPath = checkpointDir.getAbsolutePath + + def buildTwsQuery() + : (MemoryStream[Int], MemoryStream[Int], DataFrame) = { + val (inputA, inputB, dfA, dfB) = buildTwoSources() + val result = dfA.union(dfB) + .as[Int] + .groupByKey(identity) + .transformWithState( + new NullabilityDriftCountProcessor(), + TimeMode.None(), + OutputMode.Update()) + (inputA, inputB, result.toDF()) + } + + val (inputA1, inputB1, q1) = buildTwsQuery() + testStream(q1, OutputMode.Update())( + StartStream(checkpointLocation = checkpointPath), + AddData(inputA1, 1, 2, 3), + ProcessAllAvailable(), + StopStream + ) + + assertJournaledStateSchemaAllNullable(checkpointPath) + + val (inputA2, inputB2, q2) = buildTwsQuery() + testStream(q2, OutputMode.Update())( + StartStream(checkpointLocation = checkpointPath), + MultiAddData(inputA2, 4)(inputB2, 5), + ProcessAllAvailable() + ) + } + } + } + + test("rule skips non-stateful nodes whose subtree has no stateful operator") { + import org.apache.spark.sql.catalyst.analysis.WidenStatefulOperatorAttributeNullability + import org.apache.spark.sql.catalyst.expressions.{AttributeReference, NamedExpression} + import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LocalRelation, Project} + import org.apache.spark.sql.types.IntegerType + + val key = AttributeReference("key", IntegerType, nullable = false)() + val payload = AttributeReference("payload", IntegerType, nullable = false)() + val source = LocalRelation(Seq(key, payload), isStreaming = true) + val project = Project(Seq(key, payload), source) + val agg = Aggregate( + groupingExpressions = Seq(key), + aggregateExpressions = Seq(key.asInstanceOf[NamedExpression]), + child = project) + + val widened = WidenStatefulOperatorAttributeNullability(agg) + + val projectAfter = widened.collectFirst { case p: Project => p }.getOrElse( + fail(s"expected to find a Project node in the rewritten plan: $widened")) + assert(projectAfter.projectList.forall { + case ar: AttributeReference => !ar.nullable + case _ => true + }, s"Project.projectList below a stateful op should remain non-nullable: " + + s"${projectAfter.projectList}") + + val aggAfter = widened.asInstanceOf[Aggregate] + assert(aggAfter.aggregateExpressions.forall { + case ar: AttributeReference => ar.nullable + case _ => true + }, s"Aggregate.aggregateExpressions should be widened to nullable: " + + s"${aggAfter.aggregateExpressions}") + assert(aggAfter.groupingExpressions.forall { + case ar: AttributeReference => ar.nullable + case _ => true + }, s"Aggregate.groupingExpressions should be widened to nullable: " + + s"${aggAfter.groupingExpressions}") + } + + test("deepWidenAttribute recurses into struct fields, array elements, map values") { + import org.apache.spark.sql.catalyst.expressions.{AttributeReference, ExprId} + import org.apache.spark.sql.types._ + + val nestedStruct = StructType(Seq( + StructField("inner_nn", IntegerType, nullable = false), + StructField("inner_nl", StringType, nullable = true))) + val arrayOfNonNull = ArrayType(IntegerType, containsNull = false) + val mapWithNonNullValue = MapType(StringType, IntegerType, valueContainsNull = false) + val combined = StructType(Seq( + StructField("s", nestedStruct, nullable = false), + StructField("a", arrayOfNonNull, nullable = false), + StructField("m", mapWithNonNullValue, nullable = false))) + + val attr = AttributeReference("complex", combined, nullable = false)(ExprId(42L)) + val widened = WidenStatefulOpNullability.deepWidenAttribute(attr) + + assert(widened.nullable, "outer attribute should be widened to nullable") + val widenedStruct = widened.dataType.asInstanceOf[StructType] + val widenedNested = widenedStruct("s").dataType.asInstanceOf[StructType] + assert( + widenedStruct("s").nullable && widenedStruct("a").nullable && widenedStruct("m").nullable, + "all top-level fields should be widened to nullable") + assert(widenedNested("inner_nn").nullable && widenedNested("inner_nl").nullable, + "nested struct fields should be widened to nullable") + val widenedArray = widenedStruct("a").dataType.asInstanceOf[ArrayType] + assert(widenedArray.containsNull, "array element nullability should be widened") + val widenedMap = widenedStruct("m").dataType.asInstanceOf[MapType] + assert(widenedMap.valueContainsNull, "map value nullability should be widened") + + assert(widened.exprId == attr.exprId) + assert(widened.name == attr.name) + assert(widened.qualifier == attr.qualifier) + } +} + +class NullabilityDriftCountProcessor + extends StatefulProcessor[Int, Int, (Int, Long)] { + @transient private var countState: ValueState[Long] = _ + + override def init(outputMode: OutputMode, timeMode: TimeMode): Unit = { + countState = getHandle.getValueState[Long]( + "count", Encoders.scalaLong, TTLConfig.NONE) + } + + override def handleInputRows( + key: Int, + rows: Iterator[Int], + timerValues: TimerValues): Iterator[(Int, Long)] = { + val count = (if (countState.exists()) countState.get() else 0L) + rows.size + countState.update(count) + Iterator((key, count)) + } +}