Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.connector.write;

import org.apache.spark.annotation.Evolving;

/**
* Provides an informational summary of the INSERT operation producing write.
*
* @since 4.2.0
*/
@Evolving
public interface InsertSummary extends WriteSummary {

/**
* Returns the number of output rows, or -1 if not found.
*/
long numOutputRows();
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would calling this numInsertedRows be more consistent with the rest of summaries?

}
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ object RewriteMergeIntoTable extends RewriteRowLevelCommand with PredicateHelper
}
val project = Project(projectList, joinPlan)

AppendData.byPosition(r, project)
AppendData.byPosition(r, project, rowLevelCommand = Some(MERGE))

case _ =>
m
Expand Down Expand Up @@ -114,7 +114,7 @@ object RewriteMergeIntoTable extends RewriteRowLevelCommand with PredicateHelper
output = generateExpandOutput(r.output, outputs),
joinPlan)

AppendData.byPosition(r, mergeRows)
AppendData.byPosition(r, mergeRows, rowLevelCommand = Some(MERGE))

case _ =>
m
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,11 @@ case class AppendData(
isByName: Boolean,
withSchemaEvolution: Boolean,
write: Option[Write] = None,
analyzedQuery: Option[LogicalPlan] = None) extends V2WriteCommand with TransactionalWrite {
analyzedQuery: Option[LogicalPlan] = None,
rowLevelCommand: Option[RowLevelOperation.Command] = None)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How do you feel about adding a new node like InsertOnlyMerge and InsertOnlyMergeExec instead of making AppendData that is heavily used aware of the row-level operation?

I vibecoded it locally and it is fairly easy.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't feel strongly here. More like 60/40 to have a new node, to be honest.

extends V2WriteCommand
with TransactionalWrite {

override val writePrivileges: Set[TableWritePrivilege] = Set(TableWritePrivilege.INSERT)
override def withNewQuery(newQuery: LogicalPlan): AppendData = copy(query = newQuery)
override def withNewTable(newTable: NamedRelation): AppendData = copy(table = newTable)
Expand All @@ -184,13 +188,15 @@ object AppendData {
table: NamedRelation,
query: LogicalPlan,
writeOptions: Map[String, String] = Map.empty,
withSchemaEvolution: Boolean = false): AppendData = {
withSchemaEvolution: Boolean = false,
rowLevelCommand: Option[RowLevelOperation.Command] = None): AppendData = {
new AppendData(
table,
query,
writeOptions,
isByName = false,
withSchemaEvolution)
withSchemaEvolution,
rowLevelCommand = rowLevelCommand)
}
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.connector.write

/**
* Implementation of [[InsertSummary]] that provides INSERT operation summary.
*/
private[sql] case class InsertSummaryImpl(numOutputRows: Long) extends InsertSummary {
}
Original file line number Diff line number Diff line change
Expand Up @@ -788,30 +788,43 @@ abstract class InMemoryBaseTable(
}

override def abort(messages: Array[WriterCommitMessage]): Unit = {}

protected def doCommit(messages: Array[WriterCommitMessage]): Unit

override final def commit(messages: Array[WriterCommitMessage]): Unit = {
doCommit(messages)
commits += Commit(Instant.now().toEpochMilli)
}

override final def commit(
messages: Array[WriterCommitMessage],
summary: WriteSummary): Unit = {
doCommit(messages)
commits += Commit(Instant.now().toEpochMilli, writeSummary = Some(summary))
}
}

class Append(val info: LogicalWriteInfo) extends TestBatchWrite {

override def commit(messages: Array[WriterCommitMessage]): Unit = dataMap.synchronized {
override protected def doCommit(
messages: Array[WriterCommitMessage]): Unit = dataMap.synchronized {
withData(messages.map(_.asInstanceOf[BufferedRows]))
commits += Commit(Instant.now().toEpochMilli)
}
}

class DynamicOverwrite(val info: LogicalWriteInfo) extends TestBatchWrite {
override def commit(messages: Array[WriterCommitMessage]): Unit = dataMap.synchronized {
override protected def doCommit(
messages: Array[WriterCommitMessage]): Unit = dataMap.synchronized {
val newData = messages.map(_.asInstanceOf[BufferedRows])
dataMap --= newData.flatMap(_.rows.map(getKey))
withData(newData)
commits += Commit(Instant.now().toEpochMilli)
}
}

class TruncateAndAppend(val info: LogicalWriteInfo) extends TestBatchWrite {
override def commit(messages: Array[WriterCommitMessage]): Unit = dataMap.synchronized {
override protected def doCommit(
messages: Array[WriterCommitMessage]): Unit = dataMap.synchronized {
dataMap.clear()
withData(messages.map(_.asInstanceOf[BufferedRows]))
commits += Commit(Instant.now().toEpochMilli)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@

package org.apache.spark.sql.connector.catalog

import java.time.Instant
import java.util

import org.apache.spark.sql.catalyst.InternalRow
Expand All @@ -26,7 +25,7 @@ import org.apache.spark.sql.connector.catalog.constraints.Constraint
import org.apache.spark.sql.connector.distributions.{Distribution, Distributions}
import org.apache.spark.sql.connector.expressions.{FieldReference, LogicalExpressions, NamedReference, SortDirection, SortOrder, Transform}
import org.apache.spark.sql.connector.read.{Scan, ScanBuilder}
import org.apache.spark.sql.connector.write.{BatchWrite, DeltaBatchWrite, DeltaWrite, DeltaWriteBuilder, DeltaWriter, DeltaWriterFactory, LogicalWriteInfo, PhysicalWriteInfo, RequiresDistributionAndOrdering, RowLevelOperation, RowLevelOperationBuilder, RowLevelOperationInfo, SupportsDelta, Write, WriteBuilder, WriterCommitMessage, WriteSummary}
import org.apache.spark.sql.connector.write.{BatchWrite, DeltaBatchWrite, DeltaWrite, DeltaWriteBuilder, DeltaWriter, DeltaWriterFactory, LogicalWriteInfo, PhysicalWriteInfo, RequiresDistributionAndOrdering, RowLevelOperation, RowLevelOperationBuilder, RowLevelOperationInfo, SupportsDelta, Write, WriteBuilder, WriterCommitMessage}
import org.apache.spark.sql.connector.write.RowLevelOperation.Command
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.util.CaseInsensitiveStringMap
Expand Down Expand Up @@ -143,18 +142,11 @@ class InMemoryRowLevelOperationTable private (
override def description(): String = "InMemoryPartitionReplaceOperation"
}

abstract class RowLevelOperationBatchWrite extends TestBatchWrite {

override def commit(messages: Array[WriterCommitMessage], metrics: WriteSummary): Unit = {
commit(messages)
commits += Commit(Instant.now().toEpochMilli, Some(metrics))
}
}

private case class PartitionBasedReplaceData(scan: InMemoryBatchScan)
extends RowLevelOperationBatchWrite {
extends TestBatchWrite {

override def commit(messages: Array[WriterCommitMessage]): Unit = dataMap.synchronized {
override protected def doCommit(
messages: Array[WriterCommitMessage]): Unit = dataMap.synchronized {
val newData = messages.map(_.asInstanceOf[BufferedRows])
val readRows = scan.data.flatMap(_.asInstanceOf[BufferedRows].rows)
val readPartitions = readRows.map(r => getKey(r, schema)).distinct
Expand Down Expand Up @@ -216,12 +208,12 @@ class InMemoryRowLevelOperationTable private (
}
}

private object TestDeltaBatchWrite extends RowLevelOperationBatchWrite with DeltaBatchWrite{
private object TestDeltaBatchWrite extends TestBatchWrite with DeltaBatchWrite {
override def createBatchWriterFactory(info: PhysicalWriteInfo): DeltaWriterFactory = {
new DeltaBufferedRowsWriterFactory(CatalogV2Util.v2ColumnsToStructType(columns()))
}

override def commit(messages: Array[WriterCommitMessage]): Unit = {
override protected def doCommit(messages: Array[WriterCommitMessage]): Unit = {
val newData = messages.map(_.asInstanceOf[BufferedRows])
withDeletes(newData)
withData(newData, columns())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,8 @@ class InMemoryTable(

private class Overwrite(filters: Array[Filter]) extends TestBatchWrite {
import org.apache.spark.sql.connector.catalog.CatalogV2Implicits.MultipartIdentifierHelper
override def commit(messages: Array[WriterCommitMessage]): Unit = dataMap.synchronized {
override protected def doCommit(
messages: Array[WriterCommitMessage]): Unit = dataMap.synchronized {
val deleteKeys = InMemoryTable.filtersToKeys(
dataMap.keys, partCols.map(_.toSeq.quoted).toImmutableArraySeq, filters)
dataMap --= deleteKeys
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,8 @@ class InMemoryTableWithV2Filter(

private class Overwrite(predicates: Array[Predicate]) extends TestBatchWrite {
import org.apache.spark.sql.connector.catalog.CatalogV2Implicits.MultipartIdentifierHelper
override def commit(messages: Array[WriterCommitMessage]): Unit = dataMap.synchronized {
override protected def doCommit(
messages: Array[WriterCommitMessage]): Unit = dataMap.synchronized {
val deleteKeys = InMemoryTableWithV2Filter.filtersToKeys(
dataMap.keys, partCols.map(_.toSeq.quoted).toImmutableArraySeq, predicates)
dataMap --= deleteKeys
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -320,7 +320,7 @@ class FindDataSourceTable(sparkSession: SparkSession) extends Rule[LogicalPlan]
i.copy(table = DDLUtils.readHiveTable(tableMeta))

case append @ AppendData(
ExtractV2Table(V1Table(table: CatalogTable)), _, _, _, _, _, _) if !append.isByName =>
ExtractV2Table(V1Table(table: CatalogTable)), _, _, _, _, _, _, _) if !append.isByName =>
InsertIntoStatement(UnresolvedCatalogRelation(table),
table.partitionColumnNames.map(name => name -> None).toMap,
Seq.empty, append.query, false, append.isByName)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -446,7 +446,7 @@ class DataSourceV2Strategy(session: SparkSession) extends Strategy with Predicat
}

case AppendData(r @ ExtractV2Table(v1: SupportsWrite), _, _,
_, _, Some(write), analyzedQuery) if v1.supports(TableCapability.V1_BATCH_WRITE) =>
_, _, Some(write), analyzedQuery, _) if v1.supports(TableCapability.V1_BATCH_WRITE) =>
write match {
case v1Write: V1Write =>
assert(analyzedQuery.isDefined)
Expand All @@ -456,8 +456,8 @@ class DataSourceV2Strategy(session: SparkSession) extends Strategy with Predicat
v1, v2Write.getClass.getName, classOf[V1Write].getName)
}

case AppendData(r: DataSourceV2Relation, query, _, _, _, Some(write), _) =>
AppendDataExec(planLater(query), refreshCache(r), write, r.name) :: Nil
case AppendData(r: DataSourceV2Relation, query, _, _, _, Some(write), _, rowLevelCommand) =>
AppendDataExec(planLater(query), refreshCache(r), write, r.name, None, rowLevelCommand) :: Nil

case OverwriteByExpression(r @ ExtractV2Table(v1: SupportsWrite), _, _,
_, _, _, Some(write), analyzedQuery) if v1.supports(TableCapability.V1_BATCH_WRITE) =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,8 @@ object TableCapabilityCheck extends (LogicalPlan => Unit) {

// TODO: check STREAMING_WRITE capability. It's not doable now because we don't have a
// a logical plan for streaming write.
case AppendData(r: DataSourceV2Relation, _, _, _, _, _, _) if !supportsBatchWrite(r.table) =>
case AppendData(r: DataSourceV2Relation, _, _, _, _, _, _, _)
if !supportsBatchWrite(r.table) =>
throw QueryCompilationErrors.unsupportedAppendInBatchModeError(r.name)

case OverwritePartitionsDynamic(r: DataSourceV2Relation, _, _, _, _, _)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ object V2Writes extends Rule[LogicalPlan] with PredicateHelper {
import DataSourceV2Implicits._

override def apply(plan: LogicalPlan): LogicalPlan = plan transformDown {
case a @ AppendData(r: DataSourceV2Relation, query, options, _, _, None, _) =>
case a @ AppendData(r: DataSourceV2Relation, query, options, _, _, None, _, _) =>
val writeOptions = mergeOptions(options, r.options.asCaseSensitiveMap.asScala.toMap)
val writeBuilder = newWriteBuilder(r.table, writeOptions, query.schema)
val write = writeBuilder.build()
Expand Down
Loading