From 57ddd40aed33d3d73ab8503716be91625effb876 Mon Sep 17 00:00:00 2001 From: Szehon Ho Date: Mon, 24 Nov 2025 21:13:33 -0800 Subject: [PATCH 1/2] [SPARK-54496][SQL] Fix Merge Into Schema Evolution for Dataframe API --- .../ResolveMergeIntoSchemaEvolution.scala | 15 +- .../catalyst/plans/logical/v2Commands.scala | 30 +- .../connector/MergeIntoTableSuiteBase.scala | 1052 ++++++++++++++++- 3 files changed, 1081 insertions(+), 16 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveMergeIntoSchemaEvolution.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveMergeIntoSchemaEvolution.scala index ea0883f7928f..bbb8e7852b2c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveMergeIntoSchemaEvolution.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveMergeIntoSchemaEvolution.scala @@ -17,10 +17,13 @@ package org.apache.spark.sql.catalyst.analysis +import scala.collection.mutable.ArrayBuffer + +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap} import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.catalyst.types.DataTypeUtils -import org.apache.spark.sql.connector.catalog.{CatalogV2Util, SupportsRowLevelOperations, TableCatalog, TableChange} +import org.apache.spark.sql.connector.catalog.{CatalogV2Util, TableCatalog, TableChange} import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._ import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation @@ -42,15 +45,19 @@ object ResolveMergeIntoSchemaEvolution extends Rule[LogicalPlan] { if (changes.isEmpty) { m } else { - m transformUpWithNewOutput { - case r @ DataSourceV2Relation(_: SupportsRowLevelOperations, _, _, _, _, _) => + val finalAttrMapping = ArrayBuffer.empty[(Attribute, Attribute)] + val newTarget = m.targetTable.transform { + case r: DataSourceV2Relation => val referencedSourceSchema = MergeIntoTable.sourceSchemaForSchemaEvolution(m) val newTarget = performSchemaEvolution(r, referencedSourceSchema, changes) val oldTargetOutput = m.targetTable.output val newTargetOutput = newTarget.output val attributeMapping = oldTargetOutput.zip(newTargetOutput) - newTarget -> attributeMapping + finalAttrMapping ++= attributeMapping + newTarget } + val res = m.copy(targetTable = newTarget) + res.rewriteAttrs(AttributeMap(finalAttrMapping.toSeq)) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala index 3f9e8da21d28..72274ee9bf17 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala @@ -916,19 +916,29 @@ case class MergeIntoTable( false } else { val actions = matchedActions ++ notMatchedActions - val assignments = actions.collect { - case a: UpdateAction => a.assignments - case a: InsertAction => a.assignments - }.flatten - val sourcePaths = DataTypeUtils.extractAllFieldPaths(sourceTable.schema) - assignments.forall { assignment => - assignment.resolved || - (assignment.value.resolved && sourcePaths.exists { - path => MergeIntoTable.isEqual(assignment, path) - }) + val hasStarActions = actions.exists { + case _: UpdateStarAction => true + case _: InsertStarAction => true + case _ => false + } + if (hasStarActions) { + // need to resolve star actions first + false + } else { + val assignments = actions.collect { + case a: UpdateAction => a.assignments + case a: InsertAction => a.assignments + }.flatten + val sourcePaths = DataTypeUtils.extractAllFieldPaths(sourceTable.schema) + assignments.forall { assignment => + assignment.resolved || + (assignment.value.resolved && sourcePaths.exists { + path => MergeIntoTable.isEqual(assignment, path) + }) } } } + } private lazy val sourceSchemaForEvolution: StructType = MergeIntoTable.sourceSchemaForSchemaEvolution(this) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/MergeIntoTableSuiteBase.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/MergeIntoTableSuiteBase.scala index b7a8ff374b84..9d9d8e48268e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/MergeIntoTableSuiteBase.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/MergeIntoTableSuiteBase.scala @@ -21,13 +21,14 @@ import org.apache.spark.SparkRuntimeException import org.apache.spark.sql.{AnalysisException, Row} import org.apache.spark.sql.catalyst.expressions.{AttributeReference, EqualTo, In, Not} import org.apache.spark.sql.catalyst.optimizer.BuildLeft -import org.apache.spark.sql.connector.catalog.{CatalogV2Util, Column, ColumnDefaultValue, InMemoryTable, TableInfo} +import org.apache.spark.sql.connector.catalog.{CatalogV2Util, Column, ColumnDefaultValue, Identifier, InMemoryTable, TableInfo} import org.apache.spark.sql.connector.expressions.{GeneralScalarExpression, LiteralValue} import org.apache.spark.sql.connector.write.MergeSummary import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper import org.apache.spark.sql.execution.datasources.v2.MergeRowsExec import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, BroadcastNestedLoopJoinExec, CartesianProductExec} +import org.apache.spark.sql.functions.{array, col, lit, map, struct, substring} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{ArrayType, BooleanType, IntegerType, LongType, MapType, StringType, StructField, StructType} @@ -4411,7 +4412,8 @@ abstract class MergeIntoTableSuiteBase extends RowLevelOperationSuiteBase } } - test("Merge schema evolution should evolve referencing new column assigned to something else") { + test("Merge schema evolution should not evolve when referencing new column" + + "assigned to something else") { Seq(true, false).foreach { withSchemaEvolution => withTempView("source") { createAndInitTable("pk INT NOT NULL, salary INT, dep STRING", @@ -5233,6 +5235,1052 @@ abstract class MergeIntoTableSuiteBase extends RowLevelOperationSuiteBase sql(s"DROP TABLE IF EXISTS $tableNameAsString") } + test("merge with schema evolution using dataframe API: add new column and set all") { + val sourceTable = "cat.ns1.source_table" + withTable("source") { + createAndInitTable("pk INT NOT NULL, salary INT, dep STRING", + """{ "pk": 1, "salary": 100, "dep": "hr" } + |{ "pk": 2, "salary": 200, "dep": "software" } + |""".stripMargin) + + val sourceIdent = Identifier.of(Array("ns1"), "source_table") + val columns = Array( + Column.create("pk", IntegerType, false), + Column.create("salary", IntegerType), + Column.create("dep", StringType), + Column.create("new_col", IntegerType)) + val tableInfo = new TableInfo.Builder() + .withColumns(columns) + .withProperties(extraTableProps) + .build() + catalog.createTable(sourceIdent, tableInfo) + + sql(s"INSERT INTO $sourceTable VALUES (1, 101, 'support', 1)," + + s"(3, 301, 'support', 3), (4, 401, 'finance', 4)") + + spark.table(sourceTable) + .mergeInto(tableNameAsString, $"source_table.pk" === col(tableNameAsString + ".pk")) + .withSchemaEvolution() + .whenMatched() + .updateAll() + .whenNotMatched() + .insertAll() + .merge() + + // validate merge results with evolved schema + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Seq( + Row(1, 101, "support", 1), + Row(2, 200, "software", null), + Row(3, 301, "support", 3), + Row(4, 401, "finance", 4))) + } + } + + test("merge schema evolution new column with set explicit column using dataframe API") { + Seq((true, true), (false, true), (true, false)).foreach { + case (withSchemaEvolution, schemaEvolutionEnabled) => + val sourceTable = "cat.ns1.source_table" + withTable(sourceTable) { + sql(s"CREATE TABLE $tableNameAsString (pk INT NOT NULL, salary INT, dep STRING)") + + val targetData = Seq( + Row(1, 100, "hr"), + Row(2, 200, "software"), + Row(3, 300, "hr"), + Row(4, 400, "marketing"), + Row(5, 500, "executive") + ) + val targetSchema = StructType(Seq( + StructField("pk", IntegerType, nullable = false), + StructField("salary", IntegerType), + StructField("dep", StringType) + )) + spark.createDataFrame(spark.sparkContext.parallelize(targetData), targetSchema) + .writeTo(tableNameAsString).append() + + if (!schemaEvolutionEnabled) { + sql(s"""ALTER TABLE $tableNameAsString SET TBLPROPERTIES + | ('auto-schema-evolution' = 'false')""".stripMargin) + } + + val sourceIdent = Identifier.of(Array("ns1"), "source_table") + val columns = Array( + Column.create("pk", IntegerType, false), + Column.create("salary", IntegerType), + Column.create("dep", StringType), + Column.create("active", BooleanType)) + val tableInfo = new TableInfo.Builder() + .withColumns(columns) + .withProperties(extraTableProps) + .build() + catalog.createTable(sourceIdent, tableInfo) + + sql(s"INSERT INTO $sourceTable VALUES (4, 150, 'dummy', true)," + + s"(5, 250, 'dummy', true), (6, 350, 'dummy', false)") + + val mergeBuilder = spark.table(sourceTable) + .mergeInto(tableNameAsString, $"source_table.pk" === col(tableNameAsString + ".pk")) + .whenMatched() + .update(Map("dep" -> lit("software"), "active" -> col("source_table.active"))) + .whenNotMatched() + .insert(Map("pk" -> col("source_table.pk"), "salary" -> lit(0), + "dep" -> col("source_table.dep"), "active" -> col("source_table.active"))) + + if (withSchemaEvolution && schemaEvolutionEnabled) { + mergeBuilder.withSchemaEvolution().merge() + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Seq( + Row(1, 100, "hr", null), + Row(2, 200, "software", null), + Row(3, 300, "hr", null), + Row(4, 400, "software", true), + Row(5, 500, "software", true), + Row(6, 0, "dummy", false))) + } else { + val e = intercept[org.apache.spark.sql.AnalysisException] { + mergeBuilder.merge() + } + assert(e.errorClass.get == "UNRESOLVED_COLUMN.WITH_SUGGESTION") + assert(e.getMessage.contains("A column, variable, or function parameter with name " + + "`active` cannot be resolved")) + } + + sql(s"DROP TABLE $tableNameAsString") + } + } + } + + test("merge schema evolution add column with nested struct and set explicit columns " + + "using dataframe API") { + Seq((true, true), (false, true), (true, false)).foreach { + case (withSchemaEvolution, schemaEvolutionEnabled) => + val sourceTable = "cat.ns1.source_table" + withTable(sourceTable) { + sql( + s"""CREATE TABLE $tableNameAsString ( + |pk INT NOT NULL, + |s STRUCT, m: MAP>>, + |dep STRING)""".stripMargin) + + val targetData = Seq( + Row(1, Row(2, Row(Array(1, 2), Map("a" -> "b"))), "hr") + ) + val targetSchema = StructType(Seq( + StructField("pk", IntegerType, nullable = false), + StructField("s", StructType(Seq( + StructField("c1", IntegerType), + StructField("c2", StructType(Seq( + StructField("a", ArrayType(IntegerType)), + StructField("m", MapType(StringType, StringType)) + ))) + ))), + StructField("dep", StringType) + )) + spark.createDataFrame(spark.sparkContext.parallelize(targetData), targetSchema) + .writeTo(tableNameAsString).append() + + if (!schemaEvolutionEnabled) { + sql(s"""ALTER TABLE $tableNameAsString SET TBLPROPERTIES + | ('auto-schema-evolution' = 'false')""".stripMargin) + } + + val sourceIdent = Identifier.of(Array("ns1"), "source_table") + val columns = Array( + Column.create("pk", IntegerType, false), + Column.create("s", StructType(Seq( + StructField("c1", IntegerType), + StructField("c2", StructType(Seq( + StructField("a", ArrayType(IntegerType)), + StructField("m", MapType(StringType, StringType)), + StructField("c3", BooleanType) // new column + ))) + ))), + Column.create("dep", StringType)) + val tableInfo = new TableInfo.Builder() + .withColumns(columns) + .withProperties(extraTableProps) + .build() + catalog.createTable(sourceIdent, tableInfo) + + val data = Seq( + Row(1, Row(10, Row(Array(3, 4), Map("c" -> "d"), false)), "sales"), + Row(2, Row(20, Row(Array(4, 5), Map("e" -> "f"), true)), "engineering") + ) + val sourceTableSchema = StructType(Seq( + StructField("pk", IntegerType, nullable = false), + StructField("s", StructType(Seq( + StructField("c1", IntegerType), + StructField("c2", StructType(Seq( + StructField("a", ArrayType(IntegerType)), + StructField("m", MapType(StringType, StringType)), + StructField("c3", BooleanType) + ))) + ))), + StructField("dep", StringType) + )) + spark.createDataFrame(spark.sparkContext.parallelize(data), sourceTableSchema) + .createOrReplaceTempView("source_temp") + + sql(s"INSERT INTO $sourceTable SELECT * FROM source_temp") + + val mergeBuilder = spark.table(sourceTable) + .mergeInto(tableNameAsString, $"source_table.pk" === col(tableNameAsString + ".pk")) + .whenMatched() + .update(Map( + "s.c1" -> lit(-1), + "s.c2.m" -> map(lit("k"), lit("v")), + "s.c2.a" -> array(lit(-1)), + "s.c2.c3" -> col("source_table.s.c2.c3"))) + .whenNotMatched() + .insert(Map( + "pk" -> col("source_table.pk"), + "s" -> struct( + col("source_table.s.c1").as("c1"), + struct( + col("source_table.s.c2.a").as("a"), + map(lit("g"), lit("h")).as("m"), + lit(true).as("c3") + ).as("c2") + ), + "dep" -> col("source_table.dep"))) + + if (withSchemaEvolution && schemaEvolutionEnabled) { + mergeBuilder.withSchemaEvolution().merge() + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Seq(Row(1, Row(-1, Row(Seq(-1), Map("k" -> "v"), false)), "hr"), + Row(2, Row(20, Row(Seq(4, 5), Map("g" -> "h"), true)), "engineering"))) + } else { + val exception = intercept[org.apache.spark.sql.AnalysisException] { + mergeBuilder.merge() + } + assert(exception.errorClass.get == "FIELD_NOT_FOUND") + assert(exception.getMessage.contains("No such struct field `c3` in `a`, `m`. ")) + } + + sql(s"DROP TABLE $tableNameAsString") + } + } + } + + test("merge schema evolution add column with nested struct and set all columns " + + "using dataframe API") { + Seq((true, true), (false, true), (true, false)).foreach { + case (withSchemaEvolution, schemaEvolutionEnabled) => + val sourceTable = "cat.ns1.source_table" + withTable(sourceTable) { + sql( + s"""CREATE TABLE $tableNameAsString ( + |pk INT NOT NULL, + |s STRUCT, m: MAP>>, + |dep STRING)""".stripMargin) + + val targetData = Seq( + Row(1, Row(2, Row(Array(1, 2), Map("a" -> "b"))), "hr") + ) + val targetSchema = StructType(Seq( + StructField("pk", IntegerType, nullable = false), + StructField("s", StructType(Seq( + StructField("c1", IntegerType), + StructField("c2", StructType(Seq( + StructField("a", ArrayType(IntegerType)), + StructField("m", MapType(StringType, StringType)) + ))) + ))), + StructField("dep", StringType) + )) + spark.createDataFrame(spark.sparkContext.parallelize(targetData), targetSchema) + .writeTo(tableNameAsString).append() + + if (!schemaEvolutionEnabled) { + sql(s"""ALTER TABLE $tableNameAsString SET TBLPROPERTIES + | ('auto-schema-evolution' = 'false')""".stripMargin) + } + + val sourceIdent = Identifier.of(Array("ns1"), "source_table") + val columns = Array( + Column.create("pk", IntegerType, false), + Column.create("s", StructType(Seq( + StructField("c1", IntegerType), + StructField("c2", StructType(Seq( + StructField("a", ArrayType(IntegerType)), + StructField("m", MapType(StringType, StringType)), + StructField("c3", BooleanType) // new column + ))) + ))), + Column.create("dep", StringType)) + val tableInfo = new TableInfo.Builder() + .withColumns(columns) + .withProperties(extraTableProps) + .build() + catalog.createTable(sourceIdent, tableInfo) + + val data = Seq( + Row(1, Row(10, Row(Array(3, 4), Map("c" -> "d"), false)), "sales"), + Row(2, Row(20, Row(Array(4, 5), Map("e" -> "f"), true)), "engineering") + ) + val sourceTableSchema = StructType(Seq( + StructField("pk", IntegerType, nullable = false), + StructField("s", StructType(Seq( + StructField("c1", IntegerType), + StructField("c2", StructType(Seq( + StructField("a", ArrayType(IntegerType)), + StructField("m", MapType(StringType, StringType)), + StructField("c3", BooleanType) + ))) + ))), + StructField("dep", StringType) + )) + spark.createDataFrame(spark.sparkContext.parallelize(data), sourceTableSchema) + .createOrReplaceTempView("source_temp") + + sql(s"INSERT INTO $sourceTable SELECT * FROM source_temp") + + val mergeBuilder = spark.table(sourceTable) + .mergeInto(tableNameAsString, $"source_table.pk" === col(tableNameAsString + ".pk")) + .whenMatched() + .updateAll() + .whenNotMatched() + .insertAll() + + if (withSchemaEvolution && schemaEvolutionEnabled) { + mergeBuilder.withSchemaEvolution().merge() + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Seq(Row(1, Row(10, Row(Seq(3, 4), Map("c" -> "d"), false)), "sales"), + Row(2, Row(20, Row(Seq(4, 5), Map("e" -> "f"), true)), "engineering"))) + } else { + val exception = intercept[org.apache.spark.sql.AnalysisException] { + mergeBuilder.merge() + } + assert(exception.errorClass.get == "INCOMPATIBLE_DATA_FOR_TABLE.EXTRA_STRUCT_FIELDS") + assert(exception.getMessage.contains( + "Cannot write extra fields `c3` to the struct `s`.`c2`")) + } + + sql(s"DROP TABLE $tableNameAsString") + } + } + } + + test("merge schema evolution replace column with nested struct and " + + "set explicit columns using dataframe API") { + Seq((true, true), (false, true), (true, false)).foreach { + case (withSchemaEvolution, schemaEvolutionEnabled) => + val sourceTable = "cat.ns1.source_table" + withTable(sourceTable) { + sql( + s"""CREATE TABLE $tableNameAsString ( + |pk INT NOT NULL, + |s STRUCT, m: MAP>>, + |dep STRING)""".stripMargin) + + val targetData = Seq( + Row(1, Row(2, Row(Array(1, 2), Map("a" -> "b"))), "hr") + ) + val targetSchema = StructType(Seq( + StructField("pk", IntegerType, nullable = false), + StructField("s", StructType(Seq( + StructField("c1", IntegerType), + StructField("c2", StructType(Seq( + StructField("a", ArrayType(IntegerType)), + StructField("m", MapType(StringType, StringType)) + ))) + ))), + StructField("dep", StringType) + )) + spark.createDataFrame(spark.sparkContext.parallelize(targetData), targetSchema) + .writeTo(tableNameAsString).append() + + if (!schemaEvolutionEnabled) { + sql(s"""ALTER TABLE $tableNameAsString SET TBLPROPERTIES + | ('auto-schema-evolution' = 'false')""".stripMargin) + } + + val sourceIdent = Identifier.of(Array("ns1"), "source_table") + val columns = Array( + Column.create("pk", IntegerType, false), + Column.create("s", StructType(Seq( + StructField("c1", IntegerType), + StructField("c2", StructType(Seq( + // removed column 'a' + StructField("m", MapType(StringType, StringType)), + StructField("c3", BooleanType) // new column + ))) + ))), + Column.create("dep", StringType)) + val tableInfo = new TableInfo.Builder() + .withColumns(columns) + .withProperties(extraTableProps) + .build() + catalog.createTable(sourceIdent, tableInfo) + + val data = Seq( + Row(1, Row(10, Row(Map("c" -> "d"), false)), "sales"), + Row(2, Row(20, Row(Map("e" -> "f"), true)), "engineering") + ) + val sourceTableSchema = StructType(Seq( + StructField("pk", IntegerType, nullable = false), + StructField("s", StructType(Seq( + StructField("c1", IntegerType), + StructField("c2", StructType(Seq( + StructField("m", MapType(StringType, StringType)), + StructField("c3", BooleanType) + ))) + ))), + StructField("dep", StringType) + )) + spark.createDataFrame(spark.sparkContext.parallelize(data), sourceTableSchema) + .createOrReplaceTempView("source_temp") + + sql(s"INSERT INTO $sourceTable SELECT * FROM source_temp") + + val mergeBuilder = spark.table(sourceTable) + .mergeInto(tableNameAsString, $"source_table.pk" === col(tableNameAsString + ".pk")) + .whenMatched() + .update(Map( + "s.c1" -> lit(-1), + "s.c2.m" -> map(lit("k"), lit("v")), + "s.c2.a" -> array(lit(-1)), + "s.c2.c3" -> col("source_table.s.c2.c3"))) + .whenNotMatched() + .insert(Map( + "pk" -> col("source_table.pk"), + "s" -> struct( + col("source_table.s.c1").as("c1"), + struct( + array(lit(-2)).as("a"), + map(lit("g"), lit("h")).as("m"), + lit(true).as("c3") + ).as("c2") + ), + "dep" -> col("source_table.dep"))) + + if (withSchemaEvolution && schemaEvolutionEnabled) { + mergeBuilder.withSchemaEvolution().merge() + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Seq(Row(1, Row(-1, Row(Seq(-1), Map("k" -> "v"), false)), "hr"), + Row(2, Row(20, Row(Seq(-2), Map("g" -> "h"), true)), "engineering"))) + } else { + val exception = intercept[org.apache.spark.sql.AnalysisException] { + mergeBuilder.merge() + } + assert(exception.errorClass.get == "FIELD_NOT_FOUND") + assert(exception.getMessage.contains("No such struct field `c3` in `a`, `m`. ")) + } + + sql(s"DROP TABLE $tableNameAsString") + } + } + } + + test("merge schema evolution replace column with nested struct and set all columns " + + "using dataframe API") { + Seq((true, true), (false, true), (true, false)).foreach { + case (withSchemaEvolution, schemaEvolutionEnabled) => + Seq(true, false).foreach { updateByFields => + withSQLConf(SQLConf.MERGE_INTO_NESTED_TYPE_UPDATE_BY_FIELD.key -> + updateByFields.toString) { + val sourceTable = "cat.ns1.source_table" + withTable(sourceTable) { + // Create target table + sql( + s"""CREATE TABLE $tableNameAsString ( + |pk INT NOT NULL, + |s STRUCT, m: MAP>>, + |dep STRING) + |PARTITIONED BY (dep) + |""".stripMargin) + + if (!schemaEvolutionEnabled) { + sql(s"""ALTER TABLE $tableNameAsString SET TBLPROPERTIES + | ('auto-schema-evolution' = 'false')""".stripMargin) + } + + // Insert data into target table + val tableSchema = StructType(Seq( + StructField("pk", IntegerType, nullable = false), + StructField("s", StructType(Seq( + StructField("c1", IntegerType), + StructField("c2", StructType(Seq( + StructField("a", ArrayType(IntegerType)), + StructField("m", MapType(StringType, StringType)) + ))) + ))), + StructField("dep", StringType) + )) + val targetData = Seq( + Row(1, Row(2, Row(Array(1, 2), Map("a" -> "b"))), "hr") + ) + spark.createDataFrame(spark.sparkContext.parallelize(targetData), tableSchema) + .coalesce(1).writeTo(tableNameAsString).append() + + // Create source table + val sourceIdent = Identifier.of(Array("ns1"), "source_table") + val columns = Array( + Column.create("pk", IntegerType, false), + Column.create("s", StructType(Seq( + StructField("c1", IntegerType), + StructField("c2", StructType(Seq( + // missing column 'a' + StructField("m", MapType(StringType, StringType)), + StructField("c3", BooleanType) // new column + ))) + ))), + Column.create("dep", StringType)) + val tableInfo = new TableInfo.Builder() + .withColumns(columns) + .withProperties(extraTableProps) + .build() + catalog.createTable(sourceIdent, tableInfo) + + val sourceData = Seq( + Row(1, Row(10, Row(Map("c" -> "d"), false)), "sales"), + Row(2, Row(20, Row(Map("e" -> "f"), true)), "engineering") + ) + val sourceTableSchema = StructType(Seq( + StructField("pk", IntegerType, nullable = false), + StructField("s", StructType(Seq( + StructField("c1", IntegerType), + StructField("c2", StructType(Seq( + StructField("m", MapType(StringType, StringType)), + StructField("c3", BooleanType) + ))) + ))), + StructField("dep", StringType) + )) + spark.createDataFrame(spark.sparkContext.parallelize(sourceData), sourceTableSchema) + .createOrReplaceTempView("source_temp") + + sql(s"INSERT INTO $sourceTable SELECT * FROM source_temp") + + val mergeBuilder = spark.table(sourceTable) + .mergeInto(tableNameAsString, $"source_table.pk" === col(tableNameAsString + ".pk")) + .whenMatched() + .updateAll() + .whenNotMatched() + .insertAll() + + if (withSchemaEvolution && schemaEvolutionEnabled) { + mergeBuilder.withSchemaEvolution().merge() + if (updateByFields) { + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Seq( + Row(1, Row(10, Row(Seq(1, 2), Map("c" -> "d"), false)), "sales"), + Row(2, Row(20, Row(null, Map("e" -> "f"), true)), "engineering"))) + } else { + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Seq( + Row(1, Row(10, Row(null, Map("c" -> "d"), false)), "sales"), + Row(2, Row(20, Row(null, Map("e" -> "f"), true)), + "engineering"))) + } + } else { + val exception = intercept[org.apache.spark.sql.AnalysisException] { + mergeBuilder.merge() + } + assert(exception.errorClass.get == + "INCOMPATIBLE_DATA_FOR_TABLE.EXTRA_STRUCT_FIELDS") + assert(exception.getMessage.contains( + "Cannot write extra fields `c3` to the struct `s`.`c2`")) + } + + sql(s"DROP TABLE $tableNameAsString") + } + } + } + } + } + + test("merge schema evolution replace column with nested struct and " + + "update top level struct using dataframe API") { + Seq((true, true), (false, true), (true, false)).foreach { + case (withSchemaEvolution, schemaEvolutionEnabled) => + Seq(true, false).foreach { updateByFields => + withSQLConf(SQLConf.MERGE_INTO_NESTED_TYPE_UPDATE_BY_FIELD.key -> + updateByFields.toString) { + val sourceTable = "cat.ns1.source_table" + withTable(sourceTable) { + // Create target table + sql( + s"""CREATE TABLE $tableNameAsString ( + |pk INT NOT NULL, + |s STRUCT, m: MAP>>, + |dep STRING) + |PARTITIONED BY (dep) + |""".stripMargin) + + if (!schemaEvolutionEnabled) { + sql(s"""ALTER TABLE $tableNameAsString SET TBLPROPERTIES + | ('auto-schema-evolution' = 'false')""".stripMargin) + } + + // Insert data into target table + val tableSchema = StructType(Seq( + StructField("pk", IntegerType, nullable = false), + StructField("s", StructType(Seq( + StructField("c1", IntegerType), + StructField("c2", StructType(Seq( + StructField("a", ArrayType(IntegerType)), + StructField("m", MapType(StringType, StringType)) + ))) + ))), + StructField("dep", StringType) + )) + val targetData = Seq( + Row(1, Row(2, Row(Array(1, 2), Map("a" -> "b"))), "hr") + ) + spark.createDataFrame(spark.sparkContext.parallelize(targetData), tableSchema) + .coalesce(1).writeTo(tableNameAsString).append() + + // Create source table + val sourceIdent = Identifier.of(Array("ns1"), "source_table") + val columns = Array( + Column.create("pk", IntegerType, false), + Column.create("s", StructType(Seq( + StructField("c1", IntegerType), + StructField("c2", StructType(Seq( + // missing column 'a' + StructField("m", MapType(StringType, StringType)), + StructField("c3", BooleanType) // new column + ))) + ))), + Column.create("dep", StringType)) + val tableInfo = new TableInfo.Builder() + .withColumns(columns) + .withProperties(extraTableProps) + .build() + catalog.createTable(sourceIdent, tableInfo) + + val sourceData = Seq( + Row(1, Row(10, Row(Map("c" -> "d"), false)), "sales"), + Row(2, Row(20, Row(Map("e" -> "f"), true)), "engineering") + ) + val sourceTableSchema = StructType(Seq( + StructField("pk", IntegerType, nullable = false), + StructField("s", StructType(Seq( + StructField("c1", IntegerType), + StructField("c2", StructType(Seq( + StructField("m", MapType(StringType, StringType)), + StructField("c3", BooleanType) + ))) + ))), + StructField("dep", StringType) + )) + spark.createDataFrame(spark.sparkContext.parallelize(sourceData), sourceTableSchema) + .createOrReplaceTempView("source_temp") + + sql(s"INSERT INTO $sourceTable SELECT * FROM source_temp") + + val mergeBuilder = spark.table(sourceTable) + .mergeInto(tableNameAsString, $"source_table.pk" === col(tableNameAsString + ".pk")) + .whenMatched() + .update(Map("s" -> col("source_table.s"))) + .whenNotMatched() + .insertAll() + + if (withSchemaEvolution && schemaEvolutionEnabled) { + mergeBuilder.withSchemaEvolution().merge() + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Seq( + Row(1, Row(10, Row(null, Map("c" -> "d"), false)), "hr"), + Row(2, Row(20, Row(null, Map("e" -> "f"), true)), "engineering"))) + } else { + val exception = intercept[org.apache.spark.sql.AnalysisException] { + mergeBuilder.merge() + } + assert(exception.errorClass.get == + "INCOMPATIBLE_DATA_FOR_TABLE.EXTRA_STRUCT_FIELDS") + assert(exception.getMessage.contains( + "Cannot write extra fields `c3` to the struct `s`.`c2`")) + } + + sql(s"DROP TABLE $tableNameAsString") + } + } + } + } + } + + test("merge schema evolution should not evolve referencing new column " + + "via transform using dataframe API") { + Seq((true, true), (false, true), (true, false)).foreach { + case (withSchemaEvolution, schemaEvolutionEnabled) => + val sourceTable = "cat.ns1.source_table" + withTable(sourceTable) { + sql(s"CREATE TABLE $tableNameAsString (pk INT NOT NULL, salary INT, dep STRING)") + + val targetData = Seq( + Row(1, 100, "hr"), + Row(2, 200, "software") + ) + val targetSchema = StructType(Seq( + StructField("pk", IntegerType, nullable = false), + StructField("salary", IntegerType), + StructField("dep", StringType) + )) + spark.createDataFrame(spark.sparkContext.parallelize(targetData), targetSchema) + .writeTo(tableNameAsString).append() + + if (!schemaEvolutionEnabled) { + sql(s"""ALTER TABLE $tableNameAsString SET TBLPROPERTIES + | ('auto-schema-evolution' = 'false')""".stripMargin) + } + + val sourceIdent = Identifier.of(Array("ns1"), "source_table") + val columns = Array( + Column.create("pk", IntegerType, false), + Column.create("salary", IntegerType), + Column.create("dep", StringType), + Column.create("extra", StringType)) + val tableInfo = new TableInfo.Builder() + .withColumns(columns) + .withProperties(extraTableProps) + .build() + catalog.createTable(sourceIdent, tableInfo) + + sql(s"INSERT INTO $sourceTable VALUES (2, 150, 'dummy', 'blah')," + + s"(3, 250, 'dummy', 'blah')") + + val e = intercept[org.apache.spark.sql.AnalysisException] { + spark.table(sourceTable) + .mergeInto(tableNameAsString, + $"source_table.pk" === col(tableNameAsString + ".pk")) + .withSchemaEvolution() + .whenMatched() + .update(Map("extra" -> substring(col("source_table.extra"), 1, 2))) + .merge() + } + assert(e.errorClass.get == "UNRESOLVED_COLUMN.WITH_SUGGESTION") + assert(e.getMessage.contains( + "A column, variable, or function parameter with name " + + "`extra` cannot be resolved")) + + sql(s"DROP TABLE $tableNameAsString") + } + } + } + + test("merge into with source missing fields in top-level struct using dataframe API") { + val sourceTable = "cat.ns1.source_table" + withTable(sourceTable) { + // Target table has struct with 3 fields at top level + sql( + s"""CREATE TABLE $tableNameAsString ( + |pk INT NOT NULL, + |s STRUCT, + |dep STRING)""".stripMargin) + + val targetData = Seq( + Row(0, Row(1, "a", true), "sales") + ) + val targetSchema = StructType(Seq( + StructField("pk", IntegerType, nullable = false), + StructField("s", StructType(Seq( + StructField("c1", IntegerType), + StructField("c2", StringType), + StructField("c3", BooleanType) + ))), + StructField("dep", StringType) + )) + spark.createDataFrame(spark.sparkContext.parallelize(targetData), targetSchema) + .writeTo(tableNameAsString).append() + + // Create source table with struct having only 2 fields (c1, c2) - missing c3 + val sourceIdent = Identifier.of(Array("ns1"), "source_table") + val columns = Array( + Column.create("pk", IntegerType, false), + Column.create("s", StructType(Seq( + StructField("c1", IntegerType), + StructField("c2", StringType)))), // missing c3 field + Column.create("dep", StringType)) + val tableInfo = new TableInfo.Builder() + .withColumns(columns) + .withProperties(extraTableProps) + .build() + catalog.createTable(sourceIdent, tableInfo) + + val data = Seq( + Row(1, Row(10, "b"), "hr"), + Row(2, Row(20, "c"), "engineering") + ) + val sourceTableSchema = StructType(Seq( + StructField("pk", IntegerType, nullable = false), + StructField("s", StructType(Seq( + StructField("c1", IntegerType), + StructField("c2", StringType)))), + StructField("dep", StringType))) + spark.createDataFrame(spark.sparkContext.parallelize(data), sourceTableSchema) + .createOrReplaceTempView("source_temp") + + sql(s"INSERT INTO $sourceTable SELECT * FROM source_temp") + + spark.table(sourceTable) + .mergeInto(tableNameAsString, $"source_table.pk" === col(tableNameAsString + ".pk")) + .whenMatched() + .updateAll() + .whenNotMatched() + .insertAll() + .merge() + + // Missing field c3 should be filled with NULL + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Seq( + Row(0, Row(1, "a", true), "sales"), + Row(1, Row(10, "b", null), "hr"), + Row(2, Row(20, "c", null), "engineering"))) + + sql(s"DROP TABLE $tableNameAsString") + } + } + + test("merge with null struct with missing nested field using dataframe API") { + Seq(true, false).foreach { updateByFields => + Seq(true, false).foreach { coerceNestedTypes => + withSQLConf( + SQLConf.MERGE_INTO_NESTED_TYPE_UPDATE_BY_FIELD.key -> + updateByFields.toString, + SQLConf.MERGE_INTO_NESTED_TYPE_COERCION_ENABLED.key -> + coerceNestedTypes.toString) { + val sourceTable = "cat.ns1.source_table" + withTable(sourceTable) { + // Target table has nested struct with fields c1 and c2 + sql( + s"""CREATE TABLE $tableNameAsString ( + |pk INT NOT NULL, + |s STRUCT>, + |dep STRING)""".stripMargin) + + val targetData = Seq( + Row(0, Row(1, Row(10, "x")), "sales"), + Row(1, Row(2, Row(20, "y")), "hr") + ) + val targetSchema = StructType(Seq( + StructField("pk", IntegerType, nullable = false), + StructField("s", StructType(Seq( + StructField("c1", IntegerType), + StructField("c2", StructType(Seq( + StructField("a", IntegerType), + StructField("b", StringType) + ))) + ))), + StructField("dep", StringType) + )) + spark.createDataFrame(spark.sparkContext.parallelize(targetData), targetSchema) + .writeTo(tableNameAsString).append() + + // Create source table with missing nested field 'b' + val sourceIdent = Identifier.of(Array("ns1"), "source_table") + val columns = Array( + Column.create("pk", IntegerType, false), + Column.create("s", StructType(Seq( + StructField("c1", IntegerType), + StructField("c2", StructType(Seq( + StructField("a", IntegerType) + // missing field 'b' + ))) + ))), + Column.create("dep", StringType)) + val tableInfo = new TableInfo.Builder() + .withColumns(columns) + .withProperties(extraTableProps) + .build() + catalog.createTable(sourceIdent, tableInfo) + + // Source table has null for the nested struct + val data = Seq( + Row(1, null, "engineering"), + Row(2, null, "finance") + ) + val sourceTableSchema = StructType(Seq( + StructField("pk", IntegerType), + StructField("s", StructType(Seq( + StructField("c1", IntegerType), + StructField("c2", StructType(Seq( + StructField("a", IntegerType) + ))) + ))), + StructField("dep", StringType) + )) + spark.createDataFrame(spark.sparkContext.parallelize(data), sourceTableSchema) + .createOrReplaceTempView("source_temp") + + sql(s"INSERT INTO $sourceTable SELECT * FROM source_temp") + + if (coerceNestedTypes) { + spark.table(sourceTable) + .mergeInto(tableNameAsString, $"source_table.pk" === col(tableNameAsString + ".pk")) + .whenMatched() + .updateAll() + .whenNotMatched() + .insertAll() + .merge() + + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Seq( + Row(0, Row(1, Row(10, "x")), "sales"), + Row(1, null, "engineering"), + Row(2, null, "finance"))) + } else { + // Without coercion, the merge should fail due to missing field + val exception = intercept[org.apache.spark.sql.AnalysisException] { + spark.table(sourceTable) + .mergeInto(tableNameAsString, + $"source_table.pk" === col(tableNameAsString + ".pk")) + .whenMatched() + .updateAll() + .whenNotMatched() + .insertAll() + .merge() + } + assert(exception.errorClass.get == + "INCOMPATIBLE_DATA_FOR_TABLE.CANNOT_FIND_DATA") + assert(exception.getMessage.contains( + "Cannot write incompatible data for the table ``: " + + "Cannot find data for the output column `s`.`c2`.`b`.")) + } + + sql(s"DROP TABLE $tableNameAsString") + } + } + } + } + } + + test("merge null struct with schema evolution - " + + "source with missing and extra nested fields using dataframe API") { + Seq(true, false).foreach { updateByFields => + Seq(true, false).foreach { withSchemaEvolution => + Seq(true, false).foreach { coerceNestedTypes => + withSQLConf( + SQLConf.MERGE_INTO_NESTED_TYPE_UPDATE_BY_FIELD.key -> + updateByFields.toString, + SQLConf.MERGE_INTO_NESTED_TYPE_COERCION_ENABLED.key -> + coerceNestedTypes.toString) { + val sourceTable = "cat.ns1.source_table" + withTable(sourceTable) { + // Target table has nested struct with fields c1 and c2 + sql( + s"""CREATE TABLE $tableNameAsString ( + |pk INT NOT NULL, + |s STRUCT>, + |dep STRING)""".stripMargin) + + val targetData = Seq( + Row(0, Row(1, Row(10, "x")), "sales"), + Row(1, Row(2, Row(20, "y")), "hr") + ) + val targetSchema = StructType(Seq( + StructField("pk", IntegerType, nullable = false), + StructField("s", StructType(Seq( + StructField("c1", IntegerType), + StructField("c2", StructType(Seq( + StructField("a", IntegerType), + StructField("b", StringType) + ))) + ))), + StructField("dep", StringType) + )) + spark.createDataFrame(spark.sparkContext.parallelize(targetData), targetSchema) + .writeTo(tableNameAsString).append() + + // Create source table with missing field 'b' and extra field 'c' in nested struct + val sourceIdent = Identifier.of(Array("ns1"), "source_table") + val columns = Array( + Column.create("pk", IntegerType, false), + Column.create("s", StructType(Seq( + StructField("c1", IntegerType), + StructField("c2", StructType(Seq( + StructField("a", IntegerType), + // missing field 'b' + StructField("c", StringType) // extra field 'c' + ))) + ))), + Column.create("dep", StringType)) + val tableInfo = new TableInfo.Builder() + .withColumns(columns) + .withProperties(extraTableProps) + .build() + catalog.createTable(sourceIdent, tableInfo) + + // Source data has null for the nested struct + val data = Seq( + Row(1, null, "engineering"), + Row(2, null, "finance") + ) + val sourceTableSchema = StructType(Seq( + StructField("pk", IntegerType), + StructField("s", StructType(Seq( + StructField("c1", IntegerType), + StructField("c2", StructType(Seq( + StructField("a", IntegerType), + StructField("c", StringType) + ))) + ))), + StructField("dep", StringType) + )) + spark.createDataFrame(spark.sparkContext.parallelize(data), sourceTableSchema) + .createOrReplaceTempView("source_temp") + + sql(s"INSERT INTO $sourceTable SELECT * FROM source_temp") + + val mergeBuilder = spark.table(sourceTable) + .mergeInto(tableNameAsString, $"source_table.pk" === col(tableNameAsString + ".pk")) + .whenMatched() + .updateAll() + .whenNotMatched() + .insertAll() + + if (coerceNestedTypes) { + if (withSchemaEvolution) { + // extra nested field is added + mergeBuilder.withSchemaEvolution().merge() + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Seq( + Row(0, Row(1, Row(10, "x", null)), "sales"), + Row(1, null, "engineering"), + Row(2, null, "finance"))) + } else { + // extra nested field is not added + val exception = intercept[org.apache.spark.sql.AnalysisException] { + mergeBuilder.merge() + } + assert(exception.errorClass.get == + "INCOMPATIBLE_DATA_FOR_TABLE.EXTRA_STRUCT_FIELDS") + assert(exception.getMessage.contains( + "Cannot write incompatible data for the table ``: " + + "Cannot write extra fields `c` to the struct `s`.`c2`")) + } + } else { + // Without source struct coercion, the merge should fail + val exception = intercept[org.apache.spark.sql.AnalysisException] { + mergeBuilder.merge() + } + assert(exception.errorClass.get == + "INCOMPATIBLE_DATA_FOR_TABLE.CANNOT_FIND_DATA") + assert(exception.getMessage.contains( + "Cannot write incompatible data for the table ``: " + + "Cannot find data for the output column `s`.`c2`.`b`.")) + } + + sql(s"DROP TABLE $tableNameAsString") + } + } + } + } + } + } + test("Merge schema evolution should error on non-existent column in UPDATE and INSERT") { withTable(tableNameAsString) { withTempView("source") { From b240faf3dcc430c67e1fe5d99c41ae7ed406a11d Mon Sep 17 00:00:00 2001 From: Szehon Ho Date: Tue, 25 Nov 2025 12:00:35 -0800 Subject: [PATCH 2/2] Rebase --- .../connector/MergeIntoTableSuiteBase.scala | 1175 ++++++++--------- 1 file changed, 561 insertions(+), 614 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/MergeIntoTableSuiteBase.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/MergeIntoTableSuiteBase.scala index 9d9d8e48268e..680fa63e0929 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/MergeIntoTableSuiteBase.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/MergeIntoTableSuiteBase.scala @@ -5236,74 +5236,91 @@ abstract class MergeIntoTableSuiteBase extends RowLevelOperationSuiteBase } test("merge with schema evolution using dataframe API: add new column and set all") { - val sourceTable = "cat.ns1.source_table" - withTable("source") { - createAndInitTable("pk INT NOT NULL, salary INT, dep STRING", - """{ "pk": 1, "salary": 100, "dep": "hr" } - |{ "pk": 2, "salary": 200, "dep": "software" } - |""".stripMargin) + Seq(true, false).foreach { withSchemaEvolution => + val sourceTable = "cat.ns1.source_table" + withTable(sourceTable) { + sql(s"CREATE TABLE $tableNameAsString (pk INT NOT NULL, salary INT, dep STRING)") - val sourceIdent = Identifier.of(Array("ns1"), "source_table") - val columns = Array( - Column.create("pk", IntegerType, false), - Column.create("salary", IntegerType), - Column.create("dep", StringType), - Column.create("new_col", IntegerType)) - val tableInfo = new TableInfo.Builder() - .withColumns(columns) - .withProperties(extraTableProps) - .build() - catalog.createTable(sourceIdent, tableInfo) + val targetData = Seq( + Row(1, 100, "hr"), + Row(2, 200, "software") + ) + val targetSchema = StructType(Seq( + StructField("pk", IntegerType, nullable = false), + StructField("salary", IntegerType), + StructField("dep", StringType) + )) + spark.createDataFrame(spark.sparkContext.parallelize(targetData), targetSchema) + .writeTo(tableNameAsString).append() + + val sourceIdent = Identifier.of(Array("ns1"), "source_table") + val columns = Array( + Column.create("pk", IntegerType, false), + Column.create("salary", IntegerType), + Column.create("dep", StringType), + Column.create("new_col", IntegerType)) + val tableInfo = new TableInfo.Builder() + .withColumns(columns) + .withProperties(extraTableProps) + .build() + catalog.createTable(sourceIdent, tableInfo) - sql(s"INSERT INTO $sourceTable VALUES (1, 101, 'support', 1)," + - s"(3, 301, 'support', 3), (4, 401, 'finance', 4)") + sql(s"INSERT INTO $sourceTable VALUES (1, 101, 'support', 1)," + + s"(3, 301, 'support', 3), (4, 401, 'finance', 4)") - spark.table(sourceTable) - .mergeInto(tableNameAsString, $"source_table.pk" === col(tableNameAsString + ".pk")) - .withSchemaEvolution() - .whenMatched() - .updateAll() - .whenNotMatched() - .insertAll() - .merge() + val mergeBuilder = spark.table(sourceTable) + .mergeInto(tableNameAsString, + $"source_table.pk" === col(tableNameAsString + ".pk")) + .whenMatched() + .updateAll() + .whenNotMatched() + .insertAll() - // validate merge results with evolved schema - checkAnswer( - sql(s"SELECT * FROM $tableNameAsString"), - Seq( - Row(1, 101, "support", 1), - Row(2, 200, "software", null), - Row(3, 301, "support", 3), - Row(4, 401, "finance", 4))) + if (withSchemaEvolution) { + mergeBuilder.withSchemaEvolution().merge() + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Seq( + Row(1, 101, "support", 1), + Row(2, 200, "software", null), + Row(3, 301, "support", 3), + Row(4, 401, "finance", 4))) + } else { + mergeBuilder.merge() + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Seq( + Row(1, 101, "support"), + Row(2, 200, "software"), + Row(3, 301, "support"), + Row(4, 401, "finance"))) + } + + sql(s"DROP TABLE $tableNameAsString") + } } } test("merge schema evolution new column with set explicit column using dataframe API") { - Seq((true, true), (false, true), (true, false)).foreach { - case (withSchemaEvolution, schemaEvolutionEnabled) => - val sourceTable = "cat.ns1.source_table" - withTable(sourceTable) { - sql(s"CREATE TABLE $tableNameAsString (pk INT NOT NULL, salary INT, dep STRING)") - - val targetData = Seq( - Row(1, 100, "hr"), - Row(2, 200, "software"), - Row(3, 300, "hr"), - Row(4, 400, "marketing"), - Row(5, 500, "executive") - ) - val targetSchema = StructType(Seq( - StructField("pk", IntegerType, nullable = false), - StructField("salary", IntegerType), - StructField("dep", StringType) - )) - spark.createDataFrame(spark.sparkContext.parallelize(targetData), targetSchema) - .writeTo(tableNameAsString).append() + Seq(true, false).foreach { withSchemaEvolution => + val sourceTable = "cat.ns1.source_table" + withTable(sourceTable) { + sql(s"CREATE TABLE $tableNameAsString (pk INT NOT NULL, salary INT, dep STRING)") - if (!schemaEvolutionEnabled) { - sql(s"""ALTER TABLE $tableNameAsString SET TBLPROPERTIES - | ('auto-schema-evolution' = 'false')""".stripMargin) - } + val targetData = Seq( + Row(1, 100, "hr"), + Row(2, 200, "software"), + Row(3, 300, "hr"), + Row(4, 400, "marketing"), + Row(5, 500, "executive") + ) + val targetSchema = StructType(Seq( + StructField("pk", IntegerType, nullable = false), + StructField("salary", IntegerType), + StructField("dep", StringType) + )) + spark.createDataFrame(spark.sparkContext.parallelize(targetData), targetSchema) + .writeTo(tableNameAsString).append() val sourceIdent = Identifier.of(Array("ns1"), "source_table") val columns = Array( @@ -5320,72 +5337,66 @@ abstract class MergeIntoTableSuiteBase extends RowLevelOperationSuiteBase sql(s"INSERT INTO $sourceTable VALUES (4, 150, 'dummy', true)," + s"(5, 250, 'dummy', true), (6, 350, 'dummy', false)") - val mergeBuilder = spark.table(sourceTable) - .mergeInto(tableNameAsString, $"source_table.pk" === col(tableNameAsString + ".pk")) - .whenMatched() - .update(Map("dep" -> lit("software"), "active" -> col("source_table.active"))) - .whenNotMatched() - .insert(Map("pk" -> col("source_table.pk"), "salary" -> lit(0), - "dep" -> col("source_table.dep"), "active" -> col("source_table.active"))) + val mergeBuilder = spark.table(sourceTable) + .mergeInto(tableNameAsString, $"source_table.pk" === col(tableNameAsString + ".pk")) + .whenMatched() + .update(Map("dep" -> lit("software"), "active" -> col("source_table.active"))) + .whenNotMatched() + .insert(Map("pk" -> col("source_table.pk"), "salary" -> lit(0), + "dep" -> col("source_table.dep"), "active" -> col("source_table.active"))) - if (withSchemaEvolution && schemaEvolutionEnabled) { - mergeBuilder.withSchemaEvolution().merge() - checkAnswer( - sql(s"SELECT * FROM $tableNameAsString"), - Seq( - Row(1, 100, "hr", null), - Row(2, 200, "software", null), - Row(3, 300, "hr", null), - Row(4, 400, "software", true), - Row(5, 500, "software", true), - Row(6, 0, "dummy", false))) - } else { - val e = intercept[org.apache.spark.sql.AnalysisException] { - mergeBuilder.merge() - } - assert(e.errorClass.get == "UNRESOLVED_COLUMN.WITH_SUGGESTION") - assert(e.getMessage.contains("A column, variable, or function parameter with name " + - "`active` cannot be resolved")) + if (withSchemaEvolution) { + mergeBuilder.withSchemaEvolution().merge() + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Seq( + Row(1, 100, "hr", null), + Row(2, 200, "software", null), + Row(3, 300, "hr", null), + Row(4, 400, "software", true), + Row(5, 500, "software", true), + Row(6, 0, "dummy", false))) + } else { + val e = intercept[org.apache.spark.sql.AnalysisException] { + mergeBuilder.merge() } - - sql(s"DROP TABLE $tableNameAsString") + assert(e.errorClass.get == "UNRESOLVED_COLUMN.WITH_SUGGESTION") + assert(e.getMessage.contains("A column, variable, or function parameter with name " + + "`active` cannot be resolved")) } + + sql(s"DROP TABLE $tableNameAsString") + } } } test("merge schema evolution add column with nested struct and set explicit columns " + "using dataframe API") { - Seq((true, true), (false, true), (true, false)).foreach { - case (withSchemaEvolution, schemaEvolutionEnabled) => - val sourceTable = "cat.ns1.source_table" - withTable(sourceTable) { - sql( - s"""CREATE TABLE $tableNameAsString ( - |pk INT NOT NULL, - |s STRUCT, m: MAP>>, - |dep STRING)""".stripMargin) - - val targetData = Seq( - Row(1, Row(2, Row(Array(1, 2), Map("a" -> "b"))), "hr") - ) - val targetSchema = StructType(Seq( - StructField("pk", IntegerType, nullable = false), - StructField("s", StructType(Seq( - StructField("c1", IntegerType), - StructField("c2", StructType(Seq( - StructField("a", ArrayType(IntegerType)), - StructField("m", MapType(StringType, StringType)) - ))) - ))), - StructField("dep", StringType) - )) - spark.createDataFrame(spark.sparkContext.parallelize(targetData), targetSchema) - .writeTo(tableNameAsString).append() + Seq(true, false).foreach { withSchemaEvolution => + val sourceTable = "cat.ns1.source_table" + withTable(sourceTable) { + sql( + s"""CREATE TABLE $tableNameAsString ( + |pk INT NOT NULL, + |s STRUCT, m: MAP>>, + |dep STRING)""".stripMargin) - if (!schemaEvolutionEnabled) { - sql(s"""ALTER TABLE $tableNameAsString SET TBLPROPERTIES - | ('auto-schema-evolution' = 'false')""".stripMargin) - } + val targetData = Seq( + Row(1, Row(2, Row(Array(1, 2), Map("a" -> "b"))), "hr") + ) + val targetSchema = StructType(Seq( + StructField("pk", IntegerType, nullable = false), + StructField("s", StructType(Seq( + StructField("c1", IntegerType), + StructField("c2", StructType(Seq( + StructField("a", ArrayType(IntegerType)), + StructField("m", MapType(StringType, StringType)) + ))) + ))), + StructField("dep", StringType) + )) + spark.createDataFrame(spark.sparkContext.parallelize(targetData), targetSchema) + .writeTo(tableNameAsString).append() val sourceIdent = Identifier.of(Array("ns1"), "source_table") val columns = Array( @@ -5426,79 +5437,167 @@ abstract class MergeIntoTableSuiteBase extends RowLevelOperationSuiteBase sql(s"INSERT INTO $sourceTable SELECT * FROM source_temp") - val mergeBuilder = spark.table(sourceTable) - .mergeInto(tableNameAsString, $"source_table.pk" === col(tableNameAsString + ".pk")) - .whenMatched() - .update(Map( - "s.c1" -> lit(-1), - "s.c2.m" -> map(lit("k"), lit("v")), - "s.c2.a" -> array(lit(-1)), - "s.c2.c3" -> col("source_table.s.c2.c3"))) - .whenNotMatched() - .insert(Map( - "pk" -> col("source_table.pk"), - "s" -> struct( - col("source_table.s.c1").as("c1"), - struct( - col("source_table.s.c2.a").as("a"), - map(lit("g"), lit("h")).as("m"), - lit(true).as("c3") - ).as("c2") - ), - "dep" -> col("source_table.dep"))) + val mergeBuilder = spark.table(sourceTable) + .mergeInto(tableNameAsString, $"source_table.pk" === col(tableNameAsString + ".pk")) + .whenMatched() + .update(Map( + "s.c1" -> lit(-1), + "s.c2.m" -> map(lit("k"), lit("v")), + "s.c2.a" -> array(lit(-1)), + "s.c2.c3" -> col("source_table.s.c2.c3"))) + .whenNotMatched() + .insert(Map( + "pk" -> col("source_table.pk"), + "s" -> struct( + col("source_table.s.c1").as("c1"), + struct( + col("source_table.s.c2.a").as("a"), + map(lit("g"), lit("h")).as("m"), + lit(true).as("c3") + ).as("c2") + ), + "dep" -> col("source_table.dep"))) - if (withSchemaEvolution && schemaEvolutionEnabled) { - mergeBuilder.withSchemaEvolution().merge() - checkAnswer( - sql(s"SELECT * FROM $tableNameAsString"), - Seq(Row(1, Row(-1, Row(Seq(-1), Map("k" -> "v"), false)), "hr"), - Row(2, Row(20, Row(Seq(4, 5), Map("g" -> "h"), true)), "engineering"))) - } else { - val exception = intercept[org.apache.spark.sql.AnalysisException] { - mergeBuilder.merge() - } - assert(exception.errorClass.get == "FIELD_NOT_FOUND") - assert(exception.getMessage.contains("No such struct field `c3` in `a`, `m`. ")) + if (withSchemaEvolution) { + mergeBuilder.withSchemaEvolution().merge() + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Seq(Row(1, Row(-1, Row(Seq(-1), Map("k" -> "v"), false)), "hr"), + Row(2, Row(20, Row(Seq(4, 5), Map("g" -> "h"), true)), "engineering"))) + } else { + val exception = intercept[org.apache.spark.sql.AnalysisException] { + mergeBuilder.merge() } - - sql(s"DROP TABLE $tableNameAsString") + assert(exception.errorClass.get == "FIELD_NOT_FOUND") + assert(exception.getMessage.contains("No such struct field `c3` in `a`, `m`. ")) } + + sql(s"DROP TABLE $tableNameAsString") + } } } test("merge schema evolution add column with nested struct and set all columns " + "using dataframe API") { - Seq((true, true), (false, true), (true, false)).foreach { - case (withSchemaEvolution, schemaEvolutionEnabled) => - val sourceTable = "cat.ns1.source_table" - withTable(sourceTable) { - sql( - s"""CREATE TABLE $tableNameAsString ( - |pk INT NOT NULL, - |s STRUCT, m: MAP>>, - |dep STRING)""".stripMargin) + Seq(true, false).foreach { withSchemaEvolution => + val sourceTable = "cat.ns1.source_table" + withTable(sourceTable) { + sql( + s"""CREATE TABLE $tableNameAsString ( + |pk INT NOT NULL, + |s STRUCT, m: MAP>>, + |dep STRING)""".stripMargin) - val targetData = Seq( - Row(1, Row(2, Row(Array(1, 2), Map("a" -> "b"))), "hr") + val targetData = Seq( + Row(1, Row(2, Row(Array(1, 2), Map("a" -> "b"))), "hr") + ) + val targetSchema = StructType(Seq( + StructField("pk", IntegerType, nullable = false), + StructField("s", StructType(Seq( + StructField("c1", IntegerType), + StructField("c2", StructType(Seq( + StructField("a", ArrayType(IntegerType)), + StructField("m", MapType(StringType, StringType)) + ))) + ))), + StructField("dep", StringType) + )) + spark.createDataFrame(spark.sparkContext.parallelize(targetData), targetSchema) + .writeTo(tableNameAsString).append() + + val sourceIdent = Identifier.of(Array("ns1"), "source_table") + val columns = Array( + Column.create("pk", IntegerType, false), + Column.create("s", StructType(Seq( + StructField("c1", IntegerType), + StructField("c2", StructType(Seq( + StructField("a", ArrayType(IntegerType)), + StructField("m", MapType(StringType, StringType)), + StructField("c3", BooleanType) // new column + ))) + ))), + Column.create("dep", StringType)) + val tableInfo = new TableInfo.Builder() + .withColumns(columns) + .withProperties(extraTableProps) + .build() + catalog.createTable(sourceIdent, tableInfo) + + val data = Seq( + Row(1, Row(10, Row(Array(3, 4), Map("c" -> "d"), false)), "sales"), + Row(2, Row(20, Row(Array(4, 5), Map("e" -> "f"), true)), "engineering") ) - val targetSchema = StructType(Seq( + val sourceTableSchema = StructType(Seq( StructField("pk", IntegerType, nullable = false), StructField("s", StructType(Seq( StructField("c1", IntegerType), StructField("c2", StructType(Seq( StructField("a", ArrayType(IntegerType)), - StructField("m", MapType(StringType, StringType)) + StructField("m", MapType(StringType, StringType)), + StructField("c3", BooleanType) ))) ))), StructField("dep", StringType) )) - spark.createDataFrame(spark.sparkContext.parallelize(targetData), targetSchema) - .writeTo(tableNameAsString).append() + spark.createDataFrame(spark.sparkContext.parallelize(data), sourceTableSchema) + .createOrReplaceTempView("source_temp") - if (!schemaEvolutionEnabled) { - sql(s"""ALTER TABLE $tableNameAsString SET TBLPROPERTIES - | ('auto-schema-evolution' = 'false')""".stripMargin) + sql(s"INSERT INTO $sourceTable SELECT * FROM source_temp") + + val mergeBuilder = spark.table(sourceTable) + .mergeInto(tableNameAsString, $"source_table.pk" === col(tableNameAsString + ".pk")) + .whenMatched() + .updateAll() + .whenNotMatched() + .insertAll() + + if (withSchemaEvolution) { + mergeBuilder.withSchemaEvolution().merge() + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Seq(Row(1, Row(10, Row(Seq(3, 4), Map("c" -> "d"), false)), "sales"), + Row(2, Row(20, Row(Seq(4, 5), Map("e" -> "f"), true)), "engineering"))) + } else { + val exception = intercept[org.apache.spark.sql.AnalysisException] { + mergeBuilder.merge() } + assert(exception.errorClass.get == "INCOMPATIBLE_DATA_FOR_TABLE.EXTRA_STRUCT_FIELDS") + assert(exception.getMessage.contains( + "Cannot write extra fields `c3` to the struct `s`.`c2`")) + } + + sql(s"DROP TABLE $tableNameAsString") + } + } + } + + test("merge schema evolution replace column with nested struct and " + + "set explicit columns using dataframe API") { + Seq(true, false).foreach { withSchemaEvolution => + val sourceTable = "cat.ns1.source_table" + withTable(sourceTable) { + sql( + s"""CREATE TABLE $tableNameAsString ( + |pk INT NOT NULL, + |s STRUCT, m: MAP>>, + |dep STRING)""".stripMargin) + + val targetData = Seq( + Row(1, Row(2, Row(Array(1, 2), Map("a" -> "b"))), "hr") + ) + val targetSchema = StructType(Seq( + StructField("pk", IntegerType, nullable = false), + StructField("s", StructType(Seq( + StructField("c1", IntegerType), + StructField("c2", StructType(Seq( + StructField("a", ArrayType(IntegerType)), + StructField("m", MapType(StringType, StringType)) + ))) + ))), + StructField("dep", StringType) + )) + spark.createDataFrame(spark.sparkContext.parallelize(targetData), targetSchema) + .writeTo(tableNameAsString).append() val sourceIdent = Identifier.of(Array("ns1"), "source_table") val columns = Array( @@ -5506,7 +5605,7 @@ abstract class MergeIntoTableSuiteBase extends RowLevelOperationSuiteBase Column.create("s", StructType(Seq( StructField("c1", IntegerType), StructField("c2", StructType(Seq( - StructField("a", ArrayType(IntegerType)), + // removed column 'a' StructField("m", MapType(StringType, StringType)), StructField("c3", BooleanType) // new column ))) @@ -5519,15 +5618,14 @@ abstract class MergeIntoTableSuiteBase extends RowLevelOperationSuiteBase catalog.createTable(sourceIdent, tableInfo) val data = Seq( - Row(1, Row(10, Row(Array(3, 4), Map("c" -> "d"), false)), "sales"), - Row(2, Row(20, Row(Array(4, 5), Map("e" -> "f"), true)), "engineering") + Row(1, Row(10, Row(Map("c" -> "d"), false)), "sales"), + Row(2, Row(20, Row(Map("e" -> "f"), true)), "engineering") ) val sourceTableSchema = StructType(Seq( StructField("pk", IntegerType, nullable = false), StructField("s", StructType(Seq( StructField("c1", IntegerType), StructField("c2", StructType(Seq( - StructField("a", ArrayType(IntegerType)), StructField("m", MapType(StringType, StringType)), StructField("c3", BooleanType) ))) @@ -5539,49 +5637,60 @@ abstract class MergeIntoTableSuiteBase extends RowLevelOperationSuiteBase sql(s"INSERT INTO $sourceTable SELECT * FROM source_temp") - val mergeBuilder = spark.table(sourceTable) - .mergeInto(tableNameAsString, $"source_table.pk" === col(tableNameAsString + ".pk")) - .whenMatched() - .updateAll() - .whenNotMatched() - .insertAll() + val mergeBuilder = spark.table(sourceTable) + .mergeInto(tableNameAsString, $"source_table.pk" === col(tableNameAsString + ".pk")) + .whenMatched() + .update(Map( + "s.c1" -> lit(-1), + "s.c2.m" -> map(lit("k"), lit("v")), + "s.c2.a" -> array(lit(-1)), + "s.c2.c3" -> col("source_table.s.c2.c3"))) + .whenNotMatched() + .insert(Map( + "pk" -> col("source_table.pk"), + "s" -> struct( + col("source_table.s.c1").as("c1"), + struct( + array(lit(-2)).as("a"), + map(lit("g"), lit("h")).as("m"), + lit(true).as("c3") + ).as("c2") + ), + "dep" -> col("source_table.dep"))) - if (withSchemaEvolution && schemaEvolutionEnabled) { - mergeBuilder.withSchemaEvolution().merge() - checkAnswer( - sql(s"SELECT * FROM $tableNameAsString"), - Seq(Row(1, Row(10, Row(Seq(3, 4), Map("c" -> "d"), false)), "sales"), - Row(2, Row(20, Row(Seq(4, 5), Map("e" -> "f"), true)), "engineering"))) - } else { - val exception = intercept[org.apache.spark.sql.AnalysisException] { - mergeBuilder.merge() - } - assert(exception.errorClass.get == "INCOMPATIBLE_DATA_FOR_TABLE.EXTRA_STRUCT_FIELDS") - assert(exception.getMessage.contains( - "Cannot write extra fields `c3` to the struct `s`.`c2`")) + if (withSchemaEvolution) { + mergeBuilder.withSchemaEvolution().merge() + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Seq(Row(1, Row(-1, Row(Seq(-1), Map("k" -> "v"), false)), "hr"), + Row(2, Row(20, Row(Seq(-2), Map("g" -> "h"), true)), "engineering"))) + } else { + val exception = intercept[org.apache.spark.sql.AnalysisException] { + mergeBuilder.merge() } - - sql(s"DROP TABLE $tableNameAsString") + assert(exception.errorClass.get == "FIELD_NOT_FOUND") + assert(exception.getMessage.contains("No such struct field `c3` in `a`, `m`. ")) } + + sql(s"DROP TABLE $tableNameAsString") + } } } - test("merge schema evolution replace column with nested struct and " + - "set explicit columns using dataframe API") { - Seq((true, true), (false, true), (true, false)).foreach { - case (withSchemaEvolution, schemaEvolutionEnabled) => - val sourceTable = "cat.ns1.source_table" - withTable(sourceTable) { - sql( - s"""CREATE TABLE $tableNameAsString ( - |pk INT NOT NULL, - |s STRUCT, m: MAP>>, - |dep STRING)""".stripMargin) + test("merge schema evolution replace column with nested struct and set all columns " + + "using dataframe API") { + Seq(true, false).foreach { withSchemaEvolution => + val sourceTable = "cat.ns1.source_table" + withTable(sourceTable) { + sql( + s"""CREATE TABLE $tableNameAsString ( + |pk INT NOT NULL, + |s STRUCT, m: MAP>>, + |dep STRING) + |PARTITIONED BY (dep) + |""".stripMargin) - val targetData = Seq( - Row(1, Row(2, Row(Array(1, 2), Map("a" -> "b"))), "hr") - ) - val targetSchema = StructType(Seq( + val tableSchema = StructType(Seq( StructField("pk", IntegerType, nullable = false), StructField("s", StructType(Seq( StructField("c1", IntegerType), @@ -5592,13 +5701,11 @@ abstract class MergeIntoTableSuiteBase extends RowLevelOperationSuiteBase ))), StructField("dep", StringType) )) - spark.createDataFrame(spark.sparkContext.parallelize(targetData), targetSchema) - .writeTo(tableNameAsString).append() - - if (!schemaEvolutionEnabled) { - sql(s"""ALTER TABLE $tableNameAsString SET TBLPROPERTIES - | ('auto-schema-evolution' = 'false')""".stripMargin) - } + val targetData = Seq( + Row(1, Row(2, Row(Array(1, 2), Map("a" -> "b"))), "hr") + ) + spark.createDataFrame(spark.sparkContext.parallelize(targetData), tableSchema) + .coalesce(1).writeTo(tableNameAsString).append() val sourceIdent = Identifier.of(Array("ns1"), "source_table") val columns = Array( @@ -5606,7 +5713,7 @@ abstract class MergeIntoTableSuiteBase extends RowLevelOperationSuiteBase Column.create("s", StructType(Seq( StructField("c1", IntegerType), StructField("c2", StructType(Seq( - // removed column 'a' + // missing column 'a' StructField("m", MapType(StringType, StringType)), StructField("c3", BooleanType) // new column ))) @@ -5618,7 +5725,7 @@ abstract class MergeIntoTableSuiteBase extends RowLevelOperationSuiteBase .build() catalog.createTable(sourceIdent, tableInfo) - val data = Seq( + val sourceData = Seq( Row(1, Row(10, Row(Map("c" -> "d"), false)), "sales"), Row(2, Row(20, Row(Map("e" -> "f"), true)), "engineering") ) @@ -5633,306 +5740,156 @@ abstract class MergeIntoTableSuiteBase extends RowLevelOperationSuiteBase ))), StructField("dep", StringType) )) - spark.createDataFrame(spark.sparkContext.parallelize(data), sourceTableSchema) + spark.createDataFrame(spark.sparkContext.parallelize(sourceData), sourceTableSchema) .createOrReplaceTempView("source_temp") sql(s"INSERT INTO $sourceTable SELECT * FROM source_temp") - val mergeBuilder = spark.table(sourceTable) - .mergeInto(tableNameAsString, $"source_table.pk" === col(tableNameAsString + ".pk")) - .whenMatched() - .update(Map( - "s.c1" -> lit(-1), - "s.c2.m" -> map(lit("k"), lit("v")), - "s.c2.a" -> array(lit(-1)), - "s.c2.c3" -> col("source_table.s.c2.c3"))) - .whenNotMatched() - .insert(Map( - "pk" -> col("source_table.pk"), - "s" -> struct( - col("source_table.s.c1").as("c1"), - struct( - array(lit(-2)).as("a"), - map(lit("g"), lit("h")).as("m"), - lit(true).as("c3") - ).as("c2") - ), - "dep" -> col("source_table.dep"))) + val mergeBuilder = spark.table(sourceTable) + .mergeInto(tableNameAsString, $"source_table.pk" === col(tableNameAsString + ".pk")) + .whenMatched() + .updateAll() + .whenNotMatched() + .insertAll() - if (withSchemaEvolution && schemaEvolutionEnabled) { - mergeBuilder.withSchemaEvolution().merge() - checkAnswer( - sql(s"SELECT * FROM $tableNameAsString"), - Seq(Row(1, Row(-1, Row(Seq(-1), Map("k" -> "v"), false)), "hr"), - Row(2, Row(20, Row(Seq(-2), Map("g" -> "h"), true)), "engineering"))) - } else { - val exception = intercept[org.apache.spark.sql.AnalysisException] { - mergeBuilder.merge() - } - assert(exception.errorClass.get == "FIELD_NOT_FOUND") - assert(exception.getMessage.contains("No such struct field `c3` in `a`, `m`. ")) + if (withSchemaEvolution) { + mergeBuilder.withSchemaEvolution().merge() + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Seq( + Row(1, Row(10, Row(Seq(1, 2), Map("c" -> "d"), false)), "sales"), + Row(2, Row(20, Row(null, Map("e" -> "f"), true)), "engineering"))) + } else { + val exception = intercept[org.apache.spark.sql.AnalysisException] { + mergeBuilder.merge() } - - sql(s"DROP TABLE $tableNameAsString") + assert(exception.errorClass.get == + "INCOMPATIBLE_DATA_FOR_TABLE.EXTRA_STRUCT_FIELDS") + assert(exception.getMessage.contains( + "Cannot write extra fields `c3` to the struct `s`.`c2`")) } - } - } - - test("merge schema evolution replace column with nested struct and set all columns " + - "using dataframe API") { - Seq((true, true), (false, true), (true, false)).foreach { - case (withSchemaEvolution, schemaEvolutionEnabled) => - Seq(true, false).foreach { updateByFields => - withSQLConf(SQLConf.MERGE_INTO_NESTED_TYPE_UPDATE_BY_FIELD.key -> - updateByFields.toString) { - val sourceTable = "cat.ns1.source_table" - withTable(sourceTable) { - // Create target table - sql( - s"""CREATE TABLE $tableNameAsString ( - |pk INT NOT NULL, - |s STRUCT, m: MAP>>, - |dep STRING) - |PARTITIONED BY (dep) - |""".stripMargin) - - if (!schemaEvolutionEnabled) { - sql(s"""ALTER TABLE $tableNameAsString SET TBLPROPERTIES - | ('auto-schema-evolution' = 'false')""".stripMargin) - } - - // Insert data into target table - val tableSchema = StructType(Seq( - StructField("pk", IntegerType, nullable = false), - StructField("s", StructType(Seq( - StructField("c1", IntegerType), - StructField("c2", StructType(Seq( - StructField("a", ArrayType(IntegerType)), - StructField("m", MapType(StringType, StringType)) - ))) - ))), - StructField("dep", StringType) - )) - val targetData = Seq( - Row(1, Row(2, Row(Array(1, 2), Map("a" -> "b"))), "hr") - ) - spark.createDataFrame(spark.sparkContext.parallelize(targetData), tableSchema) - .coalesce(1).writeTo(tableNameAsString).append() - - // Create source table - val sourceIdent = Identifier.of(Array("ns1"), "source_table") - val columns = Array( - Column.create("pk", IntegerType, false), - Column.create("s", StructType(Seq( - StructField("c1", IntegerType), - StructField("c2", StructType(Seq( - // missing column 'a' - StructField("m", MapType(StringType, StringType)), - StructField("c3", BooleanType) // new column - ))) - ))), - Column.create("dep", StringType)) - val tableInfo = new TableInfo.Builder() - .withColumns(columns) - .withProperties(extraTableProps) - .build() - catalog.createTable(sourceIdent, tableInfo) - val sourceData = Seq( - Row(1, Row(10, Row(Map("c" -> "d"), false)), "sales"), - Row(2, Row(20, Row(Map("e" -> "f"), true)), "engineering") - ) - val sourceTableSchema = StructType(Seq( - StructField("pk", IntegerType, nullable = false), - StructField("s", StructType(Seq( - StructField("c1", IntegerType), - StructField("c2", StructType(Seq( - StructField("m", MapType(StringType, StringType)), - StructField("c3", BooleanType) - ))) - ))), - StructField("dep", StringType) - )) - spark.createDataFrame(spark.sparkContext.parallelize(sourceData), sourceTableSchema) - .createOrReplaceTempView("source_temp") - - sql(s"INSERT INTO $sourceTable SELECT * FROM source_temp") - - val mergeBuilder = spark.table(sourceTable) - .mergeInto(tableNameAsString, $"source_table.pk" === col(tableNameAsString + ".pk")) - .whenMatched() - .updateAll() - .whenNotMatched() - .insertAll() - - if (withSchemaEvolution && schemaEvolutionEnabled) { - mergeBuilder.withSchemaEvolution().merge() - if (updateByFields) { - checkAnswer( - sql(s"SELECT * FROM $tableNameAsString"), - Seq( - Row(1, Row(10, Row(Seq(1, 2), Map("c" -> "d"), false)), "sales"), - Row(2, Row(20, Row(null, Map("e" -> "f"), true)), "engineering"))) - } else { - checkAnswer( - sql(s"SELECT * FROM $tableNameAsString"), - Seq( - Row(1, Row(10, Row(null, Map("c" -> "d"), false)), "sales"), - Row(2, Row(20, Row(null, Map("e" -> "f"), true)), - "engineering"))) - } - } else { - val exception = intercept[org.apache.spark.sql.AnalysisException] { - mergeBuilder.merge() - } - assert(exception.errorClass.get == - "INCOMPATIBLE_DATA_FOR_TABLE.EXTRA_STRUCT_FIELDS") - assert(exception.getMessage.contains( - "Cannot write extra fields `c3` to the struct `s`.`c2`")) - } - - sql(s"DROP TABLE $tableNameAsString") - } - } - } + sql(s"DROP TABLE $tableNameAsString") + } } } test("merge schema evolution replace column with nested struct and " + "update top level struct using dataframe API") { - Seq((true, true), (false, true), (true, false)).foreach { - case (withSchemaEvolution, schemaEvolutionEnabled) => - Seq(true, false).foreach { updateByFields => - withSQLConf(SQLConf.MERGE_INTO_NESTED_TYPE_UPDATE_BY_FIELD.key -> - updateByFields.toString) { - val sourceTable = "cat.ns1.source_table" - withTable(sourceTable) { - // Create target table - sql( - s"""CREATE TABLE $tableNameAsString ( - |pk INT NOT NULL, - |s STRUCT, m: MAP>>, - |dep STRING) - |PARTITIONED BY (dep) - |""".stripMargin) - - if (!schemaEvolutionEnabled) { - sql(s"""ALTER TABLE $tableNameAsString SET TBLPROPERTIES - | ('auto-schema-evolution' = 'false')""".stripMargin) - } - - // Insert data into target table - val tableSchema = StructType(Seq( - StructField("pk", IntegerType, nullable = false), - StructField("s", StructType(Seq( - StructField("c1", IntegerType), - StructField("c2", StructType(Seq( - StructField("a", ArrayType(IntegerType)), - StructField("m", MapType(StringType, StringType)) - ))) - ))), - StructField("dep", StringType) - )) - val targetData = Seq( - Row(1, Row(2, Row(Array(1, 2), Map("a" -> "b"))), "hr") - ) - spark.createDataFrame(spark.sparkContext.parallelize(targetData), tableSchema) - .coalesce(1).writeTo(tableNameAsString).append() + Seq(true, false).foreach { withSchemaEvolution => + val sourceTable = "cat.ns1.source_table" + withTable(sourceTable) { + sql( + s"""CREATE TABLE $tableNameAsString ( + |pk INT NOT NULL, + |s STRUCT, m: MAP>>, + |dep STRING) + |PARTITIONED BY (dep) + |""".stripMargin) - // Create source table - val sourceIdent = Identifier.of(Array("ns1"), "source_table") - val columns = Array( - Column.create("pk", IntegerType, false), - Column.create("s", StructType(Seq( - StructField("c1", IntegerType), - StructField("c2", StructType(Seq( - // missing column 'a' - StructField("m", MapType(StringType, StringType)), - StructField("c3", BooleanType) // new column - ))) - ))), - Column.create("dep", StringType)) - val tableInfo = new TableInfo.Builder() - .withColumns(columns) - .withProperties(extraTableProps) - .build() - catalog.createTable(sourceIdent, tableInfo) + val tableSchema = StructType(Seq( + StructField("pk", IntegerType, nullable = false), + StructField("s", StructType(Seq( + StructField("c1", IntegerType), + StructField("c2", StructType(Seq( + StructField("a", ArrayType(IntegerType)), + StructField("m", MapType(StringType, StringType)) + ))) + ))), + StructField("dep", StringType) + )) + val targetData = Seq( + Row(1, Row(2, Row(Array(1, 2), Map("a" -> "b"))), "hr") + ) + spark.createDataFrame(spark.sparkContext.parallelize(targetData), tableSchema) + .coalesce(1).writeTo(tableNameAsString).append() - val sourceData = Seq( - Row(1, Row(10, Row(Map("c" -> "d"), false)), "sales"), - Row(2, Row(20, Row(Map("e" -> "f"), true)), "engineering") - ) - val sourceTableSchema = StructType(Seq( - StructField("pk", IntegerType, nullable = false), - StructField("s", StructType(Seq( - StructField("c1", IntegerType), - StructField("c2", StructType(Seq( - StructField("m", MapType(StringType, StringType)), - StructField("c3", BooleanType) - ))) - ))), - StructField("dep", StringType) - )) - spark.createDataFrame(spark.sparkContext.parallelize(sourceData), sourceTableSchema) - .createOrReplaceTempView("source_temp") + // Create source table + val sourceIdent = Identifier.of(Array("ns1"), "source_table") + val columns = Array( + Column.create("pk", IntegerType, false), + Column.create("s", StructType(Seq( + StructField("c1", IntegerType), + StructField("c2", StructType(Seq( + // missing column 'a' + StructField("m", MapType(StringType, StringType)), + StructField("c3", BooleanType) // new column + ))) + ))), + Column.create("dep", StringType)) + val tableInfo = new TableInfo.Builder() + .withColumns(columns) + .withProperties(extraTableProps) + .build() + catalog.createTable(sourceIdent, tableInfo) - sql(s"INSERT INTO $sourceTable SELECT * FROM source_temp") + val sourceData = Seq( + Row(1, Row(10, Row(Map("c" -> "d"), false)), "sales"), + Row(2, Row(20, Row(Map("e" -> "f"), true)), "engineering") + ) + val sourceTableSchema = StructType(Seq( + StructField("pk", IntegerType, nullable = false), + StructField("s", StructType(Seq( + StructField("c1", IntegerType), + StructField("c2", StructType(Seq( + StructField("m", MapType(StringType, StringType)), + StructField("c3", BooleanType) + ))) + ))), + StructField("dep", StringType) + )) + spark.createDataFrame(spark.sparkContext.parallelize(sourceData), sourceTableSchema) + .createOrReplaceTempView("source_temp") - val mergeBuilder = spark.table(sourceTable) - .mergeInto(tableNameAsString, $"source_table.pk" === col(tableNameAsString + ".pk")) - .whenMatched() - .update(Map("s" -> col("source_table.s"))) - .whenNotMatched() - .insertAll() + sql(s"INSERT INTO $sourceTable SELECT * FROM source_temp") - if (withSchemaEvolution && schemaEvolutionEnabled) { - mergeBuilder.withSchemaEvolution().merge() - checkAnswer( - sql(s"SELECT * FROM $tableNameAsString"), - Seq( - Row(1, Row(10, Row(null, Map("c" -> "d"), false)), "hr"), - Row(2, Row(20, Row(null, Map("e" -> "f"), true)), "engineering"))) - } else { - val exception = intercept[org.apache.spark.sql.AnalysisException] { - mergeBuilder.merge() - } - assert(exception.errorClass.get == - "INCOMPATIBLE_DATA_FOR_TABLE.EXTRA_STRUCT_FIELDS") - assert(exception.getMessage.contains( - "Cannot write extra fields `c3` to the struct `s`.`c2`")) - } + val mergeBuilder = spark.table(sourceTable) + .mergeInto(tableNameAsString, $"source_table.pk" === col(tableNameAsString + ".pk")) + .whenMatched() + .update(Map("s" -> col("source_table.s"))) + .whenNotMatched() + .insertAll() - sql(s"DROP TABLE $tableNameAsString") - } + if (withSchemaEvolution) { + mergeBuilder.withSchemaEvolution().merge() + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Seq( + Row(1, Row(10, Row(null, Map("c" -> "d"), false)), "hr"), + Row(2, Row(20, Row(null, Map("e" -> "f"), true)), "engineering"))) + } else { + val exception = intercept[org.apache.spark.sql.AnalysisException] { + mergeBuilder.merge() } + assert(exception.errorClass.get == + "INCOMPATIBLE_DATA_FOR_TABLE.EXTRA_STRUCT_FIELDS") + assert(exception.getMessage.contains( + "Cannot write extra fields `c3` to the struct `s`.`c2`")) } + + sql(s"DROP TABLE $tableNameAsString") + } } } test("merge schema evolution should not evolve referencing new column " + "via transform using dataframe API") { - Seq((true, true), (false, true), (true, false)).foreach { - case (withSchemaEvolution, schemaEvolutionEnabled) => - val sourceTable = "cat.ns1.source_table" - withTable(sourceTable) { - sql(s"CREATE TABLE $tableNameAsString (pk INT NOT NULL, salary INT, dep STRING)") - - val targetData = Seq( - Row(1, 100, "hr"), - Row(2, 200, "software") - ) - val targetSchema = StructType(Seq( - StructField("pk", IntegerType, nullable = false), - StructField("salary", IntegerType), - StructField("dep", StringType) - )) - spark.createDataFrame(spark.sparkContext.parallelize(targetData), targetSchema) - .writeTo(tableNameAsString).append() + Seq(true, false).foreach { withSchemaEvolution => + val sourceTable = "cat.ns1.source_table" + withTable(sourceTable) { + sql(s"CREATE TABLE $tableNameAsString (pk INT NOT NULL, salary INT, dep STRING)") - if (!schemaEvolutionEnabled) { - sql(s"""ALTER TABLE $tableNameAsString SET TBLPROPERTIES - | ('auto-schema-evolution' = 'false')""".stripMargin) - } + val targetData = Seq( + Row(1, 100, "hr"), + Row(2, 200, "software") + ) + val targetSchema = StructType(Seq( + StructField("pk", IntegerType, nullable = false), + StructField("salary", IntegerType), + StructField("dep", StringType) + )) + spark.createDataFrame(spark.sparkContext.parallelize(targetData), targetSchema) + .writeTo(tableNameAsString).append() val sourceIdent = Identifier.of(Array("ns1"), "source_table") val columns = Array( @@ -5946,25 +5903,32 @@ abstract class MergeIntoTableSuiteBase extends RowLevelOperationSuiteBase .build() catalog.createTable(sourceIdent, tableInfo) - sql(s"INSERT INTO $sourceTable VALUES (2, 150, 'dummy', 'blah')," + - s"(3, 250, 'dummy', 'blah')") + sql(s"INSERT INTO $sourceTable VALUES (2, 150, 'dummy', 'blah')," + + s"(3, 250, 'dummy', 'blah')") - val e = intercept[org.apache.spark.sql.AnalysisException] { - spark.table(sourceTable) - .mergeInto(tableNameAsString, - $"source_table.pk" === col(tableNameAsString + ".pk")) - .withSchemaEvolution() - .whenMatched() - .update(Map("extra" -> substring(col("source_table.extra"), 1, 2))) - .merge() + val e = intercept[org.apache.spark.sql.AnalysisException] { + val builder = spark.table(sourceTable) + .mergeInto(tableNameAsString, + $"source_table.pk" === col(tableNameAsString + ".pk")) + + val builderWithEvolution = if (withSchemaEvolution) { + builder.withSchemaEvolution() + } else { + builder } - assert(e.errorClass.get == "UNRESOLVED_COLUMN.WITH_SUGGESTION") - assert(e.getMessage.contains( - "A column, variable, or function parameter with name " + - "`extra` cannot be resolved")) - sql(s"DROP TABLE $tableNameAsString") + builderWithEvolution + .whenMatched() + .update(Map("extra" -> substring(col("source_table.extra"), 1, 2))) + .merge() } + assert(e.errorClass.get == "UNRESOLVED_COLUMN.WITH_SUGGESTION") + assert(e.getMessage.contains( + "A column, variable, or function parameter with name " + + "`extra` cannot be resolved")) + + sql(s"DROP TABLE $tableNameAsString") + } } } @@ -6043,114 +6007,102 @@ abstract class MergeIntoTableSuiteBase extends RowLevelOperationSuiteBase } test("merge with null struct with missing nested field using dataframe API") { - Seq(true, false).foreach { updateByFields => - Seq(true, false).foreach { coerceNestedTypes => - withSQLConf( - SQLConf.MERGE_INTO_NESTED_TYPE_UPDATE_BY_FIELD.key -> - updateByFields.toString, - SQLConf.MERGE_INTO_NESTED_TYPE_COERCION_ENABLED.key -> - coerceNestedTypes.toString) { - val sourceTable = "cat.ns1.source_table" - withTable(sourceTable) { - // Target table has nested struct with fields c1 and c2 - sql( - s"""CREATE TABLE $tableNameAsString ( - |pk INT NOT NULL, - |s STRUCT>, - |dep STRING)""".stripMargin) - - val targetData = Seq( - Row(0, Row(1, Row(10, "x")), "sales"), - Row(1, Row(2, Row(20, "y")), "hr") - ) - val targetSchema = StructType(Seq( - StructField("pk", IntegerType, nullable = false), - StructField("s", StructType(Seq( - StructField("c1", IntegerType), - StructField("c2", StructType(Seq( - StructField("a", IntegerType), - StructField("b", StringType) - ))) - ))), - StructField("dep", StringType) - )) - spark.createDataFrame(spark.sparkContext.parallelize(targetData), targetSchema) - .writeTo(tableNameAsString).append() - - // Create source table with missing nested field 'b' - val sourceIdent = Identifier.of(Array("ns1"), "source_table") - val columns = Array( - Column.create("pk", IntegerType, false), - Column.create("s", StructType(Seq( - StructField("c1", IntegerType), - StructField("c2", StructType(Seq( - StructField("a", IntegerType) - // missing field 'b' - ))) - ))), - Column.create("dep", StringType)) - val tableInfo = new TableInfo.Builder() - .withColumns(columns) - .withProperties(extraTableProps) - .build() - catalog.createTable(sourceIdent, tableInfo) - - // Source table has null for the nested struct - val data = Seq( - Row(1, null, "engineering"), - Row(2, null, "finance") - ) - val sourceTableSchema = StructType(Seq( - StructField("pk", IntegerType), - StructField("s", StructType(Seq( - StructField("c1", IntegerType), - StructField("c2", StructType(Seq( - StructField("a", IntegerType) - ))) - ))), - StructField("dep", StringType) - )) - spark.createDataFrame(spark.sparkContext.parallelize(data), sourceTableSchema) - .createOrReplaceTempView("source_temp") + Seq(true, false).foreach { coerceNestedTypes => + withSQLConf(SQLConf.MERGE_INTO_NESTED_TYPE_COERCION_ENABLED.key -> + coerceNestedTypes.toString) { + val sourceTable = "cat.ns1.source_table" + withTable(sourceTable) { + // Target table has nested struct with fields c1 and c2 + sql( + s"""CREATE TABLE $tableNameAsString ( + |pk INT NOT NULL, + |s STRUCT>, + |dep STRING)""".stripMargin) - sql(s"INSERT INTO $sourceTable SELECT * FROM source_temp") + val targetData = Seq( + Row(0, Row(1, Row(10, "x")), "sales"), + Row(1, Row(2, Row(20, "y")), "hr") + ) + val targetSchema = StructType(Seq( + StructField("pk", IntegerType, nullable = false), + StructField("s", StructType(Seq( + StructField("c1", IntegerType), + StructField("c2", StructType(Seq( + StructField("a", IntegerType), + StructField("b", StringType) + ))) + ))), + StructField("dep", StringType) + )) + spark.createDataFrame(spark.sparkContext.parallelize(targetData), targetSchema) + .writeTo(tableNameAsString).append() - if (coerceNestedTypes) { - spark.table(sourceTable) - .mergeInto(tableNameAsString, $"source_table.pk" === col(tableNameAsString + ".pk")) - .whenMatched() - .updateAll() - .whenNotMatched() - .insertAll() - .merge() + // Create source table with missing nested field 'b' + val sourceIdent = Identifier.of(Array("ns1"), "source_table") + val columns = Array( + Column.create("pk", IntegerType, false), + Column.create("s", StructType(Seq( + StructField("c1", IntegerType), + StructField("c2", StructType(Seq( + StructField("a", IntegerType) + // missing field 'b' + ))) + ))), + Column.create("dep", StringType)) + val tableInfo = new TableInfo.Builder() + .withColumns(columns) + .withProperties(extraTableProps) + .build() + catalog.createTable(sourceIdent, tableInfo) - checkAnswer( - sql(s"SELECT * FROM $tableNameAsString"), - Seq( - Row(0, Row(1, Row(10, "x")), "sales"), - Row(1, null, "engineering"), - Row(2, null, "finance"))) - } else { - // Without coercion, the merge should fail due to missing field - val exception = intercept[org.apache.spark.sql.AnalysisException] { - spark.table(sourceTable) - .mergeInto(tableNameAsString, - $"source_table.pk" === col(tableNameAsString + ".pk")) - .whenMatched() - .updateAll() - .whenNotMatched() - .insertAll() - .merge() - } - assert(exception.errorClass.get == - "INCOMPATIBLE_DATA_FOR_TABLE.CANNOT_FIND_DATA") - assert(exception.getMessage.contains( - "Cannot write incompatible data for the table ``: " + - "Cannot find data for the output column `s`.`c2`.`b`.")) - } + // Source table has null for the nested struct + val data = Seq( + Row(1, null, "engineering"), + Row(2, null, "finance") + ) + val sourceTableSchema = StructType(Seq( + StructField("pk", IntegerType), + StructField("s", StructType(Seq( + StructField("c1", IntegerType), + StructField("c2", StructType(Seq( + StructField("a", IntegerType) + ))) + ))), + StructField("dep", StringType) + )) + spark.createDataFrame(spark.sparkContext.parallelize(data), sourceTableSchema) + .createOrReplaceTempView("source_temp") + + sql(s"INSERT INTO $sourceTable SELECT * FROM source_temp") + val mergeBuilder = spark.table(sourceTable) + .mergeInto(tableNameAsString, + $"source_table.pk" === col(tableNameAsString + ".pk")) + .whenMatched() + .updateAll() + .whenNotMatched() + .insertAll() - sql(s"DROP TABLE $tableNameAsString") + if (coerceNestedTypes) { + mergeBuilder.merge() + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Seq( + Row(0, Row(1, Row(10, "x")), "sales"), + Row(1, null, "engineering"), + Row(2, null, "finance"))) + } else { + // Without coercion, the merge should fail due to missing field + val exception = intercept[org.apache.spark.sql.AnalysisException] { + mergeBuilder.merge() + } + assert(exception.errorClass.get == + "INCOMPATIBLE_DATA_FOR_TABLE.CANNOT_FIND_DATA") + assert(exception.getMessage.contains( + "Cannot write incompatible data for the table ``: " + + "Cannot find data for the output column `s`.`c2`.`b`.")) } + + sql(s"DROP TABLE $tableNameAsString") } } } @@ -6158,13 +6110,9 @@ abstract class MergeIntoTableSuiteBase extends RowLevelOperationSuiteBase test("merge null struct with schema evolution - " + "source with missing and extra nested fields using dataframe API") { - Seq(true, false).foreach { updateByFields => Seq(true, false).foreach { withSchemaEvolution => Seq(true, false).foreach { coerceNestedTypes => - withSQLConf( - SQLConf.MERGE_INTO_NESTED_TYPE_UPDATE_BY_FIELD.key -> - updateByFields.toString, - SQLConf.MERGE_INTO_NESTED_TYPE_COERCION_ENABLED.key -> + withSQLConf(SQLConf.MERGE_INTO_NESTED_TYPE_COERCION_ENABLED.key -> coerceNestedTypes.toString) { val sourceTable = "cat.ns1.source_table" withTable(sourceTable) { @@ -6275,7 +6223,6 @@ abstract class MergeIntoTableSuiteBase extends RowLevelOperationSuiteBase sql(s"DROP TABLE $tableNameAsString") } - } } } }