From ae635f44aaee12b9209942b8d9238ca449852fce Mon Sep 17 00:00:00 2001 From: Anurag Mantripragada Date: Fri, 17 Apr 2026 09:26:10 -0700 Subject: [PATCH] [SPARK-56599][SQL] Add scan narrowing for column-level UPDATEs in DSv2 --- .../connector/write/RowLevelOperation.java | 41 ++ .../write/RowLevelOperationInfo.java | 17 + .../analysis/RewriteRowLevelCommand.scala | 42 +- .../analysis/RewriteUpdateTable.scala | 242 ++++++- .../catalyst/plans/logical/v2Commands.scala | 61 +- .../write/RowLevelOperationInfoImpl.scala | 8 +- .../InMemoryRowLevelOperationTable.scala | 354 ++++++++- ...wLevelOperationRuntimeGroupFiltering.scala | 13 +- .../DeltaBasedColumnUpdateTableSuite.scala | 677 ++++++++++++++++++ .../DeltaBasedUpdateTableSuiteBase.scala | 68 ++ .../RowLevelOperationSuiteBase.scala | 12 + 11 files changed, 1479 insertions(+), 56 deletions(-) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/connector/DeltaBasedColumnUpdateTableSuite.scala diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/write/RowLevelOperation.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/write/RowLevelOperation.java index 844734ff7ccb7..8c8affdd7c098 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/write/RowLevelOperation.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/write/RowLevelOperation.java @@ -105,4 +105,45 @@ default String description() { default NamedReference[] requiredMetadataAttributes() { return new NamedReference[0]; } + + + /** + * Controls whether to send only the required data columns to the connector rather than the + * full row. + *

+ * When true, Spark narrows the data column schema ({@link LogicalWriteInfo#schema()}) to only + * the columns declared via {@link #requiredDataAttributes()}. Metadata columns (from + * {@link #requiredMetadataAttributes()}) and row ID columns (from + * {@link SupportsDelta#rowId()}) are unaffected and always projected separately. + *

+ * If {@link #requiredDataAttributes()} returns a non-empty array, the write schema is exactly + * those columns in declared order. The connector must include all columns it wants to receive, + * including the columns being updated. If {@link #requiredDataAttributes()} returns an empty + * array, Spark sends only the non-identity assigned columns (heuristic path). + * + * @since 4.2.0 + */ + default boolean supportsColumnUpdates() { + return false; + } + + /** + * Returns data column references required to perform this row-level operation. + *

+ * This method is only consulted by Spark when {@link #supportsColumnUpdates()} returns + * {@code true}. If {@code supportsColumnUpdates()} returns {@code false}, the returned array + * is ignored and the full table row is sent (the default behavior). + *

+ * When non-empty, the returned columns become the write schema in declared order. + * The connector must declare all columns it wants to receive, including the columns being + * updated. Use {@link RowLevelOperationInfo#updatedColumns()} to learn which columns are being + * assigned, then add any extra columns needed for row lookup or routing (e.g., primary key). + *

+ * When empty (the default), Spark falls back to sending only the non-identity assigned columns. + * + * @since 4.2.0 + */ + default NamedReference[] requiredDataAttributes() { + return new NamedReference[0]; + } } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/write/RowLevelOperationInfo.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/write/RowLevelOperationInfo.java index e3d7397aed91b..77bb5b31e28bc 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/write/RowLevelOperationInfo.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/write/RowLevelOperationInfo.java @@ -18,6 +18,7 @@ package org.apache.spark.sql.connector.write; import org.apache.spark.annotation.Experimental; +import org.apache.spark.sql.connector.expressions.NamedReference; import org.apache.spark.sql.connector.write.RowLevelOperation.Command; import org.apache.spark.sql.util.CaseInsensitiveStringMap; @@ -37,4 +38,20 @@ public interface RowLevelOperationInfo { * Returns the row-level SQL command (e.g. DELETE, UPDATE, MERGE). */ Command command(); + + /** + * Returns the columns being updated in an UPDATE statement, as non-identity assignments. + * + *

For DELETE and MERGE, returns an empty array. + * + *

Connectors can use this to decide what {@link RowLevelOperation#requiredDataAttributes()} + * to declare. For instance, a connector that needs its primary key for row lookup can check + * whether pk is already in the updated columns list and, if not, add it to + * requiredDataAttributes(). + * + * @since 4.2.0 + */ + default NamedReference[] updatedColumns() { + return new NamedReference[0]; + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteRowLevelCommand.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteRowLevelCommand.scala index 48c48eb323bd7..98d73225515a6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteRowLevelCommand.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteRowLevelCommand.scala @@ -22,12 +22,13 @@ import scala.collection.mutable import org.apache.spark.SparkException import org.apache.spark.sql.catalyst.ProjectingInternalRow import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeReference, AttributeSet, Expression, ExprId, If, Literal, MetadataAttribute, NamedExpression, V2ExpressionUtils} +import org.apache.spark.sql.catalyst.expressions.Literal.TrueLiteral import org.apache.spark.sql.catalyst.plans.logical.{Assignment, Expand, LogicalPlan, MergeRows, Project, Union} import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.catalyst.util.{ReplaceDataProjections, WriteDeltaProjections} import org.apache.spark.sql.catalyst.util.RowDeltaUtils._ import org.apache.spark.sql.connector.catalog.SupportsRowLevelOperations -import org.apache.spark.sql.connector.expressions.FieldReference +import org.apache.spark.sql.connector.expressions.{FieldReference, NamedReference} import org.apache.spark.sql.connector.write.{RowLevelOperation, RowLevelOperationInfoImpl, RowLevelOperationTable, SupportsDelta} import org.apache.spark.sql.connector.write.RowLevelOperation.Command import org.apache.spark.sql.errors.QueryCompilationErrors @@ -50,20 +51,35 @@ trait RewriteRowLevelCommand extends Rule[LogicalPlan] { protected def buildOperationTable( table: SupportsRowLevelOperations, command: Command, - options: CaseInsensitiveStringMap): RowLevelOperationTable = { - val info = RowLevelOperationInfoImpl(command, options) + options: CaseInsensitiveStringMap, + updatedColumns: Seq[NamedReference] = Nil): RowLevelOperationTable = { + val info = RowLevelOperationInfoImpl(command, options, updatedColumns) val operation = table.newRowLevelOperationBuilder(info).build() RowLevelOperationTable(table, operation) } + // Builds a DataSourceV2Relation for a row-level operation, optionally narrowing its output. + // + // When dataAttrs is non-empty, the relation output is narrowed to include only columns + // required for a column-update write. When dataAttrs is empty, the full relation.output is + // preserved. protected def buildRelationWithAttrs( relation: DataSourceV2Relation, table: RowLevelOperationTable, metadataAttrs: Seq[AttributeReference], - rowIdAttrs: Seq[AttributeReference] = Nil): DataSourceV2Relation = { - - val attrs = dedupAttrs(relation.output ++ rowIdAttrs ++ metadataAttrs) - relation.copy(table = table, output = attrs) + rowIdAttrs: Seq[AttributeReference] = Nil, + dataAttrs: Seq[AttributeReference] = Nil, + cond: Expression = TrueLiteral): DataSourceV2Relation = { + + if (dataAttrs.nonEmpty) { + val required = + AttributeSet(dataAttrs) ++ AttributeSet(Seq(cond)) ++ AttributeSet(rowIdAttrs) + val narrowOutput = relation.output.filter(required.contains) + relation.copy(table = table, output = dedupAttrs(narrowOutput ++ rowIdAttrs ++ metadataAttrs)) + } else { + val attrs = dedupAttrs(relation.output ++ rowIdAttrs ++ metadataAttrs) + relation.copy(table = table, output = attrs) + } } protected def dedupAttrs(attrs: Seq[AttributeReference]): Seq[AttributeReference] = { @@ -87,6 +103,14 @@ trait RewriteRowLevelCommand extends Rule[LogicalPlan] { relation) } + protected def resolveRequiredDataAttrs( + relation: DataSourceV2Relation, + operation: RowLevelOperation): Seq[AttributeReference] = { + V2ExpressionUtils.resolveRefs[AttributeReference]( + operation.requiredDataAttributes.toImmutableArraySeq, + relation) + } + protected def resolveRowIdAttrs( relation: DataSourceV2Relation, operation: SupportsDelta): Seq[AttributeReference] = { @@ -211,11 +235,13 @@ trait RewriteRowLevelCommand extends Rule[LogicalPlan] { metadataAttrs: Seq[Attribute]): WriteDeltaProjections = { val outputs = extractOutputs(plan) + // Always produce Some(rowProjection) even for empty rowAttrs (identity-only column updates). + // Physical execution calls rowProjection.project(row) unconditionally; None causes NPE. val rowProjection = if (rowAttrs.nonEmpty) { val outputsWithRow = filterOutputs(outputs, OPERATIONS_WITH_ROW) Some(newLazyProjection(plan, outputsWithRow, rowAttrs)) } else { - None + Some(ProjectingInternalRow(StructType(Nil), Nil)) } val outputsWithRowId = filterOutputs(outputs, OPERATIONS_WITH_ROW_ID) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteUpdateTable.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteUpdateTable.scala index 3c41b6bfa5683..da52b685ce1d4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteUpdateTable.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteUpdateTable.scala @@ -17,11 +17,12 @@ package org.apache.spark.sql.catalyst.analysis -import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeReference, EqualNullSafe, Expression, If, Literal, MetadataAttribute, Not, SubqueryExpression} +import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeMap, AttributeReference, AttributeSet, Cast, EqualNullSafe, Expression, If, Literal, MetadataAttribute, Not, SubqueryExpression} import org.apache.spark.sql.catalyst.expressions.Literal.TrueLiteral import org.apache.spark.sql.catalyst.plans.logical.{Assignment, Expand, Filter, LogicalPlan, Project, ReplaceData, Union, UpdateTable, WriteDelta} import org.apache.spark.sql.catalyst.util.RowDeltaUtils._ import org.apache.spark.sql.connector.catalog.SupportsRowLevelOperations +import org.apache.spark.sql.connector.expressions.FieldReference import org.apache.spark.sql.connector.write.{RowLevelOperationTable, SupportsDelta} import org.apache.spark.sql.connector.write.RowLevelOperation.Command.UPDATE import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2Relation, ExtractV2Table} @@ -41,7 +42,13 @@ object RewriteUpdateTable extends RewriteRowLevelCommand { EliminateSubqueryAliases(aliasedTable) match { case r @ ExtractV2Table(tbl: SupportsRowLevelOperations) => - val table = buildOperationTable(tbl, UPDATE, CaseInsensitiveStringMap.empty()) + val updatedCols = assignments.collect { + case Assignment(key: AttributeReference, value) + if !isIdentityAssignment(key, value) => + FieldReference(key.name) + } + val table = buildOperationTable(tbl, UPDATE, CaseInsensitiveStringMap.empty(), + updatedCols) val updateCond = cond.getOrElse(TrueLiteral) table.operation match { case _: SupportsDelta => @@ -65,18 +72,15 @@ object RewriteUpdateTable extends RewriteRowLevelCommand { assignments: Seq[Assignment], cond: Expression): ReplaceData = { - // resolve all required metadata attrs that may be used for grouping data on write - val metadataAttrs = resolveRequiredMetadataAttrs(relation, operationTable.operation) - - // construct a read relation and include all required metadata columns - val readRelation = buildRelationWithAttrs(relation, operationTable, metadataAttrs) + val (readRelation, rowAttrs) = buildCoWReadSetup(relation, operationTable, assignments, cond) - // build a plan with updated and copied over records - val query = buildReplaceDataUpdateProjection(readRelation, assignments, cond) + val updatedAndRemainingRowsPlan = buildReplaceDataUpdateProjection( + readRelation, assignments, cond) - // build a plan to replace read groups in the table val writeRelation = relation.copy(table = operationTable) - val projections = buildReplaceDataProjections(query, relation.output, metadataAttrs) + val query = updatedAndRemainingRowsPlan + val metadataAttrs = resolveRequiredMetadataAttrs(relation, operationTable.operation) + val projections = buildReplaceDataProjections(query, rowAttrs, metadataAttrs) val groupFilterCond = if (groupFilterEnabled) Some(cond) else None ReplaceData(writeRelation, cond, query, relation, projections, groupFilterCond) } @@ -89,13 +93,7 @@ object RewriteUpdateTable extends RewriteRowLevelCommand { assignments: Seq[Assignment], cond: Expression): ReplaceData = { - // resolve all required metadata attrs that may be used for grouping data on write - val metadataAttrs = resolveRequiredMetadataAttrs(relation, operationTable.operation) - - // construct a read relation and include all required metadata columns - // the same read relation will be used to read records that must be updated and copied over - // the analyzer will take care of duplicated attr IDs - val readRelation = buildRelationWithAttrs(relation, operationTable, metadataAttrs) + val (readRelation, rowAttrs) = buildCoWReadSetup(relation, operationTable, assignments, cond) // build a plan for updated records that match the condition val matchedRowsPlan = Filter(cond, readRelation) @@ -106,38 +104,92 @@ object RewriteUpdateTable extends RewriteRowLevelCommand { val remainingRowsPlan = addOperationColumn(COPY_OPERATION, Filter(remainingRowFilter, readRelation)) - // the new state is a union of updated and copied over records - val query = Union(updatedRowsPlan, remainingRowsPlan) + val updatedAndRemainingRowsPlan = Union(updatedRowsPlan, remainingRowsPlan) - // build a plan to replace read groups in the table val writeRelation = relation.copy(table = operationTable) - val projections = buildReplaceDataProjections(query, relation.output, metadataAttrs) + val query = updatedAndRemainingRowsPlan + val metadataAttrs = resolveRequiredMetadataAttrs(relation, operationTable.operation) + val projections = buildReplaceDataProjections(query, rowAttrs, metadataAttrs) val groupFilterCond = if (groupFilterEnabled) Some(cond) else None ReplaceData(writeRelation, cond, query, relation, projections, groupFilterCond) } + // Common read-relation setup shared by both CoW plan builders. + // + // When the connector supports column updates and declares required data attributes, + // the read relation is narrowed at analysis time so that + // GroupBasedRowLevelOperationScanPlanning uses only the needed columns for the scan. + // Otherwise the full relation output is used. + private def buildCoWReadSetup( + relation: DataSourceV2Relation, + operationTable: RowLevelOperationTable, + assignments: Seq[Assignment], + cond: Expression): (DataSourceV2Relation, Seq[Attribute]) = { + + val operation = operationTable.operation + val metadataAttrs = resolveRequiredMetadataAttrs(relation, operation) + val connectorDataAttrs = resolveRequiredDataAttrs(relation, operation) + val isNarrow = operation.supportsColumnUpdates() && connectorDataAttrs.nonEmpty + + // CoW scan narrowing must be done manually at analysis time. + // GroupBasedRowLevelOperationScanPlanning (an optimizer rule that fires after analysis) + // always reads relation.output directly when building the physical scan -- it does not + // observe Project nodes above the relation, so optimizer-driven column pruning has no + // effect on CoW scans. We narrow DataSourceV2Relation.output here so that rule picks + // up the narrow set. + val readRelation = if (isNarrow) { + val allRequired = (connectorDataAttrs ++ computeAssignedAttrs(assignments)).distinct + buildRelationWithAttrs(relation, operationTable, metadataAttrs, dataAttrs = allRequired, + cond = cond) + } else { + buildRelationWithAttrs(relation, operationTable, metadataAttrs) + } + + // CoW write schema (two paths only, no heuristic for CoW): + // - Narrow path (connectorDataAttrs declared): exactly connector-declared cols in declared + // order. The connector must declare ALL columns it wants to receive. + // - Full path (connectorDataAttrs empty OR supportsColumnUpdates=false): full table output. + // Unlike MOR, CoW does not have a heuristic assigned-only path because + // GroupBasedRowLevelOperationScanPlanning needs explicit column declarations to narrow. + val rowAttrs: Seq[Attribute] = if (isNarrow) connectorDataAttrs else relation.output + + (readRelation, rowAttrs) + } + // this method assumes the assignments have been already aligned before + // + // Works for both the full-scan and narrow-scan CoW paths. In the narrow case, + // readRelation.output is already restricted by buildCoWReadSetup, so projecting + // all plan.output gives the correct narrow write schema. private def buildReplaceDataUpdateProjection( plan: LogicalPlan, assignments: Seq[Assignment], cond: Expression = TrueLiteral): LogicalPlan = { - // the plan output may include metadata columns at the end - // that's why the number of assignments may not match the number of plan output columns - val assignedValues = assignments.map(_.value) - val updatedValues = plan.output.zipWithIndex.map { case (attr, index) => - if (index < assignments.size) { - val assignedExpr = assignedValues(index) - val updatedValue = If(cond, assignedExpr, attr) - Alias(updatedValue, attr.name)() - } else { - assert(MetadataAttribute.isValid(attr.metadata)) + // Build a name-keyed map via AttributeMap (compares by exprId internally) so we can look + // up each plan column's assignment without relying on positional ordering. This is more + // robust than position-based indexing and works correctly for any plan output layout. + val assignmentMap = AttributeMap(assignments.collect { + case Assignment(key: Attribute, value) => key -> value + }) + + val updatedValues = plan.output.map { attr => + if (MetadataAttribute.isValid(attr.metadata)) { if (MetadataAttribute.isPreservedOnUpdate(attr)) { attr } else { val updatedValue = If(cond, Literal(null, attr.dataType), attr) Alias(updatedValue, attr.name)(explicitMetadata = Some(attr.metadata)) } + } else { + assignmentMap.get(attr) match { + case Some(assignedExpr) => + Alias(If(cond, assignedExpr, attr), attr.name)() + case None => + // Column is present in the scan but has no assignment -- pass through unchanged. + // In the narrow CoW path these are connector-declared columns not being updated. + attr + } } } @@ -154,29 +206,149 @@ object RewriteUpdateTable extends RewriteRowLevelCommand { cond: Expression): WriteDelta = { val operation = operationTable.operation.asInstanceOf[SupportsDelta] + // Column-update support applies to the standard delta path and the delete+reinsert path. + // When representUpdateAsDeleteAndInsert is true, the REINSERT leg of the Expand already + // uses only assigned values, so the narrow effectiveRowAttrs applies correctly. + val supportsColumnUpdate = operation.supportsColumnUpdates() // resolve all needed attrs (e.g. row ID and any required metadata attrs) - val rowAttrs = relation.output val rowIdAttrs = resolveRowIdAttrs(relation, operation) val metadataAttrs = resolveRequiredMetadataAttrs(relation, operation) - // construct a read relation and include all required metadata columns + // Connector-declared data attrs used to determine pass-through columns in the write plan. + val connectorDataAttrs = if (supportsColumnUpdate) { + resolveRequiredDataAttrs(relation, operation) + } else Nil + + // MOR uses a full-schema scan; ColumnPruning narrows it via Project references. val readRelation = buildRelationWithAttrs(relation, operationTable, metadataAttrs, rowIdAttrs) + // Connector-required attrs that are NOT being assigned are added as pass-throughs in the + // plan so that ColumnPruning keeps them in the physical scan AND the connector receives + // their current values via DeltaWriter.update's row argument. + val assignedAttrs = if (supportsColumnUpdate) computeAssignedAttrs(assignments) + else relation.output + val connectorExtraAttrs: Seq[AttributeReference] = if (connectorDataAttrs.nonEmpty) { + val assignedAttrSet = AttributeSet(assignedAttrs) + connectorDataAttrs.filterNot(assignedAttrSet.contains) + } else Nil + // build a plan for updated records that match the condition val matchedRowsPlan = Filter(cond, readRelation) val rowDeltaPlan = if (operation.representUpdateAsDeleteAndInsert) { buildDeletesAndInserts(matchedRowsPlan, assignments, rowIdAttrs) + } else if (supportsColumnUpdate) { + buildColumnUpdateProjection( + matchedRowsPlan, assignments, rowIdAttrs, metadataAttrs, connectorExtraAttrs) } else { buildWriteDeltaUpdateProjection(matchedRowsPlan, assignments, rowIdAttrs) } + // Effective row write schema: + // - Narrow path (connectorDataAttrs declared): exactly connector-declared cols in declared + // order. The connector must declare ALL columns it wants to receive (including updated + // ones). This mirrors the metadata pattern and enables strict areCompatible validation. + // - Heuristic path (connectorDataAttrs empty): only the assigned (changed) columns. + // - Full path (no column-update support): full table output. + val effectiveRowAttrs = if (supportsColumnUpdate && connectorDataAttrs.nonEmpty) { + connectorDataAttrs + } else if (supportsColumnUpdate) { + assignedAttrs + } else { + relation.output + } + // build a plan to write the row delta to the table val writeRelation = relation.copy(table = operationTable) - val projections = buildWriteDeltaProjections(rowDeltaPlan, rowAttrs, rowIdAttrs, metadataAttrs) + val projections = buildWriteDeltaProjections( + rowDeltaPlan, effectiveRowAttrs, rowIdAttrs, metadataAttrs) WriteDelta(writeRelation, cond, rowDeltaPlan, relation, projections) } + // Builds the row delta projection for the column update path. + // + // The resulting Project references only: + // - assigned column values (new values being written) + // - connector pass-through values (connector declared but not assigned) + // - metadata columns (nulled or preserved) + // - row ID columns (for delta identification) + // - original row ID values (only when a row ID column is being reassigned) + // + // ColumnPruning observes exactly these references and narrows the physical scan accordingly. + // Connectors that need additional columns in the scan (e.g., partition columns for + // distribution) should declare them in requiredDataAttributes(). + // + // Note: AlignUpdateAssignments guarantees all assignment keys are top-level + // AttributeReferences even for nested field updates (e.g., SET col1.field = 'x' becomes + // Assignment(col1: AttributeReference, CreateNamedStruct(...))), so isIdentityAssignment + // correctly identifies non-updating assignments. + private def buildColumnUpdateProjection( + plan: LogicalPlan, + assignments: Seq[Assignment], + rowIdAttrs: Seq[Attribute], + metadataAttrs: Seq[Attribute], + connectorExtraAttrs: Seq[AttributeReference] = Nil): LogicalPlan = { + + // only emit values for non-identity assignments (the narrow write schema) + val assignedValues = assignments.collect { + case Assignment(key: Attribute, value) if !isIdentityAssignment(key, value) => + Alias(value, key.name)() + } + + // Connector-required data attrs that are not being assigned are passed through as-is + // so that (a) ColumnPruning keeps them in the physical scan, and (b) the connector + // receives their current values via DeltaWriter.update's row argument. + val connectorExtraAttrSet = AttributeSet(connectorExtraAttrs) + val connectorPassThroughValues = plan.output.filter { a => + connectorExtraAttrSet.contains(a) && !MetadataAttribute.isValid(a.metadata) + } + + // pass through or null out metadata columns present in the scan + val metadataAttrSet = AttributeSet(metadataAttrs) + val metadataValues = plan.output.filter(metadataAttrSet.contains).map { attr => + if (MetadataAttribute.isPreservedOnUpdate(attr)) { + attr + } else { + Alias(Literal(null, attr.dataType), attr.name)(explicitMetadata = Some(attr.metadata)) + } + } + + // pass through row ID columns from the scan + val rowIdAttrSet = AttributeSet(rowIdAttrs) + val rowIdValues = plan.output.filter(rowIdAttrSet.contains) + + val originalRowIdValues = buildOriginalRowIdValues(rowIdAttrs, assignments) + val operationType = Alias(Literal(UPDATE_OPERATION), OPERATION_COLUMN)() + + Project( + Seq(operationType) ++ assignedValues ++ connectorPassThroughValues ++ + metadataValues ++ rowIdValues ++ originalRowIdValues, + plan) + } + + // Returns the table attributes that are genuinely updated (non-identity) in this UPDATE. + // Strips Alias/Cast wrappers introduced during assignment alignment before doing the + // AttributeSet membership check (which uses exprId equality internally). + private def computeAssignedAttrs(assignments: Seq[Assignment]): Seq[AttributeReference] = { + assignments.collect { + case Assignment(key: AttributeReference, value) if !isIdentityAssignment(key, value) => key + } + } + + private def isIdentityAssignment(key: Attribute, value: Expression): Boolean = { + stripAliasesAndCasts(value) match { + case attr: Attribute => AttributeSet(Seq(key)).contains(attr) + case _ => false + } + } + + // Recursively strips Alias and Cast wrappers introduced during assignment alignment. + private def stripAliasesAndCasts(expr: Expression): Expression = expr match { + case Alias(child, _) => stripAliasesAndCasts(child) + case Cast(child, _, _, _) => stripAliasesAndCasts(child) + case other => other + } + // this method assumes the assignments have been already aligned before private def buildWriteDeltaUpdateProjection( plan: LogicalPlan, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala index b857a360544e3..4d878442d3426 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala @@ -325,6 +325,14 @@ trait RowLevelWrite extends V2WriteCommand with SupportsSubquery { operation.requiredMetadataAttributes.toImmutableArraySeq, originalTable) } + + // Resolves the connector-declared data attributes against the original table. + // Symmetric with projectedMetadataAttrs; used in narrow-schema validation. + protected def projectedDataAttrs: Seq[Attribute] = { + V2ExpressionUtils.resolveRefs[AttributeReference]( + operation.requiredDataAttributes.toImmutableArraySeq, + originalTable) + } } /** @@ -378,7 +386,35 @@ case class ReplaceData( // validates row projection output is compatible with table attributes private def rowAttrsResolved: Boolean = { val inRowAttrs = DataTypeUtils.toAttributes(projections.rowProjection.schema) - table.skipSchemaResolution || areCompatible(inRowAttrs, table.output) + table.skipSchemaResolution || + areCompatible(inRowAttrs, table.output) || + dataAttrsResolved(inRowAttrs) + } + + // Validates the narrow-write-schema row projection output. + // + // When the connector declares specific data attributes via requiredDataAttributes(), the + // write schema must exactly match projectedDataAttrs (same columns, same order). This is + // symmetric with metadataAttrsResolved: the connector's declared attrs define the write schema. + // + // When requiredDataAttributes() is empty (heuristic path), the write schema contains only + // the assigned columns. We validate each one exists in the table with a compatible type. + private def dataAttrsResolved(inRowAttrs: Seq[Attribute]): Boolean = { + if (!operation.supportsColumnUpdates()) { return false } + val outDataAttrs = projectedDataAttrs + if (outDataAttrs.nonEmpty) { + areCompatible(inRowAttrs, outDataAttrs) + } else { + inRowAttrs.forall { inAttr => + table.output.exists { outAttr => + val inType = CharVarcharUtils.getRawType(inAttr.metadata).getOrElse(inAttr.dataType) + val outType = CharVarcharUtils.getRawType(outAttr.metadata).getOrElse(outAttr.dataType) + inAttr.name == outAttr.name && + DataType.equalsIgnoreCompatibleNullability(inType, outType) && + (outAttr.nullable || !inAttr.nullable) + } + } + } } // validates metadata projection output is compatible with metadata attributes @@ -468,7 +504,28 @@ case class WriteDelta( case Some(projection) => DataTypeUtils.toAttributes(projection.schema) case None => Nil } - table.skipSchemaResolution || areCompatible(inRowAttrs, outRowAttrs) + table.skipSchemaResolution || + areCompatible(inRowAttrs, outRowAttrs) || + dataAttrsResolved(inRowAttrs) + } + + // Validates the narrow-write-schema row projection. Symmetric with ReplaceData. + private def dataAttrsResolved(inRowAttrs: Seq[Attribute]): Boolean = { + if (!operation.supportsColumnUpdates()) { return false } + val outDataAttrs = projectedDataAttrs + if (outDataAttrs.nonEmpty) { + areCompatible(inRowAttrs, outDataAttrs) + } else { + inRowAttrs.forall { inAttr => + table.output.exists { outAttr => + val inType = CharVarcharUtils.getRawType(inAttr.metadata).getOrElse(inAttr.dataType) + val outType = CharVarcharUtils.getRawType(outAttr.metadata).getOrElse(outAttr.dataType) + inAttr.name == outAttr.name && + DataType.equalsIgnoreCompatibleNullability(inType, outType) && + (outAttr.nullable || !inAttr.nullable) + } + } + } } // validates row ID projection output is compatible with row ID attributes diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/write/RowLevelOperationInfoImpl.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/write/RowLevelOperationInfoImpl.scala index 9d499cdef361b..a84e0230cd8d7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/write/RowLevelOperationInfoImpl.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/write/RowLevelOperationInfoImpl.scala @@ -17,9 +17,15 @@ package org.apache.spark.sql.connector.write +import org.apache.spark.sql.connector.expressions.NamedReference import org.apache.spark.sql.connector.write.RowLevelOperation.Command import org.apache.spark.sql.util.CaseInsensitiveStringMap private[sql] case class RowLevelOperationInfoImpl( command: Command, - options: CaseInsensitiveStringMap) extends RowLevelOperationInfo + options: CaseInsensitiveStringMap, + private val updatedCols: Seq[NamedReference] = Nil) + extends RowLevelOperationInfo { + + override def updatedColumns(): Array[NamedReference] = updatedCols.toArray +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryRowLevelOperationTable.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryRowLevelOperationTable.scala index 91e899bc1169e..2f540fa887e40 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryRowLevelOperationTable.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryRowLevelOperationTable.scala @@ -53,18 +53,49 @@ class InMemoryRowLevelOperationTable( private final val SPLIT_UPDATES = "split-updates" private final val NO_METADATA = "no-metadata" private final val noMetadata = properties.getOrDefault(NO_METADATA, "false") == "true" + private final val COLUMN_UPDATE = "column-update" + private final val COLUMN_UPDATE_REQ_ATTRS = "column-update-req-attrs" + // Selects PartitionBasedColumnUpdateOperation: CoW connector with supportsColumnUpdates=true + // and requiredDataAttributes=[pk,dep]. + private final val COLUMN_UPDATE_COW = "column-update-cow" + // Selects DeltaBasedColumnUpdateOperationFromInfo: connector that derives + // requiredDataAttributes() dynamically from RowLevelOperationInfo.updatedColumns(). + // Always adds "pk" for row lookup plus whatever Spark reports as updated. + private final val COLUMN_UPDATE_FROM_INFO = "column-update-from-info" + // Selects DeltaBasedColumnUpdateSplitOperation: delta connector with + // representUpdateAsDeleteAndInsert=true AND supportsColumnUpdates=true. + // Used to verify Point 7: the restriction on column updates for the delete+reinsert path + // has been lifted. + private final val COLUMN_UPDATE_SPLIT = "column-update-split" // used in row-level operation tests to verify replaced partitions var replacedPartitions: Seq[Seq[Any]] = Seq.empty // used in row-level operation tests to verify reported write schema var lastWriteInfo: LogicalWriteInfo = _ + // used in column-update tests to verify the scan projection was narrowed correctly + var lastScanSchema: StructType = _ + // used in column-update tests to verify that Spark passed the correct updated column list + // to the connector via RowLevelOperationInfo.updatedColumns() + var lastUpdatedColumns: Array[NamedReference] = Array.empty // used in row-level operation tests to verify passed records // (operation, id, metadata, row) var lastWriteLog: Seq[InternalRow] = Seq.empty override def newRowLevelOperationBuilder( info: RowLevelOperationInfo): RowLevelOperationBuilder = { - if (properties.getOrDefault(SUPPORTS_DELTAS, "false") == "true") { + lastUpdatedColumns = info.updatedColumns() + if (properties.getOrDefault(COLUMN_UPDATE, "false") == "true") { + () => new DeltaBasedColumnUpdateOperation(info.command) + } else if (properties.containsKey(COLUMN_UPDATE_REQ_ATTRS)) { + val reqCols = properties.get(COLUMN_UPDATE_REQ_ATTRS).split(",").map(_.trim) + () => new DeltaBasedColumnUpdateOperationWithReqAttrs(info.command, reqCols) + } else if (properties.getOrDefault(COLUMN_UPDATE_FROM_INFO, "false") == "true") { + () => new DeltaBasedColumnUpdateOperationFromInfo(info.command, info.updatedColumns().toSeq) + } else if (properties.getOrDefault(COLUMN_UPDATE_COW, "false") == "true") { + () => new PartitionBasedColumnUpdateOperation(info.command, info.updatedColumns().toSeq) + } else if (properties.getOrDefault(COLUMN_UPDATE_SPLIT, "false") == "true") { + () => new DeltaBasedColumnUpdateSplitOperation(info.command, info.updatedColumns().toSeq) + } else if (properties.getOrDefault(SUPPORTS_DELTAS, "false") == "true") { () => DeltaBasedOperation(info.command) } else { () => PartitionBasedOperation(info.command) @@ -198,6 +229,327 @@ class InMemoryRowLevelOperationTable( } } + // A delta-based operation that supports column-level updates: Spark sends only the + // assigned/changed columns in the row projection instead of the full row schema. + class DeltaBasedColumnUpdateOperation(command: Command) + extends DeltaBasedOperation(command) { + override def representUpdateAsDeleteAndInsert(): Boolean = false + override def supportsColumnUpdates(): Boolean = true + + // Override newScanBuilder to record the schema that Spark actually requests from the + // connector after column pruning, so tests can assert on scan narrowing. + override def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder = { + new InMemoryScanBuilder(schema, options) { + override def build(): Scan = { + val scan = super.build() + lastScanSchema = scan.readSchema() + scan + } + } + } + + override def newWriteBuilder(info: LogicalWriteInfo): DeltaWriteBuilder = { + lastWriteInfo = info + new DeltaWriteBuilder { + override def build(): DeltaWrite = + new DeltaWrite with RequiresDistributionAndOrdering { + + override def requiredDistribution(): Distribution = { + Distributions.clustered(Array(PARTITION_COLUMN_REF)) + } + + override def requiredOrdering(): Array[SortOrder] = { + Array[SortOrder]( + LogicalExpressions.sort( + PARTITION_COLUMN_REF, + SortDirection.ASCENDING, + SortDirection.ASCENDING.defaultNullOrdering()) + ) + } + + override def toBatch: DeltaBatchWrite = + new RowLevelOperationBatchWrite with DeltaBatchWrite { + override def createBatchWriterFactory( + info: PhysicalWriteInfo): DeltaWriterFactory = { + new DeltaBufferedRowsWriterFactory(lastWriteInfo.schema()) + } + + // For column-update writes, rows contain only the assigned columns + // (narrow schema from LogicalWriteInfo). We expand each row to the full table + // schema by overlaying write-schema columns on the base row found by pk. + override def commit(messages: Array[WriterCommitMessage]): Unit = + dataMap.synchronized { + val newData = messages.map(_.asInstanceOf[BufferedRows]) + val writeSchema = lastWriteInfo.schema() + val writeFieldIdx = writeSchema.fieldNames.zipWithIndex.toMap + + val mergedData = newData.map { buf => + val merged = new BufferedRows(buf.key, schema) + val updateOpName = UTF8String.fromString(Update.toString) + buf.log.foreach { logRow => + val opName = logRow.getUTF8String(0) + if (opName == updateOpName) { + val pk = logRow.getInt(1) + val narrowRow = logRow.get(3, writeSchema).asInstanceOf[InternalRow] + val baseRow = dataMap.values.iterator.flatten + .flatMap(_.rows) + .find(r => r.getInt(schema.fieldIndex("pk")) == pk) + val fullRow = new GenericInternalRow(schema.length) + baseRow.foreach { base => + for (i <- schema.fields.indices) { + fullRow.update(i, base.get(i, schema(i).dataType)) + } + } + schema.fields.zipWithIndex.foreach { case (field, i) => + writeFieldIdx.get(field.name).foreach { j => + fullRow.update(i, narrowRow.get(j, field.dataType)) + } + } + merged.rows.append(fullRow) + } + } + merged + } + + withDeletes(newData) + withData(mergedData) + lastWriteLog = newData.flatMap(buffer => buffer.log).toIndexedSeq + } + + override def abort(messages: Array[WriterCommitMessage]): Unit = {} + } + } + } + } + } + + // A variant of DeltaBasedColumnUpdateOperation that overrides requiredDataAttributes() + // to declare a fixed set of data columns the connector needs in the scan. This exercises + // the connector-driven scan-narrowing path (as opposed to the heuristic path). + class DeltaBasedColumnUpdateOperationWithReqAttrs(command: Command, reqCols: Array[String]) + extends DeltaBasedColumnUpdateOperation(command) { + override def requiredDataAttributes(): Array[NamedReference] = reqCols.map(FieldReference(_)) + } + + // A delta-based column-update connector that derives requiredDataAttributes() dynamically + // from RowLevelOperationInfo.updatedColumns(). + // + // This models the common connector pattern: + // 1. Spark tells the connector which columns are being updated via updatedColumns(). + // 2. The connector adds any extra columns it always needs (here: "pk" for row lookup). + // 3. The combined set is returned from requiredDataAttributes() so Spark narrows the scan. + // + // If "pk" is already in updatedColumns (the user is updating pk itself), it is not duplicated. + class DeltaBasedColumnUpdateOperationFromInfo( + command: Command, + updatedCols: Seq[NamedReference]) + extends DeltaBasedColumnUpdateOperation(command) { + + private val PK_REF: NamedReference = FieldReference("pk") + + override def requiredDataAttributes(): Array[NamedReference] = { + val updatedNames = updatedCols.map(_.describe()).toSet + if (updatedNames.contains("pk")) { + updatedCols.toArray + } else { + (Array(PK_REF) ++ updatedCols).toArray + } + } + } + + // A delta-based operation that combines representUpdateAsDeleteAndInsert=true with + // supportsColumnUpdates()=true. This verifies that the restriction which previously + // blocked column-level updates on the delete+reinsert path has been lifted. + // + // The connector declares "pk" plus any columns being updated (via updatedCols). + // The write schema = requiredDataAttributes() in declared order. + // The REINSERT leg receives the narrow write row; the DELETE leg uses row ID only. + class DeltaBasedColumnUpdateSplitOperation( + command: Command, + updatedCols: Seq[NamedReference] = Nil) + extends DeltaBasedColumnUpdateOperation(command) { + override def representUpdateAsDeleteAndInsert(): Boolean = true + + private val PK_REF: NamedReference = FieldReference("pk") + override def requiredDataAttributes(): Array[NamedReference] = { + val updatedNames = updatedCols.map(_.describe()).toSet + if (updatedNames.contains("pk")) updatedCols.toArray + else (Array(PK_REF) ++ updatedCols).toArray + } + + override def newWriteBuilder(info: LogicalWriteInfo): DeltaWriteBuilder = { + lastWriteInfo = info + new DeltaWriteBuilder { + override def build(): DeltaWrite = + new DeltaWrite with RequiresDistributionAndOrdering { + override def requiredDistribution(): Distribution = + Distributions.clustered(Array(PARTITION_COLUMN_REF)) + override def requiredOrdering(): Array[SortOrder] = Array[SortOrder]( + LogicalExpressions.sort( + PARTITION_COLUMN_REF, + SortDirection.ASCENDING, + SortDirection.ASCENDING.defaultNullOrdering())) + override def toBatch: DeltaBatchWrite = + new RowLevelOperationBatchWrite with DeltaBatchWrite { + override def createBatchWriterFactory( + info: PhysicalWriteInfo): DeltaWriterFactory = + new DeltaBufferedRowsWriterFactory(lastWriteInfo.schema()) + + // For delete+reinsert with narrow writes, the REINSERT row has only the + // connector-declared columns (requiredDataAttributes order). + // pk is the first field in the write schema (declared before updatedCols). + // Reconstruct the full row by overlaying the narrow row onto the original. + override def commit(messages: Array[WriterCommitMessage]): Unit = + dataMap.synchronized { + val newData = messages.map(_.asInstanceOf[BufferedRows]) + val writeSchema = lastWriteInfo.schema() + val writeFieldIdx = writeSchema.fieldNames.zipWithIndex.toMap + val reinsertOpName = UTF8String.fromString(Reinsert.toString) + val pkIdx = writeFieldIdx("pk") + + val expandedData = newData.map { buf => + val expanded = new BufferedRows(buf.key, schema) + buf.log.foreach { logRow => + val opName = logRow.getUTF8String(0) + if (opName == reinsertOpName) { + val narrowRow = logRow.get(3, writeSchema).asInstanceOf[InternalRow] + val pk = narrowRow.getInt(pkIdx) + val baseRow = dataMap.values.iterator.flatten + .flatMap(_.rows) + .find(r => r.getInt(schema.fieldIndex("pk")) == pk) + val fullRow = new GenericInternalRow(schema.length) + baseRow.foreach { base => + for (i <- schema.fields.indices) { + fullRow.update(i, base.get(i, schema(i).dataType)) + } + } + schema.fields.zipWithIndex.foreach { case (field, i) => + writeFieldIdx.get(field.name).foreach { j => + fullRow.update(i, narrowRow.get(j, field.dataType)) + } + } + expanded.rows.append(fullRow) + } + } + expanded + } + + withDeletes(newData) + withData(expandedData) + lastWriteLog = newData.flatMap(buffer => buffer.log).toIndexedSeq + } + + override def abort(messages: Array[WriterCommitMessage]): Unit = {} + } + } + } + } + } + + // A CoW operation that supports column-level updates. The connector declares it needs + // "pk" and "dep" for partition routing, plus any columns the user is updating (via + // updatedCols from RowLevelOperationInfo). supportsColumnUpdates()=true so Spark narrows + // the scan and write schema to exactly requiredDataAttributes(). + // The commit logic reconstructs full rows from the original scan data using pk as a key. + class PartitionBasedColumnUpdateOperation( + command: Command, + updatedCols: Seq[NamedReference] = Nil) extends RowLevelOperation { + var configuredScan: InMemoryBatchScan = _ + + override def command(): Command = command + + override def supportsColumnUpdates(): Boolean = true + + override def requiredDataAttributes(): Array[NamedReference] = { + // Always need pk (for row lookup) and dep (partition key). + // Also include any columns being updated so Spark sends their new values. + val base = Seq(FieldReference("pk"), FieldReference("dep")) + val baseNames = base.map(_.describe()).toSet + (base ++ updatedCols.filterNot(r => baseNames.contains(r.describe()))).toArray + } + + override def requiredMetadataAttributes(): Array[NamedReference] = + Array(PARTITION_COLUMN_REF, INDEX_COLUMN_REF) + + override def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder = { + new InMemoryScanBuilder(schema, options) { + override def build(): Scan = { + val scan = super.build() + configuredScan = scan.asInstanceOf[InMemoryBatchScan] + lastScanSchema = scan.readSchema() + scan + } + } + } + + override def newWriteBuilder(info: LogicalWriteInfo): WriteBuilder = { + lastWriteInfo = info + new WriteBuilder { + override def build(): Write = new Write with RequiresDistributionAndOrdering { + override def requiredDistribution: Distribution = + Distributions.clustered(Array(PARTITION_COLUMN_REF)) + + override def requiredOrdering: Array[SortOrder] = Array[SortOrder]( + LogicalExpressions.sort( + PARTITION_COLUMN_REF, + SortDirection.ASCENDING, + SortDirection.ASCENDING.defaultNullOrdering())) + + override def toBatch: BatchWrite = + PartitionBasedNarrowReplaceData(configuredScan, info.schema()) + + override def description: String = "InMemoryNarrowCoWWrite" + } + } + } + + override def description(): String = "InMemoryPartitionColumnUpdateOperation" + } + + // CoW write handler for narrow column-update writes. + // Receives rows with only the connector-declared + assigned columns. + // Reconstructs full rows by looking up the original row by pk and overlaying received columns. + private case class PartitionBasedNarrowReplaceData( + scan: InMemoryBatchScan, + writeSchema: StructType) extends RowLevelOperationBatchWrite { + + override def commit(messages: Array[WriterCommitMessage]): Unit = dataMap.synchronized { + val newData = messages.map(_.asInstanceOf[BufferedRows]) + val readRows = scan.data.flatMap(_.asInstanceOf[BufferedRows].rows) + val readPartitions = readRows.map(r => getKey(r, schema)).distinct + dataMap --= readPartitions + replacedPartitions = readPartitions + + val writeFieldIdx = writeSchema.fieldNames.zipWithIndex.toMap + val pkIdxInWrite = writeFieldIdx("pk") + val pkIdxInFull = schema.fieldIndex("pk") + + val expandedData = newData.map { buf => + val expanded = new BufferedRows(buf.key, schema) + buf.rows.foreach { narrowRow => + val pk = narrowRow.getInt(pkIdxInWrite) + val origRow = readRows.find(r => r.getInt(pkIdxInFull) == pk) + val fullRow = new GenericInternalRow(schema.length) + origRow.foreach { base => + for (i <- schema.fields.indices) { + fullRow.update(i, base.get(i, schema(i).dataType)) + } + } + schema.fields.zipWithIndex.foreach { case (field, i) => + writeFieldIdx.get(field.name).foreach { j => + fullRow.update(i, narrowRow.get(j, field.dataType)) + } + } + expanded.rows.append(fullRow) + } + expanded + } + + withData(expandedData, schema) + lastWriteLog = newData.flatMap(buffer => buffer.log).toImmutableArraySeq + } + } + private object TestDeltaBatchWrite extends RowLevelOperationBatchWrite with DeltaBatchWrite{ override def createBatchWriterFactory(info: PhysicalWriteInfo): DeltaWriterFactory = { new DeltaBufferedRowsWriterFactory(CatalogV2Util.v2ColumnsToStructType(columns())) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/dynamicpruning/RowLevelOperationRuntimeGroupFiltering.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/dynamicpruning/RowLevelOperationRuntimeGroupFiltering.scala index 41971e60f5737..9acfd55e2bfea 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/dynamicpruning/RowLevelOperationRuntimeGroupFiltering.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/dynamicpruning/RowLevelOperationRuntimeGroupFiltering.scala @@ -17,7 +17,6 @@ package org.apache.spark.sql.execution.dynamicpruning -import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, AttributeReference, DynamicPruningExpression, Expression, InSubquery, ListQuery, PredicateHelper, V2ExpressionUtils} import org.apache.spark.sql.catalyst.expressions.Literal.TrueLiteral import org.apache.spark.sql.catalyst.optimizer.RewritePredicateSubquery @@ -128,17 +127,13 @@ class RowLevelOperationRuntimeGroupFiltering(optimizeSubqueries: Rule[LogicalPla tableAttrs: Seq[Attribute], scanAttrs: Seq[Attribute]): AttributeMap[Attribute] = { - val attrMapping = tableAttrs.map { tableAttr => + // For column-level updates, the scan may be narrowed to exclude columns that the + // connector does not need. Skip table attributes that are absent from the scan + // instead of throwing -- they cannot appear in the condition if they were pruned. + val attrMapping = tableAttrs.flatMap { tableAttr => scanAttrs .find(scanAttr => conf.resolver(scanAttr.name, tableAttr.name)) .map(scanAttr => tableAttr -> scanAttr) - .getOrElse { - throw new AnalysisException( - errorClass = "_LEGACY_ERROR_TEMP_3075", - messageParameters = Map( - "tableAttr" -> tableAttr.toString, - "scanAttrs" -> scanAttrs.mkString(","))) - } } AttributeMap(attrMapping) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DeltaBasedColumnUpdateTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DeltaBasedColumnUpdateTableSuite.scala new file mode 100644 index 0000000000000..b22f2743a87ee --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DeltaBasedColumnUpdateTableSuite.scala @@ -0,0 +1,677 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connector + +import org.apache.spark.sql.Row +import org.apache.spark.sql.connector.catalog.{CatalogV2Util, TableInfo} +import org.apache.spark.sql.connector.expressions.LogicalExpressions.{identity, reference} +import org.apache.spark.sql.connector.expressions.Transform +import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType} + +/** + * Tests for UPDATE statements targeting connectors that return true from + * [[org.apache.spark.sql.connector.write.RowLevelOperation#supportsColumnUpdates]]. + * + * When a connector supports column updates, Spark narrows the row projection + * (LogicalWriteInfo.schema()) to contain only the assigned/changed columns rather than + * the full table row. + */ +class DeltaBasedColumnUpdateTableSuite extends RowLevelOperationSuiteBase { + + override protected lazy val extraTableProps: java.util.Map[String, String] = { + val props = new java.util.HashMap[String, String]() + props.put("column-update", "true") + props + } + + // --- Schema narrowing: verify LogicalWriteInfo.schema() is narrow --- + + test("column-update: rowSchema contains only the single assigned column") { + createAndInitTable("pk INT NOT NULL, id INT, dep STRING", + """{ "pk": 1, "id": 1, "dep": "hr" } + |{ "pk": 2, "id": 2, "dep": "software" } + |{ "pk": 3, "id": 3, "dep": "hr" } + |""".stripMargin) + + sql(s"UPDATE $tableNameAsString SET id = -1 WHERE pk = 1") + + // Only the assigned column (id) should appear in the row schema -- not pk or dep + checkLastWriteInfo( + expectedRowSchema = StructType(Seq( + StructField("id", IntegerType, nullable = false) + )), + expectedRowIdSchema = Some(StructType(Array(PK_FIELD))), + expectedMetadataSchema = Some(StructType(Array(PARTITION_FIELD, INDEX_FIELD_NULLABLE)))) + } + + test("column-update: rowSchema contains multiple assigned columns") { + createAndInitTable("pk INT NOT NULL, id INT, dep STRING", + """{ "pk": 1, "id": 1, "dep": "hr" } + |{ "pk": 2, "id": 2, "dep": "software" } + |""".stripMargin) + + sql(s"UPDATE $tableNameAsString SET id = -1, dep = 'engineering' WHERE pk = 1") + + // Both assigned columns (id, dep) should appear -- but NOT pk (unassigned) + checkLastWriteInfo( + expectedRowSchema = StructType(Seq( + StructField("id", IntegerType, nullable = false), + StructField("dep", StringType, nullable = false) + )), + expectedRowIdSchema = Some(StructType(Array(PK_FIELD))), + expectedMetadataSchema = Some(StructType(Array(PARTITION_FIELD, INDEX_FIELD_NULLABLE)))) + } + + test("column-update: rowSchema is empty for a full identity update") { + createAndInitTable("pk INT NOT NULL, id INT, dep STRING", + """{ "pk": 1, "id": 1, "dep": "hr" } + |""".stripMargin) + + sql(s"UPDATE $tableNameAsString SET id = id, dep = dep WHERE pk = 1") + + checkLastWriteInfo( + expectedRowSchema = new StructType(), + expectedRowIdSchema = Some(StructType(Array(PK_FIELD))), + expectedMetadataSchema = Some(StructType(Array(PARTITION_FIELD, INDEX_FIELD_NULLABLE)))) + } + + test("column-update: row filter condition is orthogonal to column narrowing") { + createAndInitTable("pk INT NOT NULL, id INT, dep STRING", + """{ "pk": 1, "id": 1, "dep": "hr" } + |{ "pk": 2, "id": 2, "dep": "software" } + |{ "pk": 3, "id": 3, "dep": "hr" } + |""".stripMargin) + + sql(s"UPDATE $tableNameAsString SET dep = 'engineering' WHERE pk IN (1, 3)") + + checkLastWriteInfo( + expectedRowSchema = StructType(Seq( + StructField("dep", StringType, nullable = false) + )), + expectedRowIdSchema = Some(StructType(Array(PK_FIELD))), + expectedMetadataSchema = Some(StructType(Array(PARTITION_FIELD, INDEX_FIELD_NULLABLE)))) + } + + test("column-update: update all rows (no WHERE clause)") { + createAndInitTable("pk INT NOT NULL, salary INT, dep STRING", + """{ "pk": 1, "salary": 100, "dep": "hr" } + |{ "pk": 2, "salary": 200, "dep": "software" } + |{ "pk": 3, "salary": 300, "dep": "hr" } + |""".stripMargin) + + sql(s"UPDATE $tableNameAsString SET salary = salary * 2") + + checkLastWriteInfo( + expectedRowSchema = StructType(Seq( + StructField("salary", IntegerType, nullable = true) + )), + expectedRowIdSchema = Some(StructType(Array(PK_FIELD))), + expectedMetadataSchema = Some(StructType(Array(PARTITION_FIELD, INDEX_FIELD_NULLABLE)))) + } + + // --- Identity assignment filtering --- + + test("column-update: rowSchema excludes identity assignments in a mixed UPDATE") { + createAndInitTable("pk INT NOT NULL, id INT, dep STRING", + """{ "pk": 1, "id": 1, "dep": "hr" } + |{ "pk": 2, "id": 2, "dep": "software" } + |""".stripMargin) + + // id = id is identity -- should be excluded from rowSchema + // dep = 'engineering' is a real assignment -- should be included + sql(s"UPDATE $tableNameAsString SET id = id, dep = 'engineering' WHERE pk = 1") + + checkLastWriteInfo( + expectedRowSchema = StructType(Seq( + StructField("dep", StringType, nullable = false) + )), + expectedRowIdSchema = Some(StructType(Array(PK_FIELD))), + expectedMetadataSchema = Some(StructType(Array(PARTITION_FIELD, INDEX_FIELD_NULLABLE)))) + } + + test("column-update: cross-column assignment is not treated as identity") { + createAndInitTable("pk INT NOT NULL, id INT, dep STRING", + """{ "pk": 1, "id": 1, "dep": "hr" } + |{ "pk": 2, "id": 2, "dep": "software" } + |""".stripMargin) + + // dep = dep is identity; id = -1 is a real assignment -- only id should appear + sql(s"UPDATE $tableNameAsString SET dep = dep, id = -1 WHERE pk = 1") + + checkLastWriteInfo( + expectedRowSchema = StructType(Seq( + StructField("id", IntegerType, nullable = false) + )), + expectedRowIdSchema = Some(StructType(Array(PK_FIELD))), + expectedMetadataSchema = Some(StructType(Array(PARTITION_FIELD, INDEX_FIELD_NULLABLE)))) + } + + // --- updatedColumns in RowLevelOperationInfo --- + + test("column-update: updatedColumns contains non-identity assigned columns") { + createAndInitTable("pk INT NOT NULL, id INT, dep STRING", + """{ "pk": 1, "id": 1, "dep": "hr" } + |""".stripMargin) + + sql(s"UPDATE $tableNameAsString SET id = -1, dep = 'eng' WHERE pk = 1") + + val updatedNames = table.lastUpdatedColumns.map(_.describe()).toSet + assert(updatedNames == Set("id", "dep"), + s"expected [id, dep] in updatedColumns but got: $updatedNames") + } + + test("column-update: updatedColumns excludes identity assignments") { + createAndInitTable("pk INT NOT NULL, id INT, dep STRING", + """{ "pk": 1, "id": 1, "dep": "hr" } + |""".stripMargin) + + // dep = dep is identity; only id should appear in updatedColumns + sql(s"UPDATE $tableNameAsString SET id = -1, dep = dep WHERE pk = 1") + + val updatedNames = table.lastUpdatedColumns.map(_.describe()).toSet + assert(updatedNames == Set("id"), + s"expected only [id] in updatedColumns but got: $updatedNames") + } + + test("column-update: updatedColumns is empty for a full identity update") { + createAndInitTable("pk INT NOT NULL, id INT, dep STRING", + """{ "pk": 1, "id": 1, "dep": "hr" } + |""".stripMargin) + + sql(s"UPDATE $tableNameAsString SET id = id, dep = dep WHERE pk = 1") + + assert(table.lastUpdatedColumns.isEmpty, + s"expected empty updatedColumns but got: ${table.lastUpdatedColumns.mkString(", ")}") + } + + test("column-update: updatedColumns is empty for DELETE (Javadoc contract)") { + // DELETE never has updated columns -- verify that the default empty array is passed + // through RowLevelOperationInfo even when a column-update connector handles the DELETE. + // Use a partition-column condition so the InMemory table can process the filter. + createAndInitTable("pk INT NOT NULL, id INT, dep STRING", + """{ "pk": 1, "id": 1, "dep": "hr" } + |{ "pk": 2, "id": 2, "dep": "software" } + |""".stripMargin) + + sql(s"DELETE FROM $tableNameAsString WHERE dep = 'hr'") + + assert(table.lastUpdatedColumns.isEmpty, + s"DELETE must pass empty updatedColumns but got: ${table.lastUpdatedColumns.mkString(", ")}") + } + + // --- Data correctness --- + + test("column-update: data correctness -- single column update") { + createAndInitTable("pk INT NOT NULL, id INT, dep STRING", + """{ "pk": 1, "id": 1, "dep": "hr" } + |{ "pk": 2, "id": 2, "dep": "software" } + |{ "pk": 3, "id": 3, "dep": "hr" } + |""".stripMargin) + + sql(s"UPDATE $tableNameAsString SET id = -1 WHERE pk = 1") + + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString ORDER BY pk"), + Row(1, -1, "hr") :: Row(2, 2, "software") :: Row(3, 3, "hr") :: Nil) + } + + test("column-update: data correctness -- update all rows") { + createAndInitTable("pk INT NOT NULL, salary INT, dep STRING", + """{ "pk": 1, "salary": 100, "dep": "hr" } + |{ "pk": 2, "salary": 200, "dep": "software" } + |{ "pk": 3, "salary": 300, "dep": "hr" } + |""".stripMargin) + + sql(s"UPDATE $tableNameAsString SET salary = salary * 2") + + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString ORDER BY pk"), + Row(1, 200, "hr") :: Row(2, 400, "software") :: Row(3, 600, "hr") :: Nil) + } + + test("column-update: data correctness -- mixed identity and real assignments") { + createAndInitTable("pk INT NOT NULL, id INT, dep STRING", + """{ "pk": 1, "id": 1, "dep": "hr" } + |{ "pk": 2, "id": 2, "dep": "software" } + |{ "pk": 3, "id": 3, "dep": "hr" } + |""".stripMargin) + + // Only dep changes; id stays as-is even though id = id is in the SET list. + sql(s"UPDATE $tableNameAsString SET id = id, dep = 'engineering' WHERE pk = 1") + + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString ORDER BY pk"), + Row(1, 1, "engineering") :: Row(2, 2, "software") :: Row(3, 3, "hr") :: Nil) + } + + // --- Scan narrowing: verify the connector only receives the columns it needs --- + + test("column-update: scan excludes the assigned column when SET to a literal") { + createAndInitTable("pk INT NOT NULL, id INT, dep STRING", + """{ "pk": 1, "id": 1, "dep": "hr" } + |{ "pk": 2, "id": 2, "dep": "software" } + |""".stripMargin) + + // id is the target of a literal assignment -- its current value is not needed. + // pk is needed for the WHERE condition and as rowId. + sql(s"UPDATE $tableNameAsString SET id = -1 WHERE pk = 1") + + val scanSchema = table.lastScanSchema + assert(!scanSchema.fieldNames.contains("id"), s"id should be excluded from scan: $scanSchema") + assert(scanSchema.fieldNames.contains("pk"), s"pk must be in scan: $scanSchema") + } + + test("column-update: scan includes the assigned column when its current value is the RHS") { + createAndInitTable("pk INT NOT NULL, salary INT, bonus INT, dep STRING", + """{ "pk": 1, "salary": 100, "bonus": 10, "dep": "hr" } + |{ "pk": 2, "salary": 200, "bonus": 20, "dep": "software" } + |""".stripMargin) + + // salary appears on the RHS (salary * 2) so it must be scanned. + // bonus is not referenced anywhere -- excluded. + sql(s"UPDATE $tableNameAsString SET salary = salary * 2") + + val scanSchema = table.lastScanSchema + assert(scanSchema.fieldNames.contains("salary"), s"salary must be in scan: $scanSchema") + assert(!scanSchema.fieldNames.contains("bonus"), s"bonus should be excluded: $scanSchema") + } + + test("column-update: scan excludes non-referenced columns for literal assignment") { + createAndInitTable("pk INT NOT NULL, id INT, salary INT, dep STRING", + """{ "pk": 1, "id": 1, "salary": 100, "dep": "hr" } + |{ "pk": 2, "id": 2, "salary": 200, "dep": "software" } + |""".stripMargin) + + // dep is a literal assignment; id and salary are not referenced -- only pk needed. + sql(s"UPDATE $tableNameAsString SET dep = 'engineering' WHERE pk = 1") + + val scanSchema = table.lastScanSchema + assert(scanSchema.fieldNames.contains("pk"), s"pk must be in scan: $scanSchema") + assert(!scanSchema.fieldNames.contains("id"), s"id should be excluded: $scanSchema") + assert(!scanSchema.fieldNames.contains("salary"), s"salary should be excluded: $scanSchema") + } + + test("column-update: scan includes condition columns even when not assigned") { + createAndInitTable("pk INT NOT NULL, salary INT, bonus INT, dep STRING", + """{ "pk": 1, "salary": 100, "bonus": 10, "dep": "hr" } + |{ "pk": 2, "salary": 200, "bonus": 20, "dep": "software" } + |""".stripMargin) + + // dep appears in the WHERE clause -- must be scanned even though it is not assigned. + // bonus is neither assigned nor in the condition -- excluded. + // salary is set to a literal -- current value not needed. + sql(s"UPDATE $tableNameAsString SET salary = -1 WHERE dep = 'hr'") + + val scanSchema = table.lastScanSchema + assert(scanSchema.fieldNames.contains("dep"), + s"dep must be in scan (WHERE clause): $scanSchema") + assert(!scanSchema.fieldNames.contains("bonus"), s"bonus should be excluded: $scanSchema") + assert(!scanSchema.fieldNames.contains("salary"), + s"salary should be excluded (literal assignment): $scanSchema") + } + + // --------------------------------------------------------------------------- + // Connector-driven scan narrowing via requiredDataAttributes() + // --------------------------------------------------------------------------- + + // Creates a table backed by DeltaBasedColumnUpdateOperationWithReqAttrs, which overrides + // requiredDataAttributes() to return the given comma-separated column names. + private def createAndInitTableWithReqAttrs( + reqAttrs: String, + schemaString: String, + jsonData: String): Unit = { + val props = new java.util.HashMap[String, String]() + props.put("column-update-req-attrs", reqAttrs) + val columns = CatalogV2Util.structTypeToV2Columns(StructType.fromDDL(schemaString)) + val transforms = Array[Transform](identity(reference(Seq("dep")))) + val tableInfo = new TableInfo.Builder() + .withColumns(columns) + .withPartitions(transforms) + .withProperties(props) + .build() + catalog.createTable(ident, tableInfo) + append(schemaString, jsonData) + } + + test("column-update: requiredDataAttributes forces connector-declared column into scan") { + // Connector declares it always needs "dep". + // SQL assigns "id" (literal) with condition on "pk". + // Connector-driven scan = {pk, dep} (dep from connector declaration; pk from condition). + // id is NOT in scan: literal assignment + not declared by connector. + createAndInitTableWithReqAttrs("dep", "pk INT NOT NULL, id INT, dep STRING", + """{ "pk": 1, "id": 1, "dep": "hr" } + |{ "pk": 2, "id": 2, "dep": "software" } + |""".stripMargin) + + sql(s"UPDATE $tableNameAsString SET id = -1 WHERE pk = 1") + + val scanSchema = table.lastScanSchema + assert(scanSchema.fieldNames.contains("dep"), + s"dep must be in scan (connector required): $scanSchema") + assert(scanSchema.fieldNames.contains("pk"), s"pk must be in scan: $scanSchema") + assert(!scanSchema.fieldNames.contains("id"), + s"id should be excluded (literal assignment, not declared): $scanSchema") + } + + test("column-update: requiredDataAttributes - data correctness") { + // Connector declares "dep,id" so it receives both the new id value and dep for routing. + // The write schema is exactly requiredDataAttributes = {dep, id} (declared order). + createAndInitTableWithReqAttrs("dep,id", "pk INT NOT NULL, id INT, dep STRING", + """{ "pk": 1, "id": 1, "dep": "hr" } + |{ "pk": 2, "id": 2, "dep": "software" } + |{ "pk": 3, "id": 3, "dep": "hr" } + |""".stripMargin) + + sql(s"UPDATE $tableNameAsString SET id = -1 WHERE pk = 1") + + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString ORDER BY pk"), + Row(1, -1, "hr") :: Row(2, 2, "software") :: Row(3, 3, "hr") :: Nil) + } + + test("column-update: empty requiredDataAttributes falls back to heuristic") { + // "column-update" uses DeltaBasedColumnUpdateOperation whose requiredDataAttributes() + // returns the default empty array. + // With the optimizer-driven approach for MOR, the scan is narrowed by V2ScanRelationPushDown + // which observes what columns the write plan actually references. + // SET id = -1 (literal assignment): id is not referenced from the scan, so it is pruned. + // dep is the partitioning column; since it is not declared in requiredDataAttributes() + // and is not referenced by the WHERE condition (pk = 1), it may be pruned from the scan. + createAndInitTable("pk INT NOT NULL, id INT, dep STRING", + """{ "pk": 1, "id": 1, "dep": "hr" } + |{ "pk": 2, "id": 2, "dep": "software" } + |""".stripMargin) + + sql(s"UPDATE $tableNameAsString SET id = -1 WHERE pk = 1") + + val scanSchema = table.lastScanSchema + assert(!scanSchema.fieldNames.contains("id"), + s"id must NOT be in scan (literal assignment, no scan reference): $scanSchema") + assert(scanSchema.fieldNames.contains("pk"), s"pk must be in scan (condition): $scanSchema") + } + + // --------------------------------------------------------------------------- + // Connector uses RowLevelOperationInfo.updatedColumns() to derive its own + // requiredDataAttributes() dynamically. + // DeltaBasedColumnUpdateOperationFromInfo always adds "pk" (for row lookup) to + // whatever Spark reports as updated columns. + // --------------------------------------------------------------------------- + + private def createAndInitTableFromInfo(schemaString: String, jsonData: String): Unit = { + val props = new java.util.HashMap[String, String]() + props.put("column-update-from-info", "true") + val columns = CatalogV2Util.structTypeToV2Columns(StructType.fromDDL(schemaString)) + val transforms = Array[Transform](identity(reference(Seq("dep")))) + val tableInfo = new TableInfo.Builder() + .withColumns(columns) + .withPartitions(transforms) + .withProperties(props) + .build() + catalog.createTable(ident, tableInfo) + append(schemaString, jsonData) + } + + test("column-update from-info: connector adds pk to updatedColumns for requiredDataAttributes") { + // Connector receives updatedColumns=[salary], adds pk for row lookup. + // requiredDataAttributes() = [pk, salary]. + // + // salary = -1 is a LITERAL assignment: the write plan references Literal(-1) not the + // scan's salary column. Since salary is in assignedAttrs, it is not a connectorExtraAttr + // pass-through either. V2ScanRelationPushDown therefore does not see salary referenced + // and prunes it from the scan. + // + // The scan contains: pk (connector pass-through), dep (partitioning + WHERE condition). + // The scan excludes: salary (literal assignment), id and bonus (not declared, not in cond). + createAndInitTableFromInfo("pk INT NOT NULL, salary INT, id INT, bonus INT, dep STRING", + """{ "pk": 1, "salary": 100, "id": 10, "bonus": 5, "dep": "hr" } + |{ "pk": 2, "salary": 200, "id": 20, "bonus": 6, "dep": "software" } + |{ "pk": 3, "salary": 300, "id": 30, "bonus": 7, "dep": "hr" } + |""".stripMargin) + + sql(s"UPDATE $tableNameAsString SET salary = -1 WHERE dep = 'hr'") + + val scanSchema = table.lastScanSchema + assert(scanSchema.fieldNames.contains("pk"), + s"pk must be in scan (connector pass-through via connectorExtraAttrs): $scanSchema") + assert(scanSchema.fieldNames.contains("dep"), + s"dep must be in scan (partitioning + WHERE): $scanSchema") + assert(!scanSchema.fieldNames.contains("id"), + s"id must be excluded (not declared, not assigned, not in condition): $scanSchema") + assert(!scanSchema.fieldNames.contains("bonus"), + s"bonus must be excluded (not declared, not assigned, not in condition): $scanSchema") + } + + test("column-update from-info: write schema is updatedColumns + pk pass-through") { + // requiredDataAttributes = [pk, salary] (pk always added; salary because it's assigned). + // Write schema = requiredDataAttributes in declared order = {pk, salary}. + createAndInitTableFromInfo("pk INT NOT NULL, salary INT, id INT, dep STRING", + """{ "pk": 1, "salary": 100, "id": 10, "dep": "hr" } + |{ "pk": 2, "salary": 200, "id": 20, "dep": "software" } + |""".stripMargin) + + sql(s"UPDATE $tableNameAsString SET salary = -1 WHERE pk = 1") + + val writeSchema = table.lastWriteInfo.schema() + assert(writeSchema.fieldNames.contains("salary"), + s"salary must be in write schema (assigned): $writeSchema") + assert(writeSchema.fieldNames.contains("pk"), + s"pk must be in write schema " + + s"(connector pass-through via requiredDataAttributes): $writeSchema") + assert(!writeSchema.fieldNames.contains("id"), + s"id must not be in write schema: $writeSchema") + assert(!writeSchema.fieldNames.contains("dep"), + s"dep must not be in write schema (partitioning, not a data column to write): $writeSchema") + } + + test("column-update from-info: pk already in updatedColumns is not duplicated") { + // When the user updates pk itself, updatedColumns=[pk, salary]. + // Connector sees pk already present -> requiredDataAttributes=[pk, salary] (no dup). + createAndInitTableFromInfo("pk INT NOT NULL, salary INT, dep STRING", + """{ "pk": 1, "salary": 100, "dep": "hr" } + |{ "pk": 2, "salary": 200, "dep": "software" } + |""".stripMargin) + + sql(s"UPDATE $tableNameAsString SET pk = pk + 10, salary = -1 WHERE dep = 'hr'") + + val writeSchema = table.lastWriteInfo.schema() + val pkCount = writeSchema.fieldNames.count(_ == "pk") + assert(pkCount == 1, s"pk must appear exactly once in write schema: $writeSchema") + } + + test("column-update from-info: data correctness") { + createAndInitTableFromInfo("pk INT NOT NULL, salary INT, id INT, dep STRING", + """{ "pk": 1, "salary": 100, "id": 10, "dep": "hr" } + |{ "pk": 2, "salary": 200, "id": 20, "dep": "software" } + |{ "pk": 3, "salary": 300, "id": 30, "dep": "hr" } + |""".stripMargin) + + sql(s"UPDATE $tableNameAsString SET salary = -1 WHERE dep = 'hr'") + + // salary updated for hr rows; id preserved (not in write schema, connector uses pk lookup) + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString ORDER BY pk"), + Row(1, -1, 10, "hr") :: + Row(2, 200, 20, "software") :: + Row(3, -1, 30, "hr") :: Nil) + } + + // --------------------------------------------------------------------------- + // CoW connector with supportsColumnUpdates() on RowLevelOperation. + // PartitionBasedColumnUpdateOperation declares requiredDataAttributes() = [pk, dep] and + // supportsColumnUpdates() = true. Spark narrows the scan to connector-declared + assigned + // columns; bonus is excluded. The connector reconstructs full rows via pk lookup. + // --------------------------------------------------------------------------- + + private def createAndInitTableCoW(schemaString: String, jsonData: String): Unit = { + val props = new java.util.HashMap[String, String]() + props.put("column-update-cow", "true") + val columns = CatalogV2Util.structTypeToV2Columns(StructType.fromDDL(schemaString)) + val transforms = Array[Transform](identity(reference(Seq("dep")))) + val tableInfo = new TableInfo.Builder() + .withColumns(columns) + .withPartitions(transforms) + .withProperties(props) + .build() + catalog.createTable(ident, tableInfo) + append(schemaString, jsonData) + } + + test("column-update CoW: scan excludes columns not declared and not assigned") { + // Connector declares [pk, dep]. SET salary = -1. + // Narrow scan = pk (declared) + dep (declared + condition + partitioning) + // + salary (assigned LHS). bonus is neither declared nor assigned -> excluded. + createAndInitTableCoW("pk INT NOT NULL, salary INT, bonus INT, dep STRING", + """{ "pk": 1, "salary": 100, "bonus": 10, "dep": "hr" } + |{ "pk": 2, "salary": 200, "bonus": 20, "dep": "software" } + |{ "pk": 3, "salary": 300, "bonus": 30, "dep": "hr" } + |""".stripMargin) + + sql(s"UPDATE $tableNameAsString SET salary = -1 WHERE dep = 'hr'") + + val scanSchema = table.lastScanSchema + assert(scanSchema.fieldNames.contains("pk"), s"pk must be in scan: $scanSchema") + assert(scanSchema.fieldNames.contains("dep"), s"dep must be in scan: $scanSchema") + assert(scanSchema.fieldNames.contains("salary"), + s"salary must be in scan (assigned LHS): $scanSchema") + assert(!scanSchema.fieldNames.contains("bonus"), s"bonus must be excluded: $scanSchema") + } + + test("column-update CoW: write schema contains only declared + assigned columns") { + createAndInitTableCoW("pk INT NOT NULL, salary INT, bonus INT, dep STRING", + """{ "pk": 1, "salary": 100, "bonus": 10, "dep": "hr" } + |{ "pk": 2, "salary": 200, "bonus": 20, "dep": "software" } + |""".stripMargin) + + sql(s"UPDATE $tableNameAsString SET salary = -1 WHERE dep = 'hr'") + + val writeSchema = table.lastWriteInfo.schema() + assert(writeSchema.fieldNames.contains("pk"), s"pk must be in write schema: $writeSchema") + assert(writeSchema.fieldNames.contains("dep"), s"dep must be in write schema: $writeSchema") + assert(writeSchema.fieldNames.contains("salary"), + s"salary must be in write schema: $writeSchema") + assert(!writeSchema.fieldNames.contains("bonus"), + s"bonus must not be in write schema: $writeSchema") + } + + test("column-update CoW: data correctness -- bonus preserved, salary updated") { + // bonus is not in the write schema; the connector must preserve it from the original row. + createAndInitTableCoW("pk INT NOT NULL, salary INT, bonus INT, dep STRING", + """{ "pk": 1, "salary": 100, "bonus": 10, "dep": "hr" } + |{ "pk": 2, "salary": 200, "bonus": 20, "dep": "software" } + |{ "pk": 3, "salary": 300, "bonus": 30, "dep": "hr" } + |""".stripMargin) + + sql(s"UPDATE $tableNameAsString SET salary = -1 WHERE dep = 'hr'") + + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString ORDER BY pk"), + Row(1, -1, 10, "hr") :: + Row(2, 200, 20, "software") :: + Row(3, -1, 30, "hr") :: Nil) + } + + test("column-update CoW: narrow scan + subquery WHERE condition") { + // Exercises buildReplaceDataWithUnionPlan + narrow scan + the flatMap change in + // RowLevelOperationRuntimeGroupFiltering.buildTableToScanAttrMap. + // The subquery forces the UNION path (updated rows + remaining rows). + // bonus is not declared and not assigned, must be excluded from scan and write + // but the subquery-based filter must still work correctly with the narrow scan. + createAndInitTableCoW("pk INT NOT NULL, salary INT, bonus INT, dep STRING", + """{ "pk": 1, "salary": 100, "bonus": 10, "dep": "hr" } + |{ "pk": 2, "salary": 200, "bonus": 20, "dep": "software" } + |{ "pk": 3, "salary": 300, "bonus": 30, "dep": "hr" } + |""".stripMargin) + + import testImplicits._ + val subqueryDF = Seq("hr").toDF() + subqueryDF.createOrReplaceTempView("target_deps") + + sql( + s"""UPDATE $tableNameAsString + |SET salary = -1 + |WHERE dep IN (SELECT * FROM target_deps) + |""".stripMargin) + + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString ORDER BY pk"), + Row(1, -1, 10, "hr") :: + Row(2, 200, 20, "software") :: + Row(3, -1, 30, "hr") :: Nil) + } + + // --------------------------------------------------------------------------- + // Delta connector with representUpdateAsDeleteAndInsert=true AND supportsColumnUpdates=true. + // + // Point 7: The restriction that blocked column-level updates on the delete+reinsert path + // has been removed. The REINSERT leg of the Expand uses only assigned values (the narrow + // write schema from effectiveRowAttrs), and the DELETE leg uses row ID only. + // --------------------------------------------------------------------------- + + private def createAndInitTableSplit(schemaString: String, jsonData: String): Unit = { + val props = new java.util.HashMap[String, String]() + props.put("column-update-split", "true") + val columns = CatalogV2Util.structTypeToV2Columns(StructType.fromDDL(schemaString)) + val transforms = Array[Transform](identity(reference(Seq("dep")))) + val tableInfo = new TableInfo.Builder() + .withColumns(columns) + .withPartitions(transforms) + .withProperties(props) + .build() + catalog.createTable(ident, tableInfo) + append(schemaString, jsonData) + } + + test("column-update split: write schema is narrow (assigned + pk pass-through)") { + // representUpdateAsDeleteAndInsert=true + supportsColumnUpdates=true. + // requiredDataAttributes() = [pk, id] (pk always declared; id because it's being updated). + // The write schema = requiredDataAttributes() in declared order = {pk, id}. + // dep is NOT in the write schema (not declared, not assigned). + createAndInitTableSplit("pk INT NOT NULL, id INT, dep STRING", + """{ "pk": 1, "id": 1, "dep": "hr" } + |{ "pk": 2, "id": 2, "dep": "software" } + |""".stripMargin) + + sql(s"UPDATE $tableNameAsString SET id = -1 WHERE pk = 1") + + // Write schema is exactly requiredDataAttributes = {pk, id} in declared order. + checkLastWriteInfo( + expectedRowSchema = StructType(Seq( + StructField("pk", IntegerType, nullable = false), + StructField("id", IntegerType, nullable = false) + )), + expectedRowIdSchema = Some(StructType(Array(PK_FIELD))), + expectedMetadataSchema = Some(StructType(Array(PARTITION_FIELD, INDEX_FIELD_NULLABLE)))) + } + + test("column-update split: data correctness") { + // representUpdateAsDeleteAndInsert=true + supportsColumnUpdates=true. + // The connector receives narrow REINSERT rows and must reconstruct full rows. + createAndInitTableSplit("pk INT NOT NULL, id INT, dep STRING", + """{ "pk": 1, "id": 1, "dep": "hr" } + |{ "pk": 2, "id": 2, "dep": "software" } + |{ "pk": 3, "id": 3, "dep": "hr" } + |""".stripMargin) + + sql(s"UPDATE $tableNameAsString SET id = -1 WHERE dep = 'hr'") + + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString ORDER BY pk"), + Row(1, -1, "hr") :: Row(2, 2, "software") :: Row(3, -1, "hr") :: Nil) + } +} + diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DeltaBasedUpdateTableSuiteBase.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DeltaBasedUpdateTableSuiteBase.scala index c2db54f8f724b..e828c3671041d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DeltaBasedUpdateTableSuiteBase.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DeltaBasedUpdateTableSuiteBase.scala @@ -23,6 +23,74 @@ abstract class DeltaBasedUpdateTableSuiteBase extends UpdateTableSuiteBase { override protected def deltaUpdate: Boolean = true + // --------------------------------------------------------------------------- + // RowLevelOperationInfo.updatedColumns() -- Spark informs the connector which + // columns are genuinely being updated (non-identity assignments only). + // --------------------------------------------------------------------------- + + test("updatedColumns: single non-identity assignment") { + createAndInitTable("pk INT NOT NULL, id INT, dep STRING", + """{ "pk": 1, "id": 1, "dep": "hr" } + |""".stripMargin) + + sql(s"UPDATE $tableNameAsString SET id = -1 WHERE pk = 1") + + checkLastUpdatedColumns("id") + } + + test("updatedColumns: multiple non-identity assignments") { + createAndInitTable("pk INT NOT NULL, id INT, dep STRING", + """{ "pk": 1, "id": 1, "dep": "hr" } + |""".stripMargin) + + sql(s"UPDATE $tableNameAsString SET id = -1, dep = 'eng' WHERE pk = 1") + + checkLastUpdatedColumns("id", "dep") + } + + test("updatedColumns: identity assignments are excluded") { + createAndInitTable("pk INT NOT NULL, id INT, dep STRING", + """{ "pk": 1, "id": 1, "dep": "hr" } + |""".stripMargin) + + // dep = dep is an identity assignment and must NOT appear in updatedColumns + sql(s"UPDATE $tableNameAsString SET id = -1, dep = dep WHERE pk = 1") + + checkLastUpdatedColumns("id") + } + + test("updatedColumns: empty when all assignments are identity") { + createAndInitTable("pk INT NOT NULL, id INT, dep STRING", + """{ "pk": 1, "id": 1, "dep": "hr" } + |""".stripMargin) + + sql(s"UPDATE $tableNameAsString SET id = id, dep = dep WHERE pk = 1") + + checkLastUpdatedColumns() // expects empty + } + + test("updatedColumns: no WHERE clause still reports assigned columns") { + createAndInitTable("pk INT NOT NULL, id INT, dep STRING", + """{ "pk": 1, "id": 1, "dep": "hr" } + |{ "pk": 2, "id": 2, "dep": "software" } + |""".stripMargin) + + sql(s"UPDATE $tableNameAsString SET dep = 'eng'") + + checkLastUpdatedColumns("dep") + } + + test("updatedColumns: cross-column assignment is not treated as identity") { + createAndInitTable("pk INT NOT NULL, id INT, dep STRING", + """{ "pk": 1, "id": 1, "dep": "hr" } + |""".stripMargin) + + // SET id = dep assigns a different column's value to id -- not identity + sql(s"UPDATE $tableNameAsString SET id = 0, dep = dep WHERE pk = 1") + + checkLastUpdatedColumns("id") + } + test("nullable row ID attrs") { createAndInitTable("pk INT, salary INT, dep STRING", """{ "pk": 1, "salary": 300, "dep": 'hr' } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/RowLevelOperationSuiteBase.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/RowLevelOperationSuiteBase.scala index 8c51fb17b2cf4..d6b4771539153 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/RowLevelOperationSuiteBase.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/RowLevelOperationSuiteBase.scala @@ -238,6 +238,18 @@ abstract class RowLevelOperationSuiteBase assert(actualMetadataSchema == expectedMetadataSchema, "metadata schema must match") } + /** + * Asserts that the column names in RowLevelOperationInfo.updatedColumns() received by the + * last operation match exactly the expected set. Order is ignored. + */ + protected def checkLastUpdatedColumns(expectedNames: String*): Unit = { + val actual = table.lastUpdatedColumns.map(_.describe()).toSet + val expected = expectedNames.toSet + assert(actual == expected, + s"updatedColumns mismatch: expected ${expected.mkString("[", ", ", "]")} " + + s"but got ${actual.mkString("[", ", ", "]")}") + } + protected def checkLastWriteLog(expectedEntries: WriteLogEntry*): Unit = { val entryType = new StructType() .add(StructField("operation", StringType))