Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ object RewriteDeleteFromTable extends RewriteRowLevelCommand {

// build a plan to replace read groups in the table
val writeRelation = relation.copy(table = operationTable)
val query = addOperationColumn(WRITE_WITH_METADATA_OPERATION, remainingRowsPlan)
val query = addOperationColumn(COPY_OPERATION, remainingRowsPlan)
val projections = buildReplaceDataProjections(query, relation.output, metadataAttrs)
val groupFilterCond = if (groupFilterEnabled) Some(cond) else None
ReplaceData(writeRelation, cond, query, relation, projections, groupFilterCond)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
import org.apache.spark.sql.catalyst.plans.{FullOuter, Inner, JoinType, LeftAnti, LeftOuter, RightOuter}
import org.apache.spark.sql.catalyst.plans.logical.{AppendData, DeleteAction, Filter, HintInfo, InsertAction, Join, JoinHint, LogicalPlan, MergeAction, MergeIntoTable, MergeRows, NO_BROADCAST_AND_REPLICATION, Project, ReplaceData, UpdateAction, WriteDelta}
import org.apache.spark.sql.catalyst.plans.logical.MergeRows.{Copy, Delete, Discard, Insert, Instruction, Keep, ROW_ID, Split, Update}
import org.apache.spark.sql.catalyst.util.RowDeltaUtils.{OPERATION_COLUMN, WRITE_OPERATION, WRITE_WITH_METADATA_OPERATION}
import org.apache.spark.sql.catalyst.util.RowDeltaUtils.{COPY_OPERATION, INSERT_OPERATION, OPERATION_COLUMN, UPDATE_OPERATION}
import org.apache.spark.sql.connector.catalog.SupportsRowLevelOperations
import org.apache.spark.sql.connector.write.{RowLevelOperationTable, SupportsDelta}
import org.apache.spark.sql.connector.write.RowLevelOperation.Command.MERGE
Expand Down Expand Up @@ -202,7 +202,7 @@ object RewriteMergeIntoTable extends RewriteRowLevelCommand with PredicateHelper
// that's why an extra unconditional instruction that would produce the original row is added
// as the last MATCHED and NOT MATCHED BY SOURCE instruction
// this logic is specific to data sources that replace groups of data
val carryoverRowsOutput = Literal(WRITE_WITH_METADATA_OPERATION) +: targetTable.output
val carryoverRowsOutput = Literal(COPY_OPERATION) +: targetTable.output
val keepCarryoverRowsInstruction = Keep(Copy, TrueLiteral, carryoverRowsOutput)

val matchedInstructions = matchedActions.map { action =>
Expand Down Expand Up @@ -439,7 +439,7 @@ object RewriteMergeIntoTable extends RewriteRowLevelCommand with PredicateHelper
case UpdateAction(cond, assignments, _) =>
val rowValues = assignments.map(_.value)
val metadataValues = nullifyMetadataOnUpdate(metadataAttrs)
val output = Seq(Literal(WRITE_WITH_METADATA_OPERATION)) ++ rowValues ++ metadataValues
val output = Seq(Literal(UPDATE_OPERATION)) ++ rowValues ++ metadataValues
Keep(Update, cond.getOrElse(TrueLiteral), output)

case DeleteAction(cond) =>
Expand All @@ -448,7 +448,7 @@ object RewriteMergeIntoTable extends RewriteRowLevelCommand with PredicateHelper
case InsertAction(cond, assignments) =>
val rowValues = assignments.map(_.value)
val metadataValues = metadataAttrs.map(attr => Literal(null, attr.dataType))
val output = Seq(Literal(WRITE_OPERATION)) ++ rowValues ++ metadataValues
val output = Seq(Literal(INSERT_OPERATION)) ++ rowValues ++ metadataValues
Keep(Insert, cond.getOrElse(TrueLiteral), output)

case other =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@ import scala.collection.mutable

import org.apache.spark.SparkException
import org.apache.spark.sql.catalyst.ProjectingInternalRow
import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeReference, AttributeSet, Expression, ExprId, Literal, MetadataAttribute, NamedExpression, V2ExpressionUtils}
import org.apache.spark.sql.catalyst.plans.logical.{Assignment, Expand, LogicalPlan, MergeRows, Project}
import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeReference, AttributeSet, Expression, ExprId, If, Literal, MetadataAttribute, NamedExpression, V2ExpressionUtils}
import org.apache.spark.sql.catalyst.plans.logical.{Assignment, Expand, LogicalPlan, MergeRows, Project, Union}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.catalyst.util.{ReplaceDataProjections, WriteDeltaProjections}
import org.apache.spark.sql.catalyst.util.RowDeltaUtils._
Expand All @@ -38,11 +38,11 @@ import org.apache.spark.util.ArrayImplicits._

trait RewriteRowLevelCommand extends Rule[LogicalPlan] {

private final val DELTA_OPERATIONS_WITH_ROW =
Set(UPDATE_OPERATION, REINSERT_OPERATION, INSERT_OPERATION)
private final val DELTA_OPERATIONS_WITH_METADATA =
Set(DELETE_OPERATION, UPDATE_OPERATION, REINSERT_OPERATION)
private final val DELTA_OPERATIONS_WITH_ROW_ID =
private final val OPERATIONS_WITH_ROW =
Set(UPDATE_OPERATION, REINSERT_OPERATION, INSERT_OPERATION, COPY_OPERATION)
private final val OPERATIONS_WITH_METADATA =
Set(DELETE_OPERATION, UPDATE_OPERATION, REINSERT_OPERATION, COPY_OPERATION)
private final val OPERATIONS_WITH_ROW_ID =
Set(DELETE_OPERATION, UPDATE_OPERATION)

protected def groupFilterEnabled: Boolean = conf.runtimeRowLevelOperationGroupFilterEnabled
Expand Down Expand Up @@ -191,11 +191,11 @@ trait RewriteRowLevelCommand extends Rule[LogicalPlan] {
metadataAttrs: Seq[Attribute]): ReplaceDataProjections = {
val outputs = extractOutputs(plan)

val outputsWithRow = filterOutputs(outputs, Set(WRITE_WITH_METADATA_OPERATION, WRITE_OPERATION))
val outputsWithRow = filterOutputs(outputs, OPERATIONS_WITH_ROW)
val rowProjection = newLazyProjection(plan, outputsWithRow, rowAttrs)

val metadataProjection = if (metadataAttrs.nonEmpty) {
val outputsWithMetadata = filterOutputs(outputs, Set(WRITE_WITH_METADATA_OPERATION))
val outputsWithMetadata = filterOutputs(outputs, OPERATIONS_WITH_METADATA)
Some(newLazyProjection(plan, outputsWithMetadata, metadataAttrs))
} else {
None
Expand All @@ -212,17 +212,17 @@ trait RewriteRowLevelCommand extends Rule[LogicalPlan] {
val outputs = extractOutputs(plan)

val rowProjection = if (rowAttrs.nonEmpty) {
val outputsWithRow = filterOutputs(outputs, DELTA_OPERATIONS_WITH_ROW)
val outputsWithRow = filterOutputs(outputs, OPERATIONS_WITH_ROW)
Some(newLazyProjection(plan, outputsWithRow, rowAttrs))
} else {
None
}

val outputsWithRowId = filterOutputs(outputs, DELTA_OPERATIONS_WITH_ROW_ID)
val outputsWithRowId = filterOutputs(outputs, OPERATIONS_WITH_ROW_ID)
val rowIdProjection = newLazyRowIdProjection(plan, outputsWithRowId, rowIdAttrs)

val metadataProjection = if (metadataAttrs.nonEmpty) {
val outputsWithMetadata = filterOutputs(outputs, DELTA_OPERATIONS_WITH_METADATA)
val outputsWithMetadata = filterOutputs(outputs, OPERATIONS_WITH_METADATA)
Some(newLazyProjection(plan, outputsWithMetadata, metadataAttrs))
} else {
None
Expand All @@ -236,18 +236,21 @@ trait RewriteRowLevelCommand extends Rule[LogicalPlan] {
case p: Project => Seq(p.projectList)
case e: Expand => e.projections
case m: MergeRows => m.outputs
case u: Union => u.children.flatMap(extractOutputs)
case _ => throw SparkException.internalError("Can't extract outputs from plan: " + plan)
}
}

private def filterOutputs(
outputs: Seq[Seq[Expression]],
operations: Set[Int]): Seq[Seq[Expression]] = {
outputs.filter {
case Literal(operation: Integer, _) +: _ => operations.contains(operation)
case Alias(Literal(operation: Integer, _), _) +: _ => operations.contains(operation)
def matches(expr: Expression): Boolean = expr match {
case Literal(operation: Integer, _) => operations.contains(operation)
case Alias(child, _) => matches(child)
case If(_, trueValue, falseValue) => matches(trueValue) && matches(falseValue)
case other => throw SparkException.internalError("Can't determine operation: " + other)
}
outputs.filter(output => matches(output.head))
}

private def newLazyProjection(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,12 +72,10 @@ object RewriteUpdateTable extends RewriteRowLevelCommand {
val readRelation = buildRelationWithAttrs(relation, operationTable, metadataAttrs)

// build a plan with updated and copied over records
val updatedAndRemainingRowsPlan = buildReplaceDataUpdateProjection(
readRelation, assignments, cond)
val query = buildReplaceDataUpdateProjection(readRelation, assignments, cond)

// build a plan to replace read groups in the table
val writeRelation = relation.copy(table = operationTable)
val query = addOperationColumn(WRITE_WITH_METADATA_OPERATION, updatedAndRemainingRowsPlan)
val projections = buildReplaceDataProjections(query, relation.output, metadataAttrs)
val groupFilterCond = if (groupFilterEnabled) Some(cond) else None
ReplaceData(writeRelation, cond, query, relation, projections, groupFilterCond)
Expand Down Expand Up @@ -105,14 +103,14 @@ object RewriteUpdateTable extends RewriteRowLevelCommand {

// build a plan that contains unmatched rows in matched groups that must be copied over
val remainingRowFilter = Not(EqualNullSafe(cond, Literal.TrueLiteral))
val remainingRowsPlan = Filter(remainingRowFilter, readRelation)
val remainingRowsPlan = addOperationColumn(COPY_OPERATION,
Filter(remainingRowFilter, readRelation))

// the new state is a union of updated and copied over records
val updatedAndRemainingRowsPlan = Union(updatedRowsPlan, remainingRowsPlan)
val query = Union(updatedRowsPlan, remainingRowsPlan)

// build a plan to replace read groups in the table
val writeRelation = relation.copy(table = operationTable)
val query = addOperationColumn(WRITE_WITH_METADATA_OPERATION, updatedAndRemainingRowsPlan)
val projections = buildReplaceDataProjections(query, relation.output, metadataAttrs)
val groupFilterCond = if (groupFilterEnabled) Some(cond) else None
ReplaceData(writeRelation, cond, query, relation, projections, groupFilterCond)
Expand Down Expand Up @@ -143,7 +141,9 @@ object RewriteUpdateTable extends RewriteRowLevelCommand {
}
}

Project(updatedValues, plan)
val writeOp = If(cond, Literal(UPDATE_OPERATION), Literal(COPY_OPERATION))
val operationCol = Alias(writeOp, OPERATION_COLUMN)()
Project(operationCol +: updatedValues, plan)
}

// build a rewrite plan for sources that support row deltas
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ object RowDeltaUtils {
final val UPDATE_OPERATION: Int = 2
final val INSERT_OPERATION: Int = 3
final val REINSERT_OPERATION: Int = 4
final val WRITE_OPERATION: Int = 5
final val WRITE_WITH_METADATA_OPERATION: Int = 6
final val COPY_OPERATION: Int = 5
final val ORIGINAL_ROW_ID_VALUE_PREFIX: String = "__original_row_id_"
}
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,8 @@ class InMemoryRowLevelOperationTable(
private final val INDEX_COLUMN_REF = FieldReference(IndexColumn.name)
private final val SUPPORTS_DELTAS = "supports-deltas"
private final val SPLIT_UPDATES = "split-updates"
private final val NO_METADATA = "no-metadata"
private final val noMetadata = properties.getOrDefault(NO_METADATA, "false") == "true"

// used in row-level operation tests to verify replaced partitions
var replacedPartitions: Seq[Seq[Any]] = Seq.empty
Expand All @@ -73,7 +75,11 @@ class InMemoryRowLevelOperationTable(
var configuredScan: InMemoryBatchScan = _

override def requiredMetadataAttributes(): Array[NamedReference] = {
Array(PARTITION_COLUMN_REF, INDEX_COLUMN_REF)
if (noMetadata) {
Array.empty
} else {
Array(PARTITION_COLUMN_REF, INDEX_COLUMN_REF)
}
}

override def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder = {
Expand All @@ -89,22 +95,29 @@ class InMemoryRowLevelOperationTable(
override def newWriteBuilder(info: LogicalWriteInfo): WriteBuilder = {
lastWriteInfo = info
new WriteBuilder {
override def build(): Write = new Write with RequiresDistributionAndOrdering {
override def requiredDistribution: Distribution = {
Distributions.clustered(Array(PARTITION_COLUMN_REF))
override def build(): Write = if (noMetadata) {
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

question: Is it intentional that the noMetadata path bypasses RequiresDistributionAndOrdering? This exercises a different physical plan (no shuffle/sort). If the goal is just to test the no-metadata code path, consider keeping the distribution/ordering requirements so these tests cover the same physical plan shape as the existing suites.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's intentional. Without metadata, we don't have PARTITION_COLUMN_REF. Without it, we can't guarantee partitioning and can't know which column to use for distribution / ordering. This still exercises the code paths for different writing tasks, so it should be enough.

new Write {
override def toBatch: BatchWrite = PartitionBasedReplaceData(configuredScan)
override def description: String = "InMemoryWrite"
}

override def requiredOrdering: Array[SortOrder] = {
Array[SortOrder](
LogicalExpressions.sort(
PARTITION_COLUMN_REF,
SortDirection.ASCENDING,
SortDirection.ASCENDING.defaultNullOrdering()))
} else {
new Write with RequiresDistributionAndOrdering {
override def requiredDistribution: Distribution = {
Distributions.clustered(Array(PARTITION_COLUMN_REF))
}

override def requiredOrdering: Array[SortOrder] = {
Array[SortOrder](
LogicalExpressions.sort(
PARTITION_COLUMN_REF,
SortDirection.ASCENDING,
SortDirection.ASCENDING.defaultNullOrdering()))
}

override def toBatch: BatchWrite = PartitionBasedReplaceData(configuredScan)

override def description: String = "InMemoryWrite"
}

override def toBatch: BatchWrite = PartitionBasedReplaceData(configuredScan)

override def description: String = "InMemoryWrite"
}
}
}
Expand Down Expand Up @@ -138,7 +151,11 @@ class InMemoryRowLevelOperationTable(
private final val PK_COLUMN_REF = FieldReference("pk")

override def requiredMetadataAttributes(): Array[NamedReference] = {
Array(PARTITION_COLUMN_REF, INDEX_COLUMN_REF)
if (noMetadata) {
Array.empty
} else {
Array(PARTITION_COLUMN_REF, INDEX_COLUMN_REF)
}
}

override def rowId(): Array[NamedReference] = Array(PK_COLUMN_REF)
Expand All @@ -150,22 +167,28 @@ class InMemoryRowLevelOperationTable(
override def newWriteBuilder(info: LogicalWriteInfo): DeltaWriteBuilder = {
lastWriteInfo = info
new DeltaWriteBuilder {
override def build(): DeltaWrite = new DeltaWrite with RequiresDistributionAndOrdering {

override def requiredDistribution(): Distribution = {
Distributions.clustered(Array(PARTITION_COLUMN_REF))
override def build(): DeltaWrite = if (noMetadata) {
new DeltaWrite {
override def toBatch: DeltaBatchWrite = TestDeltaBatchWrite
}

override def requiredOrdering(): Array[SortOrder] = {
Array[SortOrder](
LogicalExpressions.sort(
PARTITION_COLUMN_REF,
SortDirection.ASCENDING,
SortDirection.ASCENDING.defaultNullOrdering())
)
} else {
new DeltaWrite with RequiresDistributionAndOrdering {

override def requiredDistribution(): Distribution = {
Distributions.clustered(Array(PARTITION_COLUMN_REF))
}

override def requiredOrdering(): Array[SortOrder] = {
Array[SortOrder](
LogicalExpressions.sort(
PARTITION_COLUMN_REF,
SortDirection.ASCENDING,
SortDirection.ASCENDING.defaultNullOrdering())
)
}

override def toBatch: DeltaBatchWrite = TestDeltaBatchWrite
}

override def toBatch: DeltaBatchWrite = TestDeltaBatchWrite
}
}
}
Expand Down Expand Up @@ -208,21 +231,24 @@ private class DeltaBufferWriter(schema: StructType) extends BufferWriter(schema)
override def delete(meta: InternalRow, id: InternalRow): Unit = {
val pk = id.getInt(0)
buffer.deletes += pk
val logEntry = new GenericInternalRow(Array[Any](DELETE, pk, meta.copy(), null))
val metaCopy = if (meta != null) meta.copy() else null
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This null guard is needed because DeltaWritingSparkTask passes null for metadata when requiredMetadataAttributes() is empty. However, the DeltaWriter API methods (delete(meta, id), update(meta, id, row), reinsert(meta, row)) don't document that meta can be null. Third-party connectors could hit the same NPE. Consider adding Javadoc on those API methods to clarify the contract.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe this falls outside the scope of this PR.

val logEntry = new GenericInternalRow(Array[Any](DELETE, pk, metaCopy, null))
buffer.log += logEntry
}

override def update(meta: InternalRow, id: InternalRow, row: InternalRow): Unit = {
val pk = id.getInt(0)
buffer.deletes += pk
buffer.rows.append(row.copy())
val logEntry = new GenericInternalRow(Array[Any](UPDATE, pk, meta.copy(), row.copy()))
val metaCopy = if (meta != null) meta.copy() else null
val logEntry = new GenericInternalRow(Array[Any](UPDATE, pk, metaCopy, row.copy()))
buffer.log += logEntry
}

override def reinsert(meta: InternalRow, row: InternalRow): Unit = {
buffer.rows.append(row.copy())
val logEntry = new GenericInternalRow(Array[Any](REINSERT, null, meta.copy(), row.copy()))
val metaCopy = if (meta != null) meta.copy() else null
val logEntry = new GenericInternalRow(Array[Any](REINSERT, null, metaCopy, row.copy()))
buffer.log += logEntry
}

Expand Down
Loading