diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/AssignmentUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/AssignmentUtils.scala index 185dc5ec54f6..c9ed5b86dbde 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/AssignmentUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/AssignmentUtils.scala @@ -21,8 +21,7 @@ import scala.collection.mutable import org.apache.spark.sql.catalyst.SQLConfHelper import org.apache.spark.sql.catalyst.analysis.TableOutputResolver.DefaultValueFillMode.{NONE, RECURSE} -import org.apache.spark.sql.catalyst.expressions.{And, Attribute, CreateNamedStruct, Expression, GetStructField, If, IsNull, Literal} -import org.apache.spark.sql.catalyst.expressions.objects.AssertNotNull +import org.apache.spark.sql.catalyst.expressions.{Attribute, CreateNamedStruct, Expression, GetStructField, IsNull, Literal} import org.apache.spark.sql.catalyst.plans.logical.Assignment import org.apache.spark.sql.catalyst.types.DataTypeUtils import org.apache.spark.sql.catalyst.util.CharVarcharUtils @@ -182,31 +181,11 @@ object AssignmentUtils extends SQLConfHelper with CastSupport { } else if (exactAssignments.isEmpty && fieldAssignments.isEmpty) { TableOutputResolver.checkNullability(colExpr, col, conf, colPath) } else if (exactAssignments.nonEmpty) { - if (updateStar) { - val value = exactAssignments.head.value - col.dataType match { - case structType: StructType => - // Expand assignments to leaf fields - val structAssignment = - applyNestedFieldAssignments(col, colExpr, value, addError, colPath, - coerceNestedTypes) - - // Wrap with null check for missing source fields - fixNullExpansion(col, value, structType, structAssignment, - colPath, addError) - case _ => - // For non-struct types, resolve directly - val coerceMode = if (coerceNestedTypes) RECURSE else NONE - TableOutputResolver.resolveUpdate("", value, col, conf, addError, colPath, - coerceMode) - } - } else { - val value = exactAssignments.head.value - val coerceMode = if (coerceNestedTypes) RECURSE else NONE - val resolvedValue = TableOutputResolver.resolveUpdate("", value, col, conf, addError, - colPath, coerceMode) - resolvedValue - } + val value = exactAssignments.head.value + val coerceMode = if (coerceNestedTypes) RECURSE else NONE + val resolvedValue = TableOutputResolver.resolveUpdate("", value, col, conf, addError, + colPath, coerceMode) + resolvedValue } else { applyFieldAssignments(col, colExpr, fieldAssignments, addError, colPath, coerceNestedTypes) } @@ -240,63 +219,6 @@ object AssignmentUtils extends SQLConfHelper with CastSupport { } } - private def applyNestedFieldAssignments( - col: Attribute, - colExpr: Expression, - value: Expression, - addError: String => Unit, - colPath: Seq[String], - coerceNestedTyptes: Boolean): Expression = { - - col.dataType match { - case structType: StructType => - val fieldAttrs = DataTypeUtils.toAttributes(structType) - - val updatedFieldExprs = fieldAttrs.zipWithIndex.map { case (fieldAttr, ordinal) => - val fieldPath = colPath :+ fieldAttr.name - val targetFieldExpr = GetStructField(colExpr, ordinal, Some(fieldAttr.name)) - - // Try to find a corresponding field in the source value by name - val sourceFieldValue: Expression = value.dataType match { - case valueStructType: StructType => - valueStructType.fields.find(f => conf.resolver(f.name, fieldAttr.name)) match { - case Some(matchingField) => - // Found matching field in source, extract it - val fieldIndex = valueStructType.fieldIndex(matchingField.name) - GetStructField(value, fieldIndex, Some(matchingField.name)) - case None => - // Field doesn't exist in source, use target's current value with null check - TableOutputResolver.checkNullability(targetFieldExpr, fieldAttr, conf, fieldPath) - } - case _ => - // Value is not a struct, cannot extract field - addError(s"Cannot assign non-struct value to struct field '${fieldPath.quoted}'") - Literal(null, fieldAttr.dataType) - } - - // Recurse or resolve based on field type - fieldAttr.dataType match { - case nestedStructType: StructType => - // Field is a struct, recurse - applyNestedFieldAssignments(fieldAttr, targetFieldExpr, sourceFieldValue, - addError, fieldPath, coerceNestedTyptes) - case _ => - // Field is not a struct, resolve with TableOutputResolver - val coerceMode = if (coerceNestedTyptes) RECURSE else NONE - TableOutputResolver.resolveUpdate("", sourceFieldValue, fieldAttr, conf, addError, - fieldPath, coerceMode) - } - } - toNamedStruct(structType, updatedFieldExprs) - - case otherType => - addError( - "Updating nested fields is only supported for StructType but " + - s"'${colPath.quoted}' is of type $otherType") - colExpr - } - } - private def toNamedStruct(structType: StructType, fieldExprs: Seq[Expression]): Expression = { val namedStructExprs = structType.fields.zip(fieldExprs).flatMap { case (field, expr) => Seq(Literal(field.name), expr) @@ -350,55 +272,6 @@ object AssignmentUtils extends SQLConfHelper with CastSupport { IsNull(currentExpr) } - /** - * As UPDATE SET * can assign struct fields individually (preserving existing fields), - * this will lead to null expansion, ie, a struct is created where all fields are null. - * Wraps a struct assignment with null checks for the source and missing source fields. - * Return null if all are null. - * - * @param col the target column attribute - * @param value the source value expression - * @param structType the target struct type - * @param structAssignment the struct assignment result to wrap - * @param colPath the column path for error reporting - * @param addError error reporting function - * @return the wrapped expression with null checks - */ - private def fixNullExpansion( - col: Attribute, - value: Expression, - structType: StructType, - structAssignment: Expression, - colPath: Seq[String], - addError: String => Unit): Expression = { - // As StoreAssignmentPolicy.LEGACY is not allowed in DSv2, always add null check for - // non-nullable column - if (!col.nullable) { - AssertNotNull(value) - } else { - // Check if source struct is null - val valueIsNull = IsNull(value) - - // Check if missing source paths (paths in target but not in source) are not null - // These will be null for the case of UPDATE SET * and - val missingSourcePaths = getMissingSourcePaths(structType, value.dataType, colPath, addError) - val condition = if (missingSourcePaths.nonEmpty) { - // Check if all target attributes at missing source paths are null - val missingFieldNullChecks = missingSourcePaths.map { path => - createNullCheckForFieldPath(col, path) - } - // Combine all null checks with AND - val allMissingFieldsNull = missingFieldNullChecks.reduce[Expression]((a, b) => And(a, b)) - And(valueIsNull, allMissingFieldsNull) - } else { - valueIsNull - } - - // Return: If (condition) THEN NULL ELSE structAssignment - If(condition, Literal(null, structAssignment.dataType), structAssignment) - } - } - /** * Checks whether assignments are aligned and compatible with table columns. * diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 8b50abbe4052..85011da24de9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -6699,10 +6699,11 @@ object SQLConf { buildConf("spark.sql.merge.nested.type.coercion.enabled") .internal() .doc("If enabled, allow MERGE INTO to coerce source nested types if they have less" + - "nested fields than the target table's nested types.") + "nested fields than the target table's nested types. This is experimental and" + + "the semantics may change.") .version("4.1.0") .booleanConf - .createWithDefault(true) + .createWithDefault(false) /** * Holds information about keys that have been deprecated. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/MergeIntoTableSuiteBase.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/MergeIntoTableSuiteBase.scala index 680fa63e0929..7539506e8bfe 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/MergeIntoTableSuiteBase.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/MergeIntoTableSuiteBase.scala @@ -3171,222 +3171,275 @@ abstract class MergeIntoTableSuiteBase extends RowLevelOperationSuiteBase test("merge into schema evolution replace column with nested struct and set explicit columns") { Seq(true, false).foreach { withSchemaEvolution => - withTempView("source") { - createAndInitTable( - s"""pk INT NOT NULL, - |s STRUCT, m: MAP>>, - |dep STRING""".stripMargin, - """{ "pk": 1, "s": { "c1": 2, "c2": { "a": [1,2], "m": { "a": "b" } } }, "dep": "hr" }""") + Seq(true, false).foreach { coerceNestedTypes => + withSQLConf(SQLConf.MERGE_INTO_NESTED_TYPE_COERCION_ENABLED.key -> + coerceNestedTypes.toString) { + withTempView("source") { + sql( + s"""CREATE TABLE $tableNameAsString ( + |pk INT NOT NULL, + |s STRUCT, + | m: MAP>>, + |dep STRING)""".stripMargin) - val sourceTableSchema = StructType(Seq( - StructField("pk", IntegerType, nullable = false), - StructField("s", StructType(Seq( - StructField("c1", IntegerType), - StructField("c2", StructType(Seq( - // removed column 'a' - StructField("m", MapType(StringType, StringType)), - StructField("c3", BooleanType) // new column - ))) - ))), - StructField("dep", StringType) - )) - val data = Seq( - Row(1, Row(10, Row(Map("c" -> "d"), false)), "sales"), - Row(2, Row(20, Row(Map("e" -> "f"), true)), "engineering") - ) - spark.createDataFrame(spark.sparkContext.parallelize(data), sourceTableSchema) - .createOrReplaceTempView("source") + val targetSchema = StructType(Seq( + StructField("pk", IntegerType, nullable = false), + StructField("s", StructType(Seq( + StructField("c1", IntegerType), + StructField("c2", StructType(Seq( + StructField("a", ArrayType(IntegerType)), + StructField("m", MapType(StringType, StringType)) + ))) + ))), + StructField("dep", StringType) + )) + val targetData = Seq( + Row(1, Row(2, Row(Array(1, 2), Map("a" -> "b"))), "hr") + ) + spark.createDataFrame( + spark.sparkContext.parallelize(targetData), targetSchema) + .writeTo(tableNameAsString).append() - val schemaEvolutionClause = if (withSchemaEvolution) "WITH SCHEMA EVOLUTION" else "" - val mergeStmt = - s"""MERGE $schemaEvolutionClause - |INTO $tableNameAsString t - |USING source src - |ON t.pk = src.pk - |WHEN MATCHED THEN - | UPDATE SET s.c1 = -1, s.c2.m = map('k', 'v'), s.c2.a = array(-1), - | s.c2.c3 = src.s.c2.c3 - |WHEN NOT MATCHED THEN - | INSERT (pk, s, dep) VALUES (src.pk, - | named_struct('c1', src.s.c1, - | 'c2', named_struct('a', array(-2), 'm', map('g', 'h'), 'c3', true)), src.dep) - |""".stripMargin + val sourceTableSchema = StructType(Seq( + StructField("pk", IntegerType, nullable = false), + StructField("s", StructType(Seq( + StructField("c1", IntegerType), + StructField("c2", StructType(Seq( + StructField("m", MapType(StringType, StringType)), + StructField("c3", BooleanType) + ))) + ))), + StructField("dep", StringType) + )) + val data = Seq( + Row(1, Row(10, Row(Map("c" -> "d"), false)), "sales"), + Row(2, Row(20, Row(Map("e" -> "f"), true)), "engineering") + ) + spark.createDataFrame(spark.sparkContext.parallelize(data), sourceTableSchema) + .createOrReplaceTempView("source") - if (withSchemaEvolution) { - sql(mergeStmt) - checkAnswer( - sql(s"SELECT * FROM $tableNameAsString"), - Seq(Row(1, Row(-1, Row(Seq(-1), Map("k" -> "v"), false)), "hr"), - Row(2, Row(20, Row(Seq(-2), Map("g" -> "h"), true)), "engineering"))) - } else { - val exception = intercept[org.apache.spark.sql.AnalysisException] { - sql(mergeStmt) + val schemaEvolutionClause = if (withSchemaEvolution) "WITH SCHEMA EVOLUTION" else "" + val mergeStmt = + s"""MERGE $schemaEvolutionClause + |INTO $tableNameAsString t + |USING source src + |ON t.pk = src.pk + |WHEN MATCHED THEN + | UPDATE SET s.c1 = -1, s.c2.m = map('k', 'v'), s.c2.a = array(-1), + | s.c2.c3 = src.s.c2.c3 + |WHEN NOT MATCHED THEN + | INSERT (pk, s, dep) VALUES (src.pk, + | named_struct('c1', src.s.c1, + | 'c2', named_struct('a', array(-2), 'm', map('g', 'h'), 'c3', true)), src.dep) + |""".stripMargin + + if (withSchemaEvolution) { + sql(mergeStmt) + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Seq(Row(1, Row(-1, Row(Seq(-1), Map("k" -> "v"), false)), "hr"), + Row(2, Row(20, Row(Seq(-2), Map("g" -> "h"), true)), "engineering"))) + } else { + val exception = intercept[org.apache.spark.sql.AnalysisException] { + sql(mergeStmt) + } + assert(exception.errorClass.get == "FIELD_NOT_FOUND") + assert(exception.getMessage.contains("No such struct field `c3` in `a`, `m`. ")) + } } - assert(exception.errorClass.get == "FIELD_NOT_FOUND") - assert(exception.getMessage.contains("No such struct field `c3` in `a`, `m`. ")) + sql(s"DROP TABLE IF EXISTS $tableNameAsString") } } - sql(s"DROP TABLE IF EXISTS $tableNameAsString") } } test("merge into schema evolution replace column with nested struct and set all columns") { Seq(true, false).foreach { withSchemaEvolution => - withTempView("source") { - // Create table using Spark SQL - sql( - s"""CREATE TABLE $tableNameAsString ( - |pk INT NOT NULL, - |s STRUCT, m: MAP>>, - |dep STRING) - |PARTITIONED BY (dep) - |""".stripMargin) - // Insert data using DataFrame API with objects - val tableSchema = StructType(Seq( - StructField("pk", IntegerType, nullable = false), - StructField("s", StructType(Seq( - StructField("c1", IntegerType), - StructField("c2", StructType(Seq( - StructField("a", ArrayType(IntegerType)), - StructField("m", MapType(StringType, StringType)) - ))) - ))), - StructField("dep", StringType) - )) - val targetData = Seq( - Row(1, Row(2, Row(Array(1, 2), Map("a" -> "b"))), "hr") - ) - spark.createDataFrame(spark.sparkContext.parallelize(targetData), tableSchema) - .coalesce(1).writeTo(tableNameAsString).append() + Seq(true, false).foreach { coerceNestedTypes => + withSQLConf(SQLConf.MERGE_INTO_NESTED_TYPE_COERCION_ENABLED.key -> + coerceNestedTypes.toString) { + withTempView("source") { + // Create table using Spark SQL + sql( + s"""CREATE TABLE $tableNameAsString ( + |pk INT NOT NULL, + |s STRUCT, m: MAP>>, + |dep STRING) + |PARTITIONED BY (dep) + |""".stripMargin) + // Insert data using DataFrame API with objects + val tableSchema = StructType(Seq( + StructField("pk", IntegerType, nullable = false), + StructField("s", StructType(Seq( + StructField("c1", IntegerType), + StructField("c2", StructType(Seq( + StructField("a", ArrayType(IntegerType)), + StructField("m", MapType(StringType, StringType)) + ))) + ))), + StructField("dep", StringType) + )) + val targetData = Seq( + Row(1, Row(2, Row(Array(1, 2), Map("a" -> "b"))), "hr") + ) + spark.createDataFrame(spark.sparkContext.parallelize(targetData), tableSchema) + .coalesce(1).writeTo(tableNameAsString).append() - val sourceTableSchema = StructType(Seq( - StructField("pk", IntegerType, nullable = false), - StructField("s", StructType(Seq( - StructField("c1", IntegerType), - StructField("c2", StructType(Seq( - // missing column 'a' - StructField("m", MapType(StringType, StringType)), - StructField("c3", BooleanType) // new column - ))) - ))), - StructField("dep", StringType) - )) - val sourceData = Seq( - Row(1, Row(10, Row(Map("c" -> "d"), false)), "sales"), - Row(2, Row(20, Row(Map("e" -> "f"), true)), "engineering") - ) - spark.createDataFrame(spark.sparkContext.parallelize(sourceData), sourceTableSchema) - .createOrReplaceTempView("source") + val sourceTableSchema = StructType(Seq( + StructField("pk", IntegerType, nullable = false), + StructField("s", StructType(Seq( + StructField("c1", IntegerType), + StructField("c2", StructType(Seq( + // missing column 'a' + StructField("m", MapType(StringType, StringType)), + StructField("c3", BooleanType) // new column + ))) + ))), + StructField("dep", StringType) + )) + val sourceData = Seq( + Row(1, Row(10, Row(Map("c" -> "d"), false)), "sales"), + Row(2, Row(20, Row(Map("e" -> "f"), true)), "engineering") + ) + spark.createDataFrame(spark.sparkContext.parallelize(sourceData), sourceTableSchema) + .createOrReplaceTempView("source") - val schemaEvolutionClause = if (withSchemaEvolution) "WITH SCHEMA EVOLUTION" else "" - val mergeStmt = - s"""MERGE $schemaEvolutionClause - |INTO $tableNameAsString t - |USING source src - |ON t.pk = src.pk - |WHEN MATCHED THEN - | UPDATE SET * - |WHEN NOT MATCHED THEN - | INSERT * - |""".stripMargin - if (withSchemaEvolution) { - sql(mergeStmt) - checkAnswer( - sql(s"SELECT * FROM $tableNameAsString"), - Seq( - Row(1, Row(10, Row(Seq(1, 2), Map("c" -> "d"), false)), "sales"), - Row(2, Row(20, Row(null, Map("e" -> "f"), true)), "engineering"))) - } else { - val exception = intercept[org.apache.spark.sql.AnalysisException] { - sql(mergeStmt) + val schemaEvolutionClause = if (withSchemaEvolution) "WITH SCHEMA EVOLUTION" else "" + val mergeStmt = + s"""MERGE $schemaEvolutionClause + |INTO $tableNameAsString t + |USING source src + |ON t.pk = src.pk + |WHEN MATCHED THEN + | UPDATE SET * + |WHEN NOT MATCHED THEN + | INSERT * + |""".stripMargin + if (coerceNestedTypes) { + if (withSchemaEvolution) { + sql(mergeStmt) + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Seq( + Row(1, Row(10, Row(null, Map("c" -> "d"), false)), "sales"), + Row(2, Row(20, Row(null, Map("e" -> "f"), true)), "engineering"))) + } else { + val exception = intercept[org.apache.spark.sql.AnalysisException] { + sql(mergeStmt) + } + assert(exception.errorClass.get == + "INCOMPATIBLE_DATA_FOR_TABLE.EXTRA_STRUCT_FIELDS") + assert(exception.getMessage.contains( + "Cannot write extra fields `c3` to the struct `s`.`c2`")) + } + } else { + val exception = intercept[org.apache.spark.sql.AnalysisException] { + sql(mergeStmt) + } + assert(exception.errorClass.get == "INCOMPATIBLE_DATA_FOR_TABLE.CANNOT_FIND_DATA") + assert(exception.getMessage.contains( + "Cannot find data for the output column `s`.`c2`.`a`")) + } } - assert(exception.errorClass.get == "INCOMPATIBLE_DATA_FOR_TABLE.EXTRA_STRUCT_FIELDS") - assert(exception.getMessage.contains( - "Cannot write extra fields `c3` to the struct `s`.`c2`")) + sql(s"DROP TABLE IF EXISTS $tableNameAsString") } } - sql(s"DROP TABLE IF EXISTS $tableNameAsString") } } test("merge into schema evolution replace column with nested struct and update " + "top level struct") { Seq(true, false).foreach { withSchemaEvolution => - withTempView("source") { - // Create table using Spark SQL - sql( - s"""CREATE TABLE $tableNameAsString ( - |pk INT NOT NULL, - |s STRUCT, m: MAP>>, - |dep STRING) - |PARTITIONED BY (dep) - |""".stripMargin) + Seq(true, false).foreach { coerceNestedTypes => + withSQLConf(SQLConf.MERGE_INTO_NESTED_TYPE_COERCION_ENABLED.key -> + coerceNestedTypes.toString) { + withTempView("source") { + // Create table using Spark SQL + sql( + s"""CREATE TABLE $tableNameAsString ( + |pk INT NOT NULL, + |s STRUCT, m: MAP>>, + |dep STRING) + |PARTITIONED BY (dep) + |""".stripMargin) - // Insert data using DataFrame API with objects - val tableSchema = StructType(Seq( - StructField("pk", IntegerType, nullable = false), - StructField("s", StructType(Seq( - StructField("c1", IntegerType), - StructField("c2", StructType(Seq( - StructField("a", ArrayType(IntegerType)), - StructField("m", MapType(StringType, StringType)) - ))) - ))), - StructField("dep", StringType) - )) - val targetData = Seq( - Row(1, Row(2, Row(Array(1, 2), Map("a" -> "b"))), "hr") - ) - spark.createDataFrame(spark.sparkContext.parallelize(targetData), tableSchema) - .coalesce(1).writeTo(tableNameAsString).append() + // Insert data using DataFrame API with objects + val tableSchema = StructType(Seq( + StructField("pk", IntegerType, nullable = false), + StructField("s", StructType(Seq( + StructField("c1", IntegerType), + StructField("c2", StructType(Seq( + StructField("a", ArrayType(IntegerType)), + StructField("m", MapType(StringType, StringType)) + ))) + ))), + StructField("dep", StringType) + )) + val targetData = Seq( + Row(1, Row(2, Row(Array(1, 2), Map("a" -> "b"))), "hr") + ) + spark.createDataFrame(spark.sparkContext.parallelize(targetData), tableSchema) + .coalesce(1).writeTo(tableNameAsString).append() - val sourceTableSchema = StructType(Seq( - StructField("pk", IntegerType, nullable = false), - StructField("s", StructType(Seq( - StructField("c1", IntegerType), - StructField("c2", StructType(Seq( - // missing column 'a' - StructField("m", MapType(StringType, StringType)), - StructField("c3", BooleanType) // new column - ))) - ))), - StructField("dep", StringType) - )) - val sourceData = Seq( - Row(1, Row(10, Row(Map("c" -> "d"), false)), "sales"), - Row(2, Row(20, Row(Map("e" -> "f"), true)), "engineering") - ) - spark.createDataFrame(spark.sparkContext.parallelize(sourceData), sourceTableSchema) - .createOrReplaceTempView("source") + val sourceTableSchema = StructType(Seq( + StructField("pk", IntegerType, nullable = false), + StructField("s", StructType(Seq( + StructField("c1", IntegerType), + StructField("c2", StructType(Seq( + // missing column 'a' + StructField("m", MapType(StringType, StringType)), + StructField("c3", BooleanType) // new column + ))) + ))), + StructField("dep", StringType) + )) + val sourceData = Seq( + Row(1, Row(10, Row(Map("c" -> "d"), false)), "sales"), + Row(2, Row(20, Row(Map("e" -> "f"), true)), "engineering") + ) + spark.createDataFrame(spark.sparkContext.parallelize(sourceData), sourceTableSchema) + .createOrReplaceTempView("source") - val schemaEvolutionClause = if (withSchemaEvolution) "WITH SCHEMA EVOLUTION" else "" - val mergeStmt = - s"""MERGE $schemaEvolutionClause - |INTO $tableNameAsString t - |USING source src - |ON t.pk = src.pk - |WHEN MATCHED THEN - | UPDATE SET s = src.s - |WHEN NOT MATCHED THEN - | INSERT * - |""".stripMargin - if (withSchemaEvolution) { - sql(mergeStmt) - checkAnswer( - sql(s"SELECT * FROM $tableNameAsString"), - Seq( - Row(1, Row(10, Row(null, Map("c" -> "d"), false)), "hr"), - Row(2, Row(20, Row(null, Map("e" -> "f"), true)), "engineering"))) - } else { - val exception = intercept[org.apache.spark.sql.AnalysisException] { - sql(mergeStmt) + val schemaEvolutionClause = if (withSchemaEvolution) "WITH SCHEMA EVOLUTION" else "" + val mergeStmt = + s"""MERGE $schemaEvolutionClause + |INTO $tableNameAsString t + |USING source src + |ON t.pk = src.pk + |WHEN MATCHED THEN + | UPDATE SET s = src.s + |WHEN NOT MATCHED THEN + | INSERT * + |""".stripMargin + if (coerceNestedTypes) { + if (withSchemaEvolution) { + sql(mergeStmt) + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Seq( + Row(1, Row(10, Row(null, Map("c" -> "d"), false)), "hr"), + Row(2, Row(20, Row(null, Map("e" -> "f"), true)), "engineering"))) + } else { + val exception = intercept[org.apache.spark.sql.AnalysisException] { + sql(mergeStmt) + } + assert(exception.errorClass.get == + "INCOMPATIBLE_DATA_FOR_TABLE.EXTRA_STRUCT_FIELDS") + assert(exception.getMessage.contains( + "Cannot write extra fields `c3` to the struct `s`.`c2`")) + } + } else { + val exception = intercept[org.apache.spark.sql.AnalysisException] { + sql(mergeStmt) + } + assert(exception.errorClass.get == "INCOMPATIBLE_DATA_FOR_TABLE.CANNOT_FIND_DATA") + assert(exception.getMessage.contains( + "Cannot find data for the output column `s`.`c2`.`a`")) + } } - assert(exception.errorClass.get == "INCOMPATIBLE_DATA_FOR_TABLE.EXTRA_STRUCT_FIELDS") - assert(exception.getMessage.contains( - "Cannot write extra fields `c3` to the struct `s`.`c2`")) + sql(s"DROP TABLE IF EXISTS $tableNameAsString") } } - sql(s"DROP TABLE IF EXISTS $tableNameAsString") } } @@ -3513,249 +3566,309 @@ abstract class MergeIntoTableSuiteBase extends RowLevelOperationSuiteBase test("merge into schema evolution replace column for struct in map and set all columns") { Seq(true, false).foreach { withSchemaEvolution => - withTempView("source") { - val schema = - StructType(Seq( - StructField("pk", IntegerType, nullable = false), - StructField("m", MapType( - StructType(Seq(StructField("c1", IntegerType), StructField("c2", IntegerType))), - StructType(Seq(StructField("c4", StringType), StructField("c5", StringType))))), - StructField("dep", StringType))) - createTable(CatalogV2Util.structTypeToV2Columns(schema)) + Seq(true, false).foreach { coerceNestedTypes => + withSQLConf(SQLConf.MERGE_INTO_NESTED_TYPE_COERCION_ENABLED.key -> + coerceNestedTypes.toString) { + withTempView("source") { + val schema = + StructType(Seq( + StructField("pk", IntegerType, nullable = false), + StructField("m", MapType( + StructType(Seq(StructField("c1", IntegerType), StructField("c2", IntegerType))), + StructType(Seq(StructField("c4", StringType), StructField("c5", StringType))))), + StructField("dep", StringType))) + createTable(CatalogV2Util.structTypeToV2Columns(schema)) - val data = Seq( - Row(0, Map(Row(10, 10) -> Row("c", "c")), "hr"), - Row(1, Map(Row(20, 20) -> Row("d", "d")), "sales")) - spark.createDataFrame(spark.sparkContext.parallelize(data), schema) - .writeTo(tableNameAsString).append() + val data = Seq( + Row(0, Map(Row(10, 10) -> Row("c", "c")), "hr"), + Row(1, Map(Row(20, 20) -> Row("d", "d")), "sales")) + spark.createDataFrame(spark.sparkContext.parallelize(data), schema) + .writeTo(tableNameAsString).append() - val sourceTableSchema = StructType(Seq( - StructField("pk", IntegerType), - StructField("m", MapType( - StructType(Seq(StructField("c1", IntegerType), StructField("c3", BooleanType))), - StructType(Seq(StructField("c4", StringType), StructField("c6", BooleanType))))), - StructField("dep", StringType))) - val sourceData = Seq( - Row(1, Map(Row(10, true) -> Row("y", false)), "sales"), - Row(2, Map(Row(20, false) -> Row("z", true)), "engineering") - ) - spark.createDataFrame(spark.sparkContext.parallelize(sourceData), sourceTableSchema) - .createOrReplaceTempView("source") + val sourceTableSchema = StructType(Seq( + StructField("pk", IntegerType), + StructField("m", MapType( + StructType(Seq(StructField("c1", IntegerType), StructField("c3", BooleanType))), + StructType(Seq(StructField("c4", StringType), StructField("c6", BooleanType))))), + StructField("dep", StringType))) + val sourceData = Seq( + Row(1, Map(Row(10, true) -> Row("y", false)), "sales"), + Row(2, Map(Row(20, false) -> Row("z", true)), "engineering") + ) + spark.createDataFrame(spark.sparkContext.parallelize(sourceData), sourceTableSchema) + .createOrReplaceTempView("source") - val schemaEvolutionClause = if (withSchemaEvolution) "WITH SCHEMA EVOLUTION" else "" - val mergeStmt = - s"""MERGE $schemaEvolutionClause - |INTO $tableNameAsString t - |USING source src - |ON t.pk = src.pk - |WHEN MATCHED THEN - | UPDATE SET * - |WHEN NOT MATCHED THEN - | INSERT * - |""".stripMargin + val schemaEvolutionClause = if (withSchemaEvolution) "WITH SCHEMA EVOLUTION" else "" + val mergeStmt = + s"""MERGE $schemaEvolutionClause + |INTO $tableNameAsString t + |USING source src + |ON t.pk = src.pk + |WHEN MATCHED THEN + | UPDATE SET * + |WHEN NOT MATCHED THEN + | INSERT * + |""".stripMargin - if (withSchemaEvolution) { - sql(mergeStmt) - checkAnswer( - sql(s"SELECT * FROM $tableNameAsString"), - Seq(Row(0, Map(Row(10, 10, null) -> Row("c", "c", null)), "hr"), - Row(1, Map(Row(10, null, true) -> Row("y", null, false)), "sales"), - Row(2, Map(Row(20, null, false) -> Row("z", null, true)), "engineering"))) - } else { - val exception = intercept[org.apache.spark.sql.AnalysisException] { - sql(mergeStmt) + if (coerceNestedTypes) { + if (withSchemaEvolution) { + sql(mergeStmt) + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Seq(Row(0, Map(Row(10, 10, null) -> Row("c", "c", null)), "hr"), + Row(1, Map(Row(10, null, true) -> Row("y", null, false)), "sales"), + Row(2, Map(Row(20, null, false) -> Row("z", null, true)), "engineering"))) + } else { + val exception = intercept[org.apache.spark.sql.AnalysisException] { + sql(mergeStmt) + } + assert(exception.errorClass.get == + "INCOMPATIBLE_DATA_FOR_TABLE.EXTRA_STRUCT_FIELDS") + assert(exception.getMessage.contains( + "Cannot write extra fields `c3` to the struct `m`.`key`")) + } + } else { + val exception = intercept[org.apache.spark.sql.AnalysisException] { + sql(mergeStmt) + } + assert(exception.errorClass.get == "INCOMPATIBLE_DATA_FOR_TABLE.CANNOT_FIND_DATA") + assert(exception.getMessage.contains( + "Cannot find data for the output column `m`.`key`.`c2`")) + } } - assert(exception.errorClass.get == "INCOMPATIBLE_DATA_FOR_TABLE.EXTRA_STRUCT_FIELDS") - assert(exception.getMessage.contains( - "Cannot write extra fields `c3` to the struct `m`.`key`")) + sql(s"DROP TABLE IF EXISTS $tableNameAsString") } } - sql(s"DROP TABLE IF EXISTS $tableNameAsString") } } test("merge into schema evolution replace column for struct in map and set explicit columns") { Seq(true, false).foreach { withSchemaEvolution => - withTempView("source") { - val schema = - StructType(Seq( - StructField("pk", IntegerType, nullable = false), - StructField("m", MapType( - StructType(Seq(StructField("c1", IntegerType), StructField("c2", IntegerType))), - StructType(Seq(StructField("c4", StringType), StructField("c5", StringType))))), - StructField("dep", StringType))) - createTable(CatalogV2Util.structTypeToV2Columns(schema)) + Seq(true, false).foreach { coerceNestedTypes => + withSQLConf(SQLConf.MERGE_INTO_NESTED_TYPE_COERCION_ENABLED.key -> + coerceNestedTypes.toString) { + withTempView("source") { + val schema = + StructType(Seq( + StructField("pk", IntegerType, nullable = false), + StructField("m", MapType( + StructType(Seq(StructField("c1", IntegerType), StructField("c2", IntegerType))), + StructType(Seq(StructField("c4", StringType), StructField("c5", StringType))))), + StructField("dep", StringType))) + createTable(CatalogV2Util.structTypeToV2Columns(schema)) - val data = Seq( - Row(0, Map(Row(10, 10) -> Row("c", "c")), "hr"), - Row(1, Map(Row(20, 20) -> Row("d", "d")), "sales")) - spark.createDataFrame(spark.sparkContext.parallelize(data), schema) - .writeTo(tableNameAsString).append() + val data = Seq( + Row(0, Map(Row(10, 10) -> Row("c", "c")), "hr"), + Row(1, Map(Row(20, 20) -> Row("d", "d")), "sales")) + spark.createDataFrame(spark.sparkContext.parallelize(data), schema) + .writeTo(tableNameAsString).append() - val sourceTableSchema = StructType(Seq( - StructField("pk", IntegerType), - StructField("m", MapType( - StructType(Seq(StructField("c1", IntegerType), StructField("c3", BooleanType))), - StructType(Seq(StructField("c4", StringType), StructField("c6", BooleanType))))), - StructField("dep", StringType))) - val sourceData = Seq( - Row(1, Map(Row(10, true) -> Row("y", false)), "sales"), - Row(2, Map(Row(20, false) -> Row("z", true)), "engineering") - ) - spark.createDataFrame(spark.sparkContext.parallelize(sourceData), sourceTableSchema) - .createOrReplaceTempView("source") + val sourceTableSchema = StructType(Seq( + StructField("pk", IntegerType), + StructField("m", MapType( + StructType(Seq(StructField("c1", IntegerType), StructField("c3", BooleanType))), + StructType(Seq(StructField("c4", StringType), StructField("c6", BooleanType))))), + StructField("dep", StringType))) + val sourceData = Seq( + Row(1, Map(Row(10, true) -> Row("y", false)), "sales"), + Row(2, Map(Row(20, false) -> Row("z", true)), "engineering") + ) + spark.createDataFrame(spark.sparkContext.parallelize(sourceData), sourceTableSchema) + .createOrReplaceTempView("source") - val schemaEvolutionClause = if (withSchemaEvolution) "WITH SCHEMA EVOLUTION" else "" - val mergeStmt = - s"""MERGE $schemaEvolutionClause - |INTO $tableNameAsString t - |USING source src - |ON t.pk = src.pk - |WHEN MATCHED THEN - | UPDATE SET t.m = src.m, t.dep = 'my_old_dep' - |WHEN NOT MATCHED THEN - | INSERT (pk, m, dep) VALUES (src.pk, src.m, 'my_new_dep') - |""".stripMargin + val schemaEvolutionClause = if (withSchemaEvolution) "WITH SCHEMA EVOLUTION" else "" + val mergeStmt = + s"""MERGE $schemaEvolutionClause + |INTO $tableNameAsString t + |USING source src + |ON t.pk = src.pk + |WHEN MATCHED THEN + | UPDATE SET t.m = src.m, t.dep = 'my_old_dep' + |WHEN NOT MATCHED THEN + | INSERT (pk, m, dep) VALUES (src.pk, src.m, 'my_new_dep') + |""".stripMargin - if (withSchemaEvolution) { - sql(mergeStmt) - checkAnswer( - sql(s"SELECT * FROM $tableNameAsString"), - Seq(Row(0, Map(Row(10, 10, null) -> Row("c", "c", null)), "hr"), - Row(1, Map(Row(10, null, true) -> Row("y", null, false)), "my_old_dep"), - Row(2, Map(Row(20, null, false) -> Row("z", null, true)), "my_new_dep"))) - } else { - val exception = intercept[org.apache.spark.sql.AnalysisException] { - sql(mergeStmt) + if (coerceNestedTypes) { + if (withSchemaEvolution) { + sql(mergeStmt) + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Seq(Row(0, Map(Row(10, 10, null) -> Row("c", "c", null)), "hr"), + Row(1, Map(Row(10, null, true) -> Row("y", null, false)), "my_old_dep"), + Row(2, Map(Row(20, null, false) -> Row("z", null, true)), "my_new_dep"))) + } else { + val exception = intercept[org.apache.spark.sql.AnalysisException] { + sql(mergeStmt) + } + assert(exception.errorClass.get == + "INCOMPATIBLE_DATA_FOR_TABLE.EXTRA_STRUCT_FIELDS") + assert(exception.getMessage.contains( + "Cannot write extra fields `c3` to the struct `m`.`key`")) + } + } else { + val exception = intercept[org.apache.spark.sql.AnalysisException] { + sql(mergeStmt) + } + assert(exception.errorClass.get == "INCOMPATIBLE_DATA_FOR_TABLE.CANNOT_FIND_DATA") + assert(exception.getMessage.contains( + "Cannot find data for the output column `m`.`key`.`c2`")) + } } - assert(exception.errorClass.get == "INCOMPATIBLE_DATA_FOR_TABLE.EXTRA_STRUCT_FIELDS") - assert(exception.getMessage.contains( - "Cannot write extra fields `c3` to the struct `m`.`key`")) + sql(s"DROP TABLE IF EXISTS $tableNameAsString") } } - sql(s"DROP TABLE IF EXISTS $tableNameAsString") } } test("merge into schema evolution replace column for struct in array and set all columns") { Seq(true, false).foreach { withSchemaEvolution => - withTempView("source") { - val schema = - StructType(Seq( - StructField("pk", IntegerType, nullable = false), - StructField("a", ArrayType( - StructType(Seq(StructField("c1", IntegerType), StructField("c2", IntegerType))))), - StructField("dep", StringType))) - createTable(CatalogV2Util.structTypeToV2Columns(schema)) + Seq(true, false).foreach { coerceNestedTypes => + withSQLConf(SQLConf.MERGE_INTO_NESTED_TYPE_COERCION_ENABLED.key -> + coerceNestedTypes.toString) { + withTempView("source") { + val schema = + StructType(Seq( + StructField("pk", IntegerType, nullable = false), + StructField("a", ArrayType( + StructType(Seq(StructField("c1", IntegerType), StructField("c2", IntegerType))))), + StructField("dep", StringType))) + createTable(CatalogV2Util.structTypeToV2Columns(schema)) - val data = Seq( - Row(0, Array(Row(10, 10)), "hr"), - Row(1, Array(Row(20, 20)), "sales")) - spark.createDataFrame(spark.sparkContext.parallelize(data), schema) - .writeTo(tableNameAsString).append() + val data = Seq( + Row(0, Array(Row(10, 10)), "hr"), + Row(1, Array(Row(20, 20)), "sales")) + spark.createDataFrame(spark.sparkContext.parallelize(data), schema) + .writeTo(tableNameAsString).append() - val sourceTableSchema = StructType(Seq( - StructField("pk", IntegerType), - StructField("a", ArrayType( - StructType(Seq(StructField("c1", IntegerType), StructField("c3", BooleanType))))), - StructField("dep", StringType))) - val sourceData = Seq( - Row(1, Array(Row(10, true)), "sales"), - Row(2, Array(Row(20, false)), "engineering") - ) - spark.createDataFrame(spark.sparkContext.parallelize(sourceData), sourceTableSchema) - .createOrReplaceTempView("source") + val sourceTableSchema = StructType(Seq( + StructField("pk", IntegerType), + StructField("a", ArrayType( + StructType(Seq(StructField("c1", IntegerType), StructField("c3", BooleanType))))), + StructField("dep", StringType))) + val sourceData = Seq( + Row(1, Array(Row(10, true)), "sales"), + Row(2, Array(Row(20, false)), "engineering") + ) + spark.createDataFrame(spark.sparkContext.parallelize(sourceData), sourceTableSchema) + .createOrReplaceTempView("source") - val schemaEvolutionClause = if (withSchemaEvolution) "WITH SCHEMA EVOLUTION" else "" - val mergeStmt = - s"""MERGE $schemaEvolutionClause - |INTO $tableNameAsString t - |USING source src - |ON t.pk = src.pk - |WHEN MATCHED THEN - | UPDATE SET * - |WHEN NOT MATCHED THEN - | INSERT * - |""".stripMargin + val schemaEvolutionClause = if (withSchemaEvolution) "WITH SCHEMA EVOLUTION" else "" + val mergeStmt = + s"""MERGE $schemaEvolutionClause + |INTO $tableNameAsString t + |USING source src + |ON t.pk = src.pk + |WHEN MATCHED THEN + | UPDATE SET * + |WHEN NOT MATCHED THEN + | INSERT * + |""".stripMargin - if (withSchemaEvolution) { - sql(mergeStmt) - checkAnswer( - sql(s"SELECT * FROM $tableNameAsString"), - Seq(Row(0, Array(Row(10, 10, null)), "hr"), - Row(1, Array(Row(10, null, true)), "sales"), - Row(2, Array(Row(20, null, false)), "engineering"))) - } else { - val exception = intercept[org.apache.spark.sql.AnalysisException] { - sql(mergeStmt) + if (coerceNestedTypes) { + if (withSchemaEvolution) { + sql(mergeStmt) + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Seq(Row(0, Array(Row(10, 10, null)), "hr"), + Row(1, Array(Row(10, null, true)), "sales"), + Row(2, Array(Row(20, null, false)), "engineering"))) + } else { + val exception = intercept[org.apache.spark.sql.AnalysisException] { + sql(mergeStmt) + } + assert(exception.errorClass.get == + "INCOMPATIBLE_DATA_FOR_TABLE.EXTRA_STRUCT_FIELDS") + assert(exception.getMessage.contains( + "Cannot write extra fields `c3` to the struct `a`.`element`")) + } + } else { + val exception = intercept[org.apache.spark.sql.AnalysisException] { + sql(mergeStmt) + } + assert(exception.errorClass.get == "INCOMPATIBLE_DATA_FOR_TABLE.CANNOT_FIND_DATA") + assert(exception.getMessage.contains( + "Cannot find data for the output column `a`.`element`.`c2`")) + } } - assert(exception.errorClass.get == "INCOMPATIBLE_DATA_FOR_TABLE.EXTRA_STRUCT_FIELDS") - assert(exception.getMessage.contains( - "Cannot write extra fields `c3` to the struct `a`.`element`")) + sql(s"DROP TABLE IF EXISTS $tableNameAsString") } } - sql(s"DROP TABLE IF EXISTS $tableNameAsString") } } test("merge into schema evolution replace column for struct in array and set explicit columns") { Seq(true, false).foreach { withSchemaEvolution => - withTempView("source") { - val schema = - StructType(Seq( - StructField("pk", IntegerType, nullable = false), - StructField("a", ArrayType( - StructType(Seq(StructField("c1", IntegerType), StructField("c2", IntegerType))))), - StructField("dep", StringType))) - createTable(CatalogV2Util.structTypeToV2Columns(schema)) + Seq(true, false).foreach { coerceNestedTypes => + withSQLConf(SQLConf.MERGE_INTO_NESTED_TYPE_COERCION_ENABLED.key -> + coerceNestedTypes.toString) { + withTempView("source") { + val schema = + StructType(Seq( + StructField("pk", IntegerType, nullable = false), + StructField("a", ArrayType( + StructType(Seq(StructField("c1", IntegerType), StructField("c2", IntegerType))))), + StructField("dep", StringType))) + createTable(CatalogV2Util.structTypeToV2Columns(schema)) - val data = Seq( - Row(0, Array(Row(10, 10)), "hr"), - Row(1, Array(Row(20, 20)), "sales")) - spark.createDataFrame(spark.sparkContext.parallelize(data), schema) - .writeTo(tableNameAsString).append() + val data = Seq( + Row(0, Array(Row(10, 10)), "hr"), + Row(1, Array(Row(20, 20)), "sales")) + spark.createDataFrame(spark.sparkContext.parallelize(data), schema) + .writeTo(tableNameAsString).append() - val sourceTableSchema = StructType(Seq( - StructField("pk", IntegerType), - StructField("a", ArrayType( - StructType(Seq(StructField("c1", IntegerType), StructField("c3", BooleanType))))), - StructField("dep", StringType))) - val sourceData = Seq( - Row(1, Array(Row(10, true)), "sales"), - Row(2, Array(Row(20, false)), "engineering") - ) - spark.createDataFrame(spark.sparkContext.parallelize(sourceData), sourceTableSchema) - .createOrReplaceTempView("source") + val sourceTableSchema = StructType(Seq( + StructField("pk", IntegerType), + StructField("a", ArrayType( + StructType(Seq(StructField("c1", IntegerType), StructField("c3", BooleanType))))), + StructField("dep", StringType))) + val sourceData = Seq( + Row(1, Array(Row(10, true)), "sales"), + Row(2, Array(Row(20, false)), "engineering") + ) + spark.createDataFrame(spark.sparkContext.parallelize(sourceData), sourceTableSchema) + .createOrReplaceTempView("source") - val schemaEvolutionClause = if (withSchemaEvolution) "WITH SCHEMA EVOLUTION" else "" - val mergeStmt = - s"""MERGE $schemaEvolutionClause - |INTO $tableNameAsString t - |USING source src - |ON t.pk = src.pk - |WHEN MATCHED THEN - | UPDATE SET t.a = src.a, t.dep = 'my_old_dep' - |WHEN NOT MATCHED THEN - | INSERT (pk, a, dep) VALUES (src.pk, src.a, 'my_new_dep') - |""".stripMargin + val schemaEvolutionClause = if (withSchemaEvolution) "WITH SCHEMA EVOLUTION" else "" + val mergeStmt = + s"""MERGE $schemaEvolutionClause + |INTO $tableNameAsString t + |USING source src + |ON t.pk = src.pk + |WHEN MATCHED THEN + | UPDATE SET t.a = src.a, t.dep = 'my_old_dep' + |WHEN NOT MATCHED THEN + | INSERT (pk, a, dep) VALUES (src.pk, src.a, 'my_new_dep') + |""".stripMargin - if (withSchemaEvolution) { - sql(mergeStmt) - checkAnswer( - sql(s"SELECT * FROM $tableNameAsString"), - Seq(Row(0, Array(Row(10, 10, null)), "hr"), - Row(1, Array(Row(10, null, true)), "my_old_dep"), - Row(2, Array(Row(20, null, false)), "my_new_dep"))) - } else { - val exception = intercept[org.apache.spark.sql.AnalysisException] { - sql(mergeStmt) + if (coerceNestedTypes) { + if (withSchemaEvolution) { + sql(mergeStmt) + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Seq(Row(0, Array(Row(10, 10, null)), "hr"), + Row(1, Array(Row(10, null, true)), "my_old_dep"), + Row(2, Array(Row(20, null, false)), "my_new_dep"))) + } else { + val exception = intercept[org.apache.spark.sql.AnalysisException] { + sql(mergeStmt) + } + assert(exception.errorClass.get == + "INCOMPATIBLE_DATA_FOR_TABLE.EXTRA_STRUCT_FIELDS") + assert(exception.getMessage.contains( + "Cannot write extra fields `c3` to the struct `a`.`element`")) + } + } else { + val exception = intercept[org.apache.spark.sql.AnalysisException] { + sql(mergeStmt) + } + assert(exception.errorClass.get == "INCOMPATIBLE_DATA_FOR_TABLE.CANNOT_FIND_DATA") + assert(exception.getMessage.contains( + "Cannot find data for the output column `a`.`element`.`c2`")) + } } - assert(exception.errorClass.get == "INCOMPATIBLE_DATA_FOR_TABLE.EXTRA_STRUCT_FIELDS") - assert(exception.getMessage.contains( - "Cannot write extra fields `c3` to the struct `a`.`element`")) + sql(s"DROP TABLE IF EXISTS $tableNameAsString") } } - sql(s"DROP TABLE IF EXISTS $tableNameAsString") } } test("merge into empty table with NOT MATCHED clause schema evolution") { @@ -4447,200 +4560,268 @@ abstract class MergeIntoTableSuiteBase extends RowLevelOperationSuiteBase } test("merge into with source missing fields in struct nested in array") { - withTempView("source") { - // Target table has struct with 3 fields (c1, c2, c3) in array - createAndInitTable( - s"""pk INT NOT NULL, - |a ARRAY>, - |dep STRING""".stripMargin, - """{ "pk": 0, "a": [ { "c1": 1, "c2": "a", "c3": true } ], "dep": "sales" } - |{ "pk": 1, "a": [ { "c1": 2, "c2": "b", "c3": false } ], "dep": "sales" }""" - .stripMargin) + Seq(true, false).foreach { coerceNestedTypes => + withSQLConf(SQLConf.MERGE_INTO_NESTED_TYPE_COERCION_ENABLED.key -> + coerceNestedTypes.toString) { + withTempView("source") { + // Target table has struct with 3 fields (c1, c2, c3) in array + createAndInitTable( + s"""pk INT NOT NULL, + |a ARRAY>, + |dep STRING""".stripMargin, + """{ "pk": 0, "a": [ { "c1": 1, "c2": "a", "c3": true } ], "dep": "sales" } + |{ "pk": 1, "a": [ { "c1": 2, "c2": "b", "c3": false } ], "dep": "sales" }""" + .stripMargin) - // Source table has struct with only 2 fields (c1, c2) - missing c3 - val sourceTableSchema = StructType(Seq( - StructField("pk", IntegerType, nullable = false), - StructField("a", ArrayType( - StructType(Seq( - StructField("c1", IntegerType), - StructField("c2", StringType))))), // missing c3 field - StructField("dep", StringType))) - val data = Seq( - Row(1, Array(Row(10, "c")), "hr"), - Row(2, Array(Row(30, "e")), "engineering") - ) - spark.createDataFrame(spark.sparkContext.parallelize(data), sourceTableSchema) - .createOrReplaceTempView("source") + // Source table has struct with only 2 fields (c1, c2) - missing c3 + val sourceTableSchema = StructType(Seq( + StructField("pk", IntegerType, nullable = false), + StructField("a", ArrayType( + StructType(Seq( + StructField("c1", IntegerType), + StructField("c2", StringType))))), // missing c3 field + StructField("dep", StringType))) + val data = Seq( + Row(1, Array(Row(10, "c")), "hr"), + Row(2, Array(Row(30, "e")), "engineering") + ) + spark.createDataFrame(spark.sparkContext.parallelize(data), sourceTableSchema) + .createOrReplaceTempView("source") - sql( - s"""MERGE INTO $tableNameAsString t - |USING source src - |ON t.pk = src.pk - |WHEN MATCHED THEN - | UPDATE SET * - |WHEN NOT MATCHED THEN - | INSERT * - |""".stripMargin) + val mergeStmt = + s"""MERGE INTO $tableNameAsString t + |USING source src + |ON t.pk = src.pk + |WHEN MATCHED THEN + | UPDATE SET * + |WHEN NOT MATCHED THEN + | INSERT * + |""".stripMargin - // Missing field c3 should be filled with NULL - checkAnswer( - sql(s"SELECT * FROM $tableNameAsString"), - Seq( - Row(0, Array(Row(1, "a", true)), "sales"), - Row(1, Array(Row(10, "c", null)), "hr"), - Row(2, Array(Row(30, "e", null)), "engineering"))) + if (coerceNestedTypes) { + sql(mergeStmt) + // Missing field c3 should be filled with NULL + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Seq( + Row(0, Array(Row(1, "a", true)), "sales"), + Row(1, Array(Row(10, "c", null)), "hr"), + Row(2, Array(Row(30, "e", null)), "engineering"))) + } else { + val exception = intercept[org.apache.spark.sql.AnalysisException] { + sql(mergeStmt) + } + assert(exception.errorClass.get == + "INCOMPATIBLE_DATA_FOR_TABLE.CANNOT_FIND_DATA") + assert(exception.getMessage.contains( + "Cannot write incompatible data for the table ``: " + + "Cannot find data for the output column `a`.`element`.`c3`.")) + } + } + sql(s"DROP TABLE IF EXISTS $tableNameAsString") + } } - sql(s"DROP TABLE IF EXISTS $tableNameAsString") } test("merge into with source missing fields in struct nested in map key") { - withTempView("source") { - // Target table has struct with 2 fields in map key - val targetSchema = - StructType(Seq( - StructField("pk", IntegerType, nullable = false), - StructField("m", MapType( - StructType(Seq(StructField("c1", IntegerType), StructField("c2", BooleanType))), - StructType(Seq(StructField("c3", StringType))))), - StructField("dep", StringType))) - createTable(CatalogV2Util.structTypeToV2Columns(targetSchema)) + Seq(true, false).foreach { coerceNestedTypes => + withSQLConf(SQLConf.MERGE_INTO_NESTED_TYPE_COERCION_ENABLED.key -> + coerceNestedTypes.toString) { + withTempView("source") { + // Target table has struct with 2 fields in map key + val targetSchema = + StructType(Seq( + StructField("pk", IntegerType, nullable = false), + StructField("m", MapType( + StructType(Seq(StructField("c1", IntegerType), StructField("c2", BooleanType))), + StructType(Seq(StructField("c3", StringType))))), + StructField("dep", StringType))) + createTable(CatalogV2Util.structTypeToV2Columns(targetSchema)) - val targetData = Seq( - Row(0, Map(Row(10, true) -> Row("x")), "hr"), - Row(1, Map(Row(20, false) -> Row("y")), "sales")) - spark.createDataFrame(spark.sparkContext.parallelize(targetData), targetSchema) - .writeTo(tableNameAsString).append() + val targetData = Seq( + Row(0, Map(Row(10, true) -> Row("x")), "hr"), + Row(1, Map(Row(20, false) -> Row("y")), "sales")) + spark.createDataFrame(spark.sparkContext.parallelize(targetData), targetSchema) + .writeTo(tableNameAsString).append() - // Source table has struct with only 1 field (c1) in map key - missing c2 - val sourceTableSchema = StructType(Seq( - StructField("pk", IntegerType), - StructField("m", MapType( - StructType(Seq(StructField("c1", IntegerType))), // missing c2 - StructType(Seq(StructField("c3", StringType))))), - StructField("dep", StringType))) - val sourceData = Seq( - Row(1, Map(Row(10) -> Row("z")), "sales"), - Row(2, Map(Row(20) -> Row("w")), "engineering") - ) - spark.createDataFrame(spark.sparkContext.parallelize(sourceData), sourceTableSchema) - .createOrReplaceTempView("source") + // Source table has struct with only 1 field (c1) in map key - missing c2 + val sourceTableSchema = StructType(Seq( + StructField("pk", IntegerType), + StructField("m", MapType( + StructType(Seq(StructField("c1", IntegerType))), // missing c2 + StructType(Seq(StructField("c3", StringType))))), + StructField("dep", StringType))) + val sourceData = Seq( + Row(1, Map(Row(10) -> Row("z")), "sales"), + Row(2, Map(Row(20) -> Row("w")), "engineering") + ) + spark.createDataFrame(spark.sparkContext.parallelize(sourceData), sourceTableSchema) + .createOrReplaceTempView("source") - sql( - s"""MERGE INTO $tableNameAsString t - |USING source src - |ON t.pk = src.pk - |WHEN MATCHED THEN - | UPDATE SET * - |WHEN NOT MATCHED THEN - | INSERT * - |""".stripMargin) + val mergeStmt = + s"""MERGE INTO $tableNameAsString t + |USING source src + |ON t.pk = src.pk + |WHEN MATCHED THEN + | UPDATE SET * + |WHEN NOT MATCHED THEN + | INSERT * + |""".stripMargin - // Missing field c2 should be filled with NULL - checkAnswer( - sql(s"SELECT * FROM $tableNameAsString"), - Seq( - Row(0, Map(Row(10, true) -> Row("x")), "hr"), - Row(1, Map(Row(10, null) -> Row("z")), "sales"), - Row(2, Map(Row(20, null) -> Row("w")), "engineering"))) + if (coerceNestedTypes) { + sql(mergeStmt) + // Missing field c2 should be filled with NULL + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Seq( + Row(0, Map(Row(10, true) -> Row("x")), "hr"), + Row(1, Map(Row(10, null) -> Row("z")), "sales"), + Row(2, Map(Row(20, null) -> Row("w")), "engineering"))) + } else { + val exception = intercept[org.apache.spark.sql.AnalysisException] { + sql(mergeStmt) + } + assert(exception.errorClass.get == + "INCOMPATIBLE_DATA_FOR_TABLE.CANNOT_FIND_DATA") + assert(exception.getMessage.contains( + "Cannot write incompatible data for the table ``: " + + "Cannot find data for the output column `m`.`key`.`c2`.")) + } + } + sql(s"DROP TABLE IF EXISTS $tableNameAsString") + } } - sql(s"DROP TABLE IF EXISTS $tableNameAsString") } test("merge into with source missing fields in struct nested in map value") { - withTempView("source") { - // Target table has struct with 2 fields in map value - val targetSchema = - StructType(Seq( - StructField("pk", IntegerType, nullable = false), - StructField("m", MapType( - StructType(Seq(StructField("c1", IntegerType))), - StructType(Seq(StructField("c1", StringType), StructField("c2", BooleanType))))), - StructField("dep", StringType))) - createTable(CatalogV2Util.structTypeToV2Columns(targetSchema)) + Seq(true, false).foreach { coerceNestedTypes => + withSQLConf(SQLConf.MERGE_INTO_NESTED_TYPE_COERCION_ENABLED.key -> + coerceNestedTypes.toString) { + withTempView("source") { + // Target table has struct with 2 fields in map value + val targetSchema = + StructType(Seq( + StructField("pk", IntegerType, nullable = false), + StructField("m", MapType( + StructType(Seq(StructField("c1", IntegerType))), + StructType(Seq(StructField("c1", StringType), StructField("c2", BooleanType))))), + StructField("dep", StringType))) + createTable(CatalogV2Util.structTypeToV2Columns(targetSchema)) - val targetData = Seq( - Row(0, Map(Row(10) -> Row("x", true)), "hr"), - Row(1, Map(Row(20) -> Row("y", false)), "sales")) - spark.createDataFrame(spark.sparkContext.parallelize(targetData), targetSchema) - .writeTo(tableNameAsString).append() + val targetData = Seq( + Row(0, Map(Row(10) -> Row("x", true)), "hr"), + Row(1, Map(Row(20) -> Row("y", false)), "sales")) + spark.createDataFrame(spark.sparkContext.parallelize(targetData), targetSchema) + .writeTo(tableNameAsString).append() - // Source table has struct with only 1 field (c1) in map value - missing c2 - val sourceTableSchema = StructType(Seq( - StructField("pk", IntegerType), - StructField("m", MapType( - StructType(Seq(StructField("c1", IntegerType))), - StructType(Seq(StructField("c1", StringType))))), // missing c2 - StructField("dep", StringType))) - val sourceData = Seq( - Row(1, Map(Row(10) -> Row("z")), "sales"), - Row(2, Map(Row(20) -> Row("w")), "engineering") - ) - spark.createDataFrame(spark.sparkContext.parallelize(sourceData), sourceTableSchema) - .createOrReplaceTempView("source") + // Source table has struct with only 1 field (c1) in map value - missing c2 + val sourceTableSchema = StructType(Seq( + StructField("pk", IntegerType), + StructField("m", MapType( + StructType(Seq(StructField("c1", IntegerType))), + StructType(Seq(StructField("c1", StringType))))), // missing c2 + StructField("dep", StringType))) + val sourceData = Seq( + Row(1, Map(Row(10) -> Row("z")), "sales"), + Row(2, Map(Row(20) -> Row("w")), "engineering") + ) + spark.createDataFrame(spark.sparkContext.parallelize(sourceData), sourceTableSchema) + .createOrReplaceTempView("source") - sql( - s"""MERGE INTO $tableNameAsString t - |USING source src - |ON t.pk = src.pk - |WHEN MATCHED THEN - | UPDATE SET * - |WHEN NOT MATCHED THEN - | INSERT * - |""".stripMargin) + val mergeStmt = + s"""MERGE INTO $tableNameAsString t + |USING source src + |ON t.pk = src.pk + |WHEN MATCHED THEN + | UPDATE SET * + |WHEN NOT MATCHED THEN + | INSERT * + |""".stripMargin - // Missing field c2 should be filled with NULL - checkAnswer( - sql(s"SELECT * FROM $tableNameAsString"), - Seq( - Row(0, Map(Row(10) -> Row("x", true)), "hr"), - Row(1, Map(Row(10) -> Row("z", null)), "sales"), - Row(2, Map(Row(20) -> Row("w", null)), "engineering"))) + if (coerceNestedTypes) { + sql(mergeStmt) + // Missing field c2 should be filled with NULL + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Seq( + Row(0, Map(Row(10) -> Row("x", true)), "hr"), + Row(1, Map(Row(10) -> Row("z", null)), "sales"), + Row(2, Map(Row(20) -> Row("w", null)), "engineering"))) + } else { + val exception = intercept[org.apache.spark.sql.AnalysisException] { + sql(mergeStmt) + } + assert(exception.errorClass.get == + "INCOMPATIBLE_DATA_FOR_TABLE.CANNOT_FIND_DATA") + assert(exception.getMessage.contains( + "Cannot write incompatible data for the table ``: " + + "Cannot find data for the output column `m`.`value`.`c2`.")) + } + } + sql(s"DROP TABLE IF EXISTS $tableNameAsString") + } } - sql(s"DROP TABLE IF EXISTS $tableNameAsString") } test("merge into with source missing fields in top-level struct") { - withTempView("source") { - // Target table has struct with 3 fields at top level - createAndInitTable( - s"""pk INT NOT NULL, - |s STRUCT, - |dep STRING""".stripMargin, - """{ "pk": 0, "s": { "c1": 1, "c2": "a", "c3": true }, "dep": "sales"}""") + Seq(true, false).foreach { coerceNestedTypes => + withSQLConf(SQLConf.MERGE_INTO_NESTED_TYPE_COERCION_ENABLED.key -> + coerceNestedTypes.toString) { + withTempView("source") { + // Target table has struct with 3 fields at top level + createAndInitTable( + s"""pk INT NOT NULL, + |s STRUCT, + |dep STRING""".stripMargin, + """{ "pk": 0, "s": { "c1": 1, "c2": "a", "c3": true }, "dep": "sales"}""") - // Source table has struct with only 2 fields (c1, c2) - missing c3 - val sourceTableSchema = StructType(Seq( - StructField("pk", IntegerType, nullable = false), - StructField("s", StructType(Seq( - StructField("c1", IntegerType), - StructField("c2", StringType)))), // missing c3 field - StructField("dep", StringType))) - val data = Seq( - Row(1, Row(10, "b"), "hr"), - Row(2, Row(20, "c"), "engineering") - ) - spark.createDataFrame(spark.sparkContext.parallelize(data), sourceTableSchema) - .createOrReplaceTempView("source") + // Source table has struct with only 2 fields (c1, c2) - missing c3 + val sourceTableSchema = StructType(Seq( + StructField("pk", IntegerType, nullable = false), + StructField("s", StructType(Seq( + StructField("c1", IntegerType), + StructField("c2", StringType)))), // missing c3 field + StructField("dep", StringType))) + val data = Seq( + Row(1, Row(10, "b"), "hr"), + Row(2, Row(20, "c"), "engineering") + ) + spark.createDataFrame(spark.sparkContext.parallelize(data), sourceTableSchema) + .createOrReplaceTempView("source") - sql( - s"""MERGE INTO $tableNameAsString t - |USING source src - |ON t.pk = src.pk - |WHEN MATCHED THEN - | UPDATE SET * - |WHEN NOT MATCHED THEN - | INSERT * - |""".stripMargin) + val mergeStmt = + s"""MERGE INTO $tableNameAsString t + |USING source src + |ON t.pk = src.pk + |WHEN MATCHED THEN + | UPDATE SET * + |WHEN NOT MATCHED THEN + | INSERT * + |""".stripMargin - // Missing field c3 should be filled with NULL - checkAnswer( - sql(s"SELECT * FROM $tableNameAsString"), - Seq( - Row(0, Row(1, "a", true), "sales"), - Row(1, Row(10, "b", null), "hr"), - Row(2, Row(20, "c", null), "engineering"))) + if (coerceNestedTypes) { + sql(mergeStmt) + // Missing field c3 should be filled with NULL + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Seq( + Row(0, Row(1, "a", true), "sales"), + Row(1, Row(10, "b", null), "hr"), + Row(2, Row(20, "c", null), "engineering"))) + } else { + val exception = intercept[org.apache.spark.sql.AnalysisException] { + sql(mergeStmt) + } + assert(exception.errorClass.get == + "INCOMPATIBLE_DATA_FOR_TABLE.CANNOT_FIND_DATA") + assert(exception.getMessage.contains( + "Cannot write incompatible data for the table ``: " + + "Cannot find data for the output column `s`.`c3`.")) + } + } + sql(s"DROP TABLE IF EXISTS $tableNameAsString") + } } - sql(s"DROP TABLE IF EXISTS $tableNameAsString") } test("merge with null struct") { @@ -4934,128 +5115,142 @@ abstract class MergeIntoTableSuiteBase extends RowLevelOperationSuiteBase test("merge null struct with non-nullable nested field - source with missing " + "and extra nested fields") { - withSQLConf( - SQLConf.MERGE_INTO_NESTED_TYPE_COERCION_ENABLED.key -> "true") { - withTempView("source") { - // Target table has nested struct with NON-NULLABLE field b - createAndInitTable( - s"""pk INT NOT NULL, - |s STRUCT>, - |dep STRING""".stripMargin, - """{ "pk": 0, "s": { "c1": 1, "c2": { "a": 10, "b": "x" } }, "dep": "sales" } - |{ "pk": 1, "s": { "c1": 2, "c2": { "a": 20, "b": "y" } }, "dep": "hr" }""" - .stripMargin) + Seq(true, false).foreach { withSchemaEvolution => + Seq(true, false).foreach { coerceNestedTypes => + withSQLConf(SQLConf.MERGE_INTO_NESTED_TYPE_COERCION_ENABLED.key -> + coerceNestedTypes.toString) { + withTempView("source") { + createAndInitTable( + s"""pk INT NOT NULL, + |s STRUCT>, + |dep STRING""".stripMargin, + """{ "pk": 0, "s": { "c1": 1, "c2": { "a": 10, "b": "x" } }, "dep": "sales" } + |{ "pk": 1, "s": { "c1": 2, "c2": { "a": 20, "b": "y" } }, "dep": "hr" }""" + .stripMargin) - // Source table has missing field 'b' and extra field 'c' in nested struct - val sourceTableSchema = StructType(Seq( - StructField("pk", IntegerType), - StructField("s", StructType(Seq( - StructField("c1", IntegerType), - StructField("c2", StructType(Seq( - StructField("a", IntegerType), - // missing field 'b' (which is non-nullable in target) - StructField("c", StringType) // extra field 'c' - ))) - ))), - StructField("dep", StringType) - )) + val sourceTableSchema = StructType(Seq( + StructField("pk", IntegerType), + StructField("s", StructType(Seq( + StructField("c1", IntegerType), + StructField("c2", StructType(Seq( + StructField("a", IntegerType), + StructField("c", StringType) + ))) + ))), + StructField("dep", StringType) + )) - val data = Seq( - Row(1, null, "engineering"), - Row(2, null, "finance") - ) - spark.createDataFrame(spark.sparkContext.parallelize(data), sourceTableSchema) - .createOrReplaceTempView("source") + val data = Seq( + Row(1, null, "engineering"), + Row(2, null, "finance") + ) + spark.createDataFrame(spark.sparkContext.parallelize(data), sourceTableSchema) + .createOrReplaceTempView("source") - val mergeStmt = - s"""MERGE WITH SCHEMA EVOLUTION - |INTO $tableNameAsString t USING source - |ON t.pk = source.pk - |WHEN MATCHED THEN - | UPDATE SET * - |WHEN NOT MATCHED THEN - | INSERT * - |""".stripMargin + val schemaEvolutionClause = if (withSchemaEvolution) "WITH SCHEMA EVOLUTION" else "" + val mergeStmt = + s"""MERGE $schemaEvolutionClause + |INTO $tableNameAsString t USING source + |ON t.pk = source.pk + |WHEN MATCHED THEN + | UPDATE SET * + |WHEN NOT MATCHED THEN + | INSERT * + |""".stripMargin - // All cases should fail due to non-nullable constraint violation - val exception = intercept[org.apache.spark.sql.AnalysisException] { - sql(mergeStmt) + val exception = intercept[org.apache.spark.sql.AnalysisException] { + sql(mergeStmt) + } + assert(exception.errorClass.get == + "INCOMPATIBLE_DATA_FOR_TABLE.CANNOT_FIND_DATA") + assert(exception.getMessage.contains( + "Cannot find data for the output column `s`.`c2`.`b`")) + } + sql(s"DROP TABLE IF EXISTS $tableNameAsString") } - assert(exception.errorClass.get == "INCOMPATIBLE_DATA_FOR_TABLE.CANNOT_FIND_DATA") - assert(exception.getMessage.contains("Cannot write incompatible data for the table ``: " + - "Cannot find data for the output column `s`.`c2`.`b`.")) } - sql(s"DROP TABLE IF EXISTS $tableNameAsString") } } test("merge with null struct using default value") { - withTempView("source") { - // Target table has nested struct with a default value - sql( - s"""CREATE TABLE $tableNameAsString ( - | pk INT NOT NULL, - | s STRUCT> DEFAULT - | named_struct('c1', 999, 'c2', named_struct('a', 999, 'b', 'default')), - | dep STRING) - |PARTITIONED BY (dep) - |""".stripMargin) + Seq(true, false).foreach { coerceNestedTypes => + withSQLConf(SQLConf.MERGE_INTO_NESTED_TYPE_COERCION_ENABLED.key -> + coerceNestedTypes.toString) { + withTempView("source") { + sql( + s"""CREATE TABLE $tableNameAsString ( + | pk INT NOT NULL, + | s STRUCT> DEFAULT + | named_struct('c1', 999, 'c2', named_struct('a', 999, 'b', 'default')), + | dep STRING) + |PARTITIONED BY (dep) + |""".stripMargin) - // Insert initial data using DataFrame API - val initialSchema = StructType(Seq( - StructField("pk", IntegerType, nullable = false), - StructField("s", StructType(Seq( - StructField("c1", IntegerType), - StructField("c2", StructType(Seq( - StructField("a", IntegerType), - StructField("b", StringType) - ))) - ))), - StructField("dep", StringType) - )) - val initialData = Seq( - Row(0, Row(1, Row(10, "x")), "sales"), - Row(1, Row(2, Row(20, "y")), "hr") - ) - spark.createDataFrame(spark.sparkContext.parallelize(initialData), initialSchema) - .writeTo(tableNameAsString).append() + val initialSchema = StructType(Seq( + StructField("pk", IntegerType, nullable = false), + StructField("s", StructType(Seq( + StructField("c1", IntegerType), + StructField("c2", StructType(Seq( + StructField("a", IntegerType), + StructField("b", StringType) + ))) + ))), + StructField("dep", StringType) + )) + val initialData = Seq( + Row(0, Row(1, Row(10, "x")), "sales"), + Row(1, Row(2, Row(20, "y")), "hr") + ) + spark.createDataFrame(spark.sparkContext.parallelize(initialData), initialSchema) + .writeTo(tableNameAsString).append() - // Source table has null for the nested struct - val sourceTableSchema = StructType(Seq( - StructField("pk", IntegerType), - StructField("s", StructType(Seq( - StructField("c1", IntegerType), - StructField("c2", StructType(Seq( - StructField("a", IntegerType) - // missing field 'b' - ))) - ))), - StructField("dep", StringType) - )) + val sourceTableSchema = StructType(Seq( + StructField("pk", IntegerType), + StructField("s", StructType(Seq( + StructField("c1", IntegerType), + StructField("c2", StructType(Seq( + StructField("a", IntegerType) + ))) + ))), + StructField("dep", StringType) + )) - val data = Seq( - Row(1, null, "engineering"), - Row(2, null, "finance") - ) - spark.createDataFrame(spark.sparkContext.parallelize(data), sourceTableSchema) - .createOrReplaceTempView("source") + val data = Seq( + Row(1, null, "engineering"), + Row(2, null, "finance") + ) + spark.createDataFrame(spark.sparkContext.parallelize(data), sourceTableSchema) + .createOrReplaceTempView("source") - sql( - s"""MERGE INTO $tableNameAsString t USING source - |ON t.pk = source.pk - |WHEN MATCHED THEN - | UPDATE SET * - |WHEN NOT MATCHED THEN - | INSERT * - |""".stripMargin) - checkAnswer( - sql(s"SELECT * FROM $tableNameAsString"), - Seq( - Row(0, Row(1, Row(10, "x")), "sales"), - Row(1, null, "engineering"), - Row(2, null, "finance"))) + val mergeStmt = + s"""MERGE INTO $tableNameAsString t USING source + |ON t.pk = source.pk + |WHEN MATCHED THEN + | UPDATE SET * + |WHEN NOT MATCHED THEN + | INSERT * + |""".stripMargin + + if (coerceNestedTypes) { + sql(mergeStmt) + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Seq( + Row(0, Row(1, Row(10, "x")), "sales"), + Row(1, null, "engineering"), + Row(2, null, "finance"))) + } else { + val exception = intercept[org.apache.spark.sql.AnalysisException] { + sql(mergeStmt) + } + assert(exception.errorClass.get == "INCOMPATIBLE_DATA_FOR_TABLE.CANNOT_FIND_DATA") + assert(exception.getMessage.contains( + "Cannot find data for the output column `s`.`c2`.`b`")) + } + } + sql(s"DROP TABLE IF EXISTS $tableNameAsString") + } } - sql(s"DROP TABLE IF EXISTS $tableNameAsString") } @@ -5170,8 +5365,8 @@ abstract class MergeIntoTableSuiteBase extends RowLevelOperationSuiteBase checkAnswer( sql(s"SELECT * FROM $tableNameAsString"), Seq( - Row(1, Row(10, Row(20, true)), "sales"), - Row(2, Row(20, Row(30, false)), "engineering"))) + Row(1, Row(10, Row(20, null)), "sales"), + Row(2, Row(20, Row(30, null)), "engineering"))) } else { val exception = intercept[Exception] { sql(mergeStmt) @@ -5254,222 +5449,46 @@ abstract class MergeIntoTableSuiteBase extends RowLevelOperationSuiteBase .writeTo(tableNameAsString).append() val sourceIdent = Identifier.of(Array("ns1"), "source_table") - val columns = Array( - Column.create("pk", IntegerType, false), - Column.create("salary", IntegerType), - Column.create("dep", StringType), - Column.create("new_col", IntegerType)) - val tableInfo = new TableInfo.Builder() - .withColumns(columns) - .withProperties(extraTableProps) - .build() - catalog.createTable(sourceIdent, tableInfo) - - sql(s"INSERT INTO $sourceTable VALUES (1, 101, 'support', 1)," + - s"(3, 301, 'support', 3), (4, 401, 'finance', 4)") - - val mergeBuilder = spark.table(sourceTable) - .mergeInto(tableNameAsString, - $"source_table.pk" === col(tableNameAsString + ".pk")) - .whenMatched() - .updateAll() - .whenNotMatched() - .insertAll() - - if (withSchemaEvolution) { - mergeBuilder.withSchemaEvolution().merge() - checkAnswer( - sql(s"SELECT * FROM $tableNameAsString"), - Seq( - Row(1, 101, "support", 1), - Row(2, 200, "software", null), - Row(3, 301, "support", 3), - Row(4, 401, "finance", 4))) - } else { - mergeBuilder.merge() - checkAnswer( - sql(s"SELECT * FROM $tableNameAsString"), - Seq( - Row(1, 101, "support"), - Row(2, 200, "software"), - Row(3, 301, "support"), - Row(4, 401, "finance"))) - } - - sql(s"DROP TABLE $tableNameAsString") - } - } - } - - test("merge schema evolution new column with set explicit column using dataframe API") { - Seq(true, false).foreach { withSchemaEvolution => - val sourceTable = "cat.ns1.source_table" - withTable(sourceTable) { - sql(s"CREATE TABLE $tableNameAsString (pk INT NOT NULL, salary INT, dep STRING)") - - val targetData = Seq( - Row(1, 100, "hr"), - Row(2, 200, "software"), - Row(3, 300, "hr"), - Row(4, 400, "marketing"), - Row(5, 500, "executive") - ) - val targetSchema = StructType(Seq( - StructField("pk", IntegerType, nullable = false), - StructField("salary", IntegerType), - StructField("dep", StringType) - )) - spark.createDataFrame(spark.sparkContext.parallelize(targetData), targetSchema) - .writeTo(tableNameAsString).append() - - val sourceIdent = Identifier.of(Array("ns1"), "source_table") - val columns = Array( - Column.create("pk", IntegerType, false), - Column.create("salary", IntegerType), - Column.create("dep", StringType), - Column.create("active", BooleanType)) - val tableInfo = new TableInfo.Builder() - .withColumns(columns) - .withProperties(extraTableProps) - .build() - catalog.createTable(sourceIdent, tableInfo) - - sql(s"INSERT INTO $sourceTable VALUES (4, 150, 'dummy', true)," + - s"(5, 250, 'dummy', true), (6, 350, 'dummy', false)") - - val mergeBuilder = spark.table(sourceTable) - .mergeInto(tableNameAsString, $"source_table.pk" === col(tableNameAsString + ".pk")) - .whenMatched() - .update(Map("dep" -> lit("software"), "active" -> col("source_table.active"))) - .whenNotMatched() - .insert(Map("pk" -> col("source_table.pk"), "salary" -> lit(0), - "dep" -> col("source_table.dep"), "active" -> col("source_table.active"))) - - if (withSchemaEvolution) { - mergeBuilder.withSchemaEvolution().merge() - checkAnswer( - sql(s"SELECT * FROM $tableNameAsString"), - Seq( - Row(1, 100, "hr", null), - Row(2, 200, "software", null), - Row(3, 300, "hr", null), - Row(4, 400, "software", true), - Row(5, 500, "software", true), - Row(6, 0, "dummy", false))) - } else { - val e = intercept[org.apache.spark.sql.AnalysisException] { - mergeBuilder.merge() - } - assert(e.errorClass.get == "UNRESOLVED_COLUMN.WITH_SUGGESTION") - assert(e.getMessage.contains("A column, variable, or function parameter with name " + - "`active` cannot be resolved")) - } - - sql(s"DROP TABLE $tableNameAsString") - } - } - } - - test("merge schema evolution add column with nested struct and set explicit columns " + - "using dataframe API") { - Seq(true, false).foreach { withSchemaEvolution => - val sourceTable = "cat.ns1.source_table" - withTable(sourceTable) { - sql( - s"""CREATE TABLE $tableNameAsString ( - |pk INT NOT NULL, - |s STRUCT, m: MAP>>, - |dep STRING)""".stripMargin) - - val targetData = Seq( - Row(1, Row(2, Row(Array(1, 2), Map("a" -> "b"))), "hr") - ) - val targetSchema = StructType(Seq( - StructField("pk", IntegerType, nullable = false), - StructField("s", StructType(Seq( - StructField("c1", IntegerType), - StructField("c2", StructType(Seq( - StructField("a", ArrayType(IntegerType)), - StructField("m", MapType(StringType, StringType)) - ))) - ))), - StructField("dep", StringType) - )) - spark.createDataFrame(spark.sparkContext.parallelize(targetData), targetSchema) - .writeTo(tableNameAsString).append() - - val sourceIdent = Identifier.of(Array("ns1"), "source_table") - val columns = Array( - Column.create("pk", IntegerType, false), - Column.create("s", StructType(Seq( - StructField("c1", IntegerType), - StructField("c2", StructType(Seq( - StructField("a", ArrayType(IntegerType)), - StructField("m", MapType(StringType, StringType)), - StructField("c3", BooleanType) // new column - ))) - ))), - Column.create("dep", StringType)) - val tableInfo = new TableInfo.Builder() - .withColumns(columns) - .withProperties(extraTableProps) - .build() - catalog.createTable(sourceIdent, tableInfo) - - val data = Seq( - Row(1, Row(10, Row(Array(3, 4), Map("c" -> "d"), false)), "sales"), - Row(2, Row(20, Row(Array(4, 5), Map("e" -> "f"), true)), "engineering") - ) - val sourceTableSchema = StructType(Seq( - StructField("pk", IntegerType, nullable = false), - StructField("s", StructType(Seq( - StructField("c1", IntegerType), - StructField("c2", StructType(Seq( - StructField("a", ArrayType(IntegerType)), - StructField("m", MapType(StringType, StringType)), - StructField("c3", BooleanType) - ))) - ))), - StructField("dep", StringType) - )) - spark.createDataFrame(spark.sparkContext.parallelize(data), sourceTableSchema) - .createOrReplaceTempView("source_temp") + val columns = Array( + Column.create("pk", IntegerType, false), + Column.create("salary", IntegerType), + Column.create("dep", StringType), + Column.create("new_col", IntegerType)) + val tableInfo = new TableInfo.Builder() + .withColumns(columns) + .withProperties(extraTableProps) + .build() + catalog.createTable(sourceIdent, tableInfo) - sql(s"INSERT INTO $sourceTable SELECT * FROM source_temp") + sql(s"INSERT INTO $sourceTable VALUES (1, 101, 'support', 1)," + + s"(3, 301, 'support', 3), (4, 401, 'finance', 4)") val mergeBuilder = spark.table(sourceTable) - .mergeInto(tableNameAsString, $"source_table.pk" === col(tableNameAsString + ".pk")) + .mergeInto(tableNameAsString, + $"source_table.pk" === col(tableNameAsString + ".pk")) .whenMatched() - .update(Map( - "s.c1" -> lit(-1), - "s.c2.m" -> map(lit("k"), lit("v")), - "s.c2.a" -> array(lit(-1)), - "s.c2.c3" -> col("source_table.s.c2.c3"))) + .updateAll() .whenNotMatched() - .insert(Map( - "pk" -> col("source_table.pk"), - "s" -> struct( - col("source_table.s.c1").as("c1"), - struct( - col("source_table.s.c2.a").as("a"), - map(lit("g"), lit("h")).as("m"), - lit(true).as("c3") - ).as("c2") - ), - "dep" -> col("source_table.dep"))) + .insertAll() if (withSchemaEvolution) { mergeBuilder.withSchemaEvolution().merge() checkAnswer( sql(s"SELECT * FROM $tableNameAsString"), - Seq(Row(1, Row(-1, Row(Seq(-1), Map("k" -> "v"), false)), "hr"), - Row(2, Row(20, Row(Seq(4, 5), Map("g" -> "h"), true)), "engineering"))) + Seq( + Row(1, 101, "support", 1), + Row(2, 200, "software", null), + Row(3, 301, "support", 3), + Row(4, 401, "finance", 4))) } else { - val exception = intercept[org.apache.spark.sql.AnalysisException] { - mergeBuilder.merge() - } - assert(exception.errorClass.get == "FIELD_NOT_FOUND") - assert(exception.getMessage.contains("No such struct field `c3` in `a`, `m`. ")) + mergeBuilder.merge() + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Seq( + Row(1, 101, "support"), + Row(2, 200, "software"), + Row(3, 301, "support"), + Row(4, 401, "finance"))) } sql(s"DROP TABLE $tableNameAsString") @@ -5477,29 +5496,22 @@ abstract class MergeIntoTableSuiteBase extends RowLevelOperationSuiteBase } } - test("merge schema evolution add column with nested struct and set all columns " + - "using dataframe API") { + test("merge schema evolution new column with set explicit column using dataframe API") { Seq(true, false).foreach { withSchemaEvolution => val sourceTable = "cat.ns1.source_table" withTable(sourceTable) { - sql( - s"""CREATE TABLE $tableNameAsString ( - |pk INT NOT NULL, - |s STRUCT, m: MAP>>, - |dep STRING)""".stripMargin) + sql(s"CREATE TABLE $tableNameAsString (pk INT NOT NULL, salary INT, dep STRING)") val targetData = Seq( - Row(1, Row(2, Row(Array(1, 2), Map("a" -> "b"))), "hr") + Row(1, 100, "hr"), + Row(2, 200, "software"), + Row(3, 300, "hr"), + Row(4, 400, "marketing"), + Row(5, 500, "executive") ) val targetSchema = StructType(Seq( StructField("pk", IntegerType, nullable = false), - StructField("s", StructType(Seq( - StructField("c1", IntegerType), - StructField("c2", StructType(Seq( - StructField("a", ArrayType(IntegerType)), - StructField("m", MapType(StringType, StringType)) - ))) - ))), + StructField("salary", IntegerType), StructField("dep", StringType) )) spark.createDataFrame(spark.sparkContext.parallelize(targetData), targetSchema) @@ -5508,62 +5520,44 @@ abstract class MergeIntoTableSuiteBase extends RowLevelOperationSuiteBase val sourceIdent = Identifier.of(Array("ns1"), "source_table") val columns = Array( Column.create("pk", IntegerType, false), - Column.create("s", StructType(Seq( - StructField("c1", IntegerType), - StructField("c2", StructType(Seq( - StructField("a", ArrayType(IntegerType)), - StructField("m", MapType(StringType, StringType)), - StructField("c3", BooleanType) // new column - ))) - ))), - Column.create("dep", StringType)) + Column.create("salary", IntegerType), + Column.create("dep", StringType), + Column.create("active", BooleanType)) val tableInfo = new TableInfo.Builder() .withColumns(columns) .withProperties(extraTableProps) .build() catalog.createTable(sourceIdent, tableInfo) - val data = Seq( - Row(1, Row(10, Row(Array(3, 4), Map("c" -> "d"), false)), "sales"), - Row(2, Row(20, Row(Array(4, 5), Map("e" -> "f"), true)), "engineering") - ) - val sourceTableSchema = StructType(Seq( - StructField("pk", IntegerType, nullable = false), - StructField("s", StructType(Seq( - StructField("c1", IntegerType), - StructField("c2", StructType(Seq( - StructField("a", ArrayType(IntegerType)), - StructField("m", MapType(StringType, StringType)), - StructField("c3", BooleanType) - ))) - ))), - StructField("dep", StringType) - )) - spark.createDataFrame(spark.sparkContext.parallelize(data), sourceTableSchema) - .createOrReplaceTempView("source_temp") - - sql(s"INSERT INTO $sourceTable SELECT * FROM source_temp") + sql(s"INSERT INTO $sourceTable VALUES (4, 150, 'dummy', true)," + + s"(5, 250, 'dummy', true), (6, 350, 'dummy', false)") val mergeBuilder = spark.table(sourceTable) .mergeInto(tableNameAsString, $"source_table.pk" === col(tableNameAsString + ".pk")) .whenMatched() - .updateAll() + .update(Map("dep" -> lit("software"), "active" -> col("source_table.active"))) .whenNotMatched() - .insertAll() + .insert(Map("pk" -> col("source_table.pk"), "salary" -> lit(0), + "dep" -> col("source_table.dep"), "active" -> col("source_table.active"))) if (withSchemaEvolution) { mergeBuilder.withSchemaEvolution().merge() checkAnswer( sql(s"SELECT * FROM $tableNameAsString"), - Seq(Row(1, Row(10, Row(Seq(3, 4), Map("c" -> "d"), false)), "sales"), - Row(2, Row(20, Row(Seq(4, 5), Map("e" -> "f"), true)), "engineering"))) + Seq( + Row(1, 100, "hr", null), + Row(2, 200, "software", null), + Row(3, 300, "hr", null), + Row(4, 400, "software", true), + Row(5, 500, "software", true), + Row(6, 0, "dummy", false))) } else { - val exception = intercept[org.apache.spark.sql.AnalysisException] { + val e = intercept[org.apache.spark.sql.AnalysisException] { mergeBuilder.merge() } - assert(exception.errorClass.get == "INCOMPATIBLE_DATA_FOR_TABLE.EXTRA_STRUCT_FIELDS") - assert(exception.getMessage.contains( - "Cannot write extra fields `c3` to the struct `s`.`c2`")) + assert(e.errorClass.get == "UNRESOLVED_COLUMN.WITH_SUGGESTION") + assert(e.getMessage.contains("A column, variable, or function parameter with name " + + "`active` cannot be resolved")) } sql(s"DROP TABLE $tableNameAsString") @@ -5571,8 +5565,8 @@ abstract class MergeIntoTableSuiteBase extends RowLevelOperationSuiteBase } } - test("merge schema evolution replace column with nested struct and " + - "set explicit columns using dataframe API") { + test("merge schema evolution add column with nested struct and set explicit columns " + + "using dataframe API") { Seq(true, false).foreach { withSchemaEvolution => val sourceTable = "cat.ns1.source_table" withTable(sourceTable) { @@ -5605,7 +5599,7 @@ abstract class MergeIntoTableSuiteBase extends RowLevelOperationSuiteBase Column.create("s", StructType(Seq( StructField("c1", IntegerType), StructField("c2", StructType(Seq( - // removed column 'a' + StructField("a", ArrayType(IntegerType)), StructField("m", MapType(StringType, StringType)), StructField("c3", BooleanType) // new column ))) @@ -5618,14 +5612,15 @@ abstract class MergeIntoTableSuiteBase extends RowLevelOperationSuiteBase catalog.createTable(sourceIdent, tableInfo) val data = Seq( - Row(1, Row(10, Row(Map("c" -> "d"), false)), "sales"), - Row(2, Row(20, Row(Map("e" -> "f"), true)), "engineering") + Row(1, Row(10, Row(Array(3, 4), Map("c" -> "d"), false)), "sales"), + Row(2, Row(20, Row(Array(4, 5), Map("e" -> "f"), true)), "engineering") ) val sourceTableSchema = StructType(Seq( StructField("pk", IntegerType, nullable = false), StructField("s", StructType(Seq( StructField("c1", IntegerType), StructField("c2", StructType(Seq( + StructField("a", ArrayType(IntegerType)), StructField("m", MapType(StringType, StringType)), StructField("c3", BooleanType) ))) @@ -5651,7 +5646,7 @@ abstract class MergeIntoTableSuiteBase extends RowLevelOperationSuiteBase "s" -> struct( col("source_table.s.c1").as("c1"), struct( - array(lit(-2)).as("a"), + col("source_table.s.c2.a").as("a"), map(lit("g"), lit("h")).as("m"), lit(true).as("c3") ).as("c2") @@ -5663,7 +5658,7 @@ abstract class MergeIntoTableSuiteBase extends RowLevelOperationSuiteBase checkAnswer( sql(s"SELECT * FROM $tableNameAsString"), Seq(Row(1, Row(-1, Row(Seq(-1), Map("k" -> "v"), false)), "hr"), - Row(2, Row(20, Row(Seq(-2), Map("g" -> "h"), true)), "engineering"))) + Row(2, Row(20, Row(Seq(4, 5), Map("g" -> "h"), true)), "engineering"))) } else { val exception = intercept[org.apache.spark.sql.AnalysisException] { mergeBuilder.merge() @@ -5677,7 +5672,7 @@ abstract class MergeIntoTableSuiteBase extends RowLevelOperationSuiteBase } } - test("merge schema evolution replace column with nested struct and set all columns " + + test("merge schema evolution add column with nested struct and set all columns " + "using dataframe API") { Seq(true, false).foreach { withSchemaEvolution => val sourceTable = "cat.ns1.source_table" @@ -5686,26 +5681,24 @@ abstract class MergeIntoTableSuiteBase extends RowLevelOperationSuiteBase s"""CREATE TABLE $tableNameAsString ( |pk INT NOT NULL, |s STRUCT, m: MAP>>, - |dep STRING) - |PARTITIONED BY (dep) - |""".stripMargin) + |dep STRING)""".stripMargin) - val tableSchema = StructType(Seq( - StructField("pk", IntegerType, nullable = false), - StructField("s", StructType(Seq( - StructField("c1", IntegerType), - StructField("c2", StructType(Seq( - StructField("a", ArrayType(IntegerType)), - StructField("m", MapType(StringType, StringType)) - ))) - ))), - StructField("dep", StringType) - )) - val targetData = Seq( - Row(1, Row(2, Row(Array(1, 2), Map("a" -> "b"))), "hr") - ) - spark.createDataFrame(spark.sparkContext.parallelize(targetData), tableSchema) - .coalesce(1).writeTo(tableNameAsString).append() + val targetData = Seq( + Row(1, Row(2, Row(Array(1, 2), Map("a" -> "b"))), "hr") + ) + val targetSchema = StructType(Seq( + StructField("pk", IntegerType, nullable = false), + StructField("s", StructType(Seq( + StructField("c1", IntegerType), + StructField("c2", StructType(Seq( + StructField("a", ArrayType(IntegerType)), + StructField("m", MapType(StringType, StringType)) + ))) + ))), + StructField("dep", StringType) + )) + spark.createDataFrame(spark.sparkContext.parallelize(targetData), targetSchema) + .writeTo(tableNameAsString).append() val sourceIdent = Identifier.of(Array("ns1"), "source_table") val columns = Array( @@ -5713,7 +5706,7 @@ abstract class MergeIntoTableSuiteBase extends RowLevelOperationSuiteBase Column.create("s", StructType(Seq( StructField("c1", IntegerType), StructField("c2", StructType(Seq( - // missing column 'a' + StructField("a", ArrayType(IntegerType)), StructField("m", MapType(StringType, StringType)), StructField("c3", BooleanType) // new column ))) @@ -5725,22 +5718,23 @@ abstract class MergeIntoTableSuiteBase extends RowLevelOperationSuiteBase .build() catalog.createTable(sourceIdent, tableInfo) - val sourceData = Seq( - Row(1, Row(10, Row(Map("c" -> "d"), false)), "sales"), - Row(2, Row(20, Row(Map("e" -> "f"), true)), "engineering") + val data = Seq( + Row(1, Row(10, Row(Array(3, 4), Map("c" -> "d"), false)), "sales"), + Row(2, Row(20, Row(Array(4, 5), Map("e" -> "f"), true)), "engineering") ) val sourceTableSchema = StructType(Seq( StructField("pk", IntegerType, nullable = false), StructField("s", StructType(Seq( StructField("c1", IntegerType), StructField("c2", StructType(Seq( + StructField("a", ArrayType(IntegerType)), StructField("m", MapType(StringType, StringType)), StructField("c3", BooleanType) ))) ))), StructField("dep", StringType) )) - spark.createDataFrame(spark.sparkContext.parallelize(sourceData), sourceTableSchema) + spark.createDataFrame(spark.sparkContext.parallelize(data), sourceTableSchema) .createOrReplaceTempView("source_temp") sql(s"INSERT INTO $sourceTable SELECT * FROM source_temp") @@ -5748,105 +5742,7 @@ abstract class MergeIntoTableSuiteBase extends RowLevelOperationSuiteBase val mergeBuilder = spark.table(sourceTable) .mergeInto(tableNameAsString, $"source_table.pk" === col(tableNameAsString + ".pk")) .whenMatched() - .updateAll() - .whenNotMatched() - .insertAll() - - if (withSchemaEvolution) { - mergeBuilder.withSchemaEvolution().merge() - checkAnswer( - sql(s"SELECT * FROM $tableNameAsString"), - Seq( - Row(1, Row(10, Row(Seq(1, 2), Map("c" -> "d"), false)), "sales"), - Row(2, Row(20, Row(null, Map("e" -> "f"), true)), "engineering"))) - } else { - val exception = intercept[org.apache.spark.sql.AnalysisException] { - mergeBuilder.merge() - } - assert(exception.errorClass.get == - "INCOMPATIBLE_DATA_FOR_TABLE.EXTRA_STRUCT_FIELDS") - assert(exception.getMessage.contains( - "Cannot write extra fields `c3` to the struct `s`.`c2`")) - } - - sql(s"DROP TABLE $tableNameAsString") - } - } - } - - test("merge schema evolution replace column with nested struct and " + - "update top level struct using dataframe API") { - Seq(true, false).foreach { withSchemaEvolution => - val sourceTable = "cat.ns1.source_table" - withTable(sourceTable) { - sql( - s"""CREATE TABLE $tableNameAsString ( - |pk INT NOT NULL, - |s STRUCT, m: MAP>>, - |dep STRING) - |PARTITIONED BY (dep) - |""".stripMargin) - - val tableSchema = StructType(Seq( - StructField("pk", IntegerType, nullable = false), - StructField("s", StructType(Seq( - StructField("c1", IntegerType), - StructField("c2", StructType(Seq( - StructField("a", ArrayType(IntegerType)), - StructField("m", MapType(StringType, StringType)) - ))) - ))), - StructField("dep", StringType) - )) - val targetData = Seq( - Row(1, Row(2, Row(Array(1, 2), Map("a" -> "b"))), "hr") - ) - spark.createDataFrame(spark.sparkContext.parallelize(targetData), tableSchema) - .coalesce(1).writeTo(tableNameAsString).append() - - // Create source table - val sourceIdent = Identifier.of(Array("ns1"), "source_table") - val columns = Array( - Column.create("pk", IntegerType, false), - Column.create("s", StructType(Seq( - StructField("c1", IntegerType), - StructField("c2", StructType(Seq( - // missing column 'a' - StructField("m", MapType(StringType, StringType)), - StructField("c3", BooleanType) // new column - ))) - ))), - Column.create("dep", StringType)) - val tableInfo = new TableInfo.Builder() - .withColumns(columns) - .withProperties(extraTableProps) - .build() - catalog.createTable(sourceIdent, tableInfo) - - val sourceData = Seq( - Row(1, Row(10, Row(Map("c" -> "d"), false)), "sales"), - Row(2, Row(20, Row(Map("e" -> "f"), true)), "engineering") - ) - val sourceTableSchema = StructType(Seq( - StructField("pk", IntegerType, nullable = false), - StructField("s", StructType(Seq( - StructField("c1", IntegerType), - StructField("c2", StructType(Seq( - StructField("m", MapType(StringType, StringType)), - StructField("c3", BooleanType) - ))) - ))), - StructField("dep", StringType) - )) - spark.createDataFrame(spark.sparkContext.parallelize(sourceData), sourceTableSchema) - .createOrReplaceTempView("source_temp") - - sql(s"INSERT INTO $sourceTable SELECT * FROM source_temp") - - val mergeBuilder = spark.table(sourceTable) - .mergeInto(tableNameAsString, $"source_table.pk" === col(tableNameAsString + ".pk")) - .whenMatched() - .update(Map("s" -> col("source_table.s"))) + .updateAll() .whenNotMatched() .insertAll() @@ -5854,15 +5750,13 @@ abstract class MergeIntoTableSuiteBase extends RowLevelOperationSuiteBase mergeBuilder.withSchemaEvolution().merge() checkAnswer( sql(s"SELECT * FROM $tableNameAsString"), - Seq( - Row(1, Row(10, Row(null, Map("c" -> "d"), false)), "hr"), - Row(2, Row(20, Row(null, Map("e" -> "f"), true)), "engineering"))) + Seq(Row(1, Row(10, Row(Seq(3, 4), Map("c" -> "d"), false)), "sales"), + Row(2, Row(20, Row(Seq(4, 5), Map("e" -> "f"), true)), "engineering"))) } else { val exception = intercept[org.apache.spark.sql.AnalysisException] { mergeBuilder.merge() } - assert(exception.errorClass.get == - "INCOMPATIBLE_DATA_FOR_TABLE.EXTRA_STRUCT_FIELDS") + assert(exception.errorClass.get == "INCOMPATIBLE_DATA_FOR_TABLE.EXTRA_STRUCT_FIELDS") assert(exception.getMessage.contains( "Cannot write extra fields `c3` to the struct `s`.`c2`")) } @@ -5872,6 +5766,339 @@ abstract class MergeIntoTableSuiteBase extends RowLevelOperationSuiteBase } } + test("merge schema evolution replace column with nested struct and " + + "set explicit columns using dataframe API") { + Seq(true, false).foreach { withSchemaEvolution => + Seq(true, false).foreach { coerceNestedTypes => + withSQLConf(SQLConf.MERGE_INTO_NESTED_TYPE_COERCION_ENABLED.key -> + coerceNestedTypes.toString) { + val sourceTable = "cat.ns1.source_table" + withTable(sourceTable) { + sql( + s"""CREATE TABLE $tableNameAsString ( + |pk INT NOT NULL, + |s STRUCT, m: MAP>>, + |dep STRING)""".stripMargin) + + val targetData = Seq( + Row(1, Row(2, Row(Array(1, 2), Map("a" -> "b"))), "hr") + ) + val targetSchema = StructType(Seq( + StructField("pk", IntegerType, nullable = false), + StructField("s", StructType(Seq( + StructField("c1", IntegerType), + StructField("c2", StructType(Seq( + StructField("a", ArrayType(IntegerType)), + StructField("m", MapType(StringType, StringType)) + ))) + ))), + StructField("dep", StringType) + )) + spark.createDataFrame(spark.sparkContext.parallelize(targetData), targetSchema) + .writeTo(tableNameAsString).append() + + val sourceIdent = Identifier.of(Array("ns1"), "source_table") + val columns = Array( + Column.create("pk", IntegerType, false), + Column.create("s", StructType(Seq( + StructField("c1", IntegerType), + StructField("c2", StructType(Seq( + // removed column 'a' + StructField("m", MapType(StringType, StringType)), + StructField("c3", BooleanType) // new column + ))) + ))), + Column.create("dep", StringType)) + val tableInfo = new TableInfo.Builder() + .withColumns(columns) + .withProperties(extraTableProps) + .build() + catalog.createTable(sourceIdent, tableInfo) + + val data = Seq( + Row(1, Row(10, Row(Map("c" -> "d"), false)), "sales"), + Row(2, Row(20, Row(Map("e" -> "f"), true)), "engineering") + ) + val sourceTableSchema = StructType(Seq( + StructField("pk", IntegerType, nullable = false), + StructField("s", StructType(Seq( + StructField("c1", IntegerType), + StructField("c2", StructType(Seq( + StructField("m", MapType(StringType, StringType)), + StructField("c3", BooleanType) + ))) + ))), + StructField("dep", StringType) + )) + spark.createDataFrame(spark.sparkContext.parallelize(data), sourceTableSchema) + .createOrReplaceTempView("source_temp") + + sql(s"INSERT INTO $sourceTable SELECT * FROM source_temp") + + val mergeBuilder = spark.table(sourceTable) + .mergeInto(tableNameAsString, $"source_table.pk" === col(tableNameAsString + ".pk")) + .whenMatched() + .update(Map( + "s.c1" -> lit(-1), + "s.c2.m" -> map(lit("k"), lit("v")), + "s.c2.a" -> array(lit(-1)), + "s.c2.c3" -> col("source_table.s.c2.c3"))) + .whenNotMatched() + .insert(Map( + "pk" -> col("source_table.pk"), + "s" -> struct( + col("source_table.s.c1").as("c1"), + struct( + array(lit(-2)).as("a"), + map(lit("g"), lit("h")).as("m"), + lit(true).as("c3") + ).as("c2") + ), + "dep" -> col("source_table.dep"))) + + if (coerceNestedTypes && withSchemaEvolution) { + mergeBuilder.withSchemaEvolution().merge() + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Seq(Row(1, Row(-1, Row(Seq(-1), Map("k" -> "v"), false)), "hr"), + Row(2, Row(20, Row(Seq(-2), Map("g" -> "h"), true)), "engineering"))) + } else { + val exception = intercept[org.apache.spark.sql.AnalysisException] { + mergeBuilder.merge() + } + assert(exception.errorClass.get == "FIELD_NOT_FOUND") + assert(exception.getMessage.contains("No such struct field `c3` in `a`, `m`. ")) + } + } + sql(s"DROP TABLE $tableNameAsString") + } + } + } + } + + test("merge schema evolution replace column with nested struct and set all columns " + + "using dataframe API") { + Seq(true, false).foreach { withSchemaEvolution => + Seq(true, false).foreach { coerceNestedTypes => + withSQLConf(SQLConf.MERGE_INTO_NESTED_TYPE_COERCION_ENABLED.key -> + coerceNestedTypes.toString) { + val sourceTable = "cat.ns1.source_table" + withTable(sourceTable) { + sql( + s"""CREATE TABLE $tableNameAsString ( + |pk INT NOT NULL, + |s STRUCT, m: MAP>>, + |dep STRING) + |PARTITIONED BY (dep) + |""".stripMargin) + + val tableSchema = StructType(Seq( + StructField("pk", IntegerType, nullable = false), + StructField("s", StructType(Seq( + StructField("c1", IntegerType), + StructField("c2", StructType(Seq( + StructField("a", ArrayType(IntegerType)), + StructField("m", MapType(StringType, StringType)) + ))) + ))), + StructField("dep", StringType) + )) + val targetData = Seq( + Row(1, Row(2, Row(Array(1, 2), Map("a" -> "b"))), "hr") + ) + spark.createDataFrame(spark.sparkContext.parallelize(targetData), tableSchema) + .coalesce(1).writeTo(tableNameAsString).append() + + val sourceIdent = Identifier.of(Array("ns1"), "source_table") + val columns = Array( + Column.create("pk", IntegerType, false), + Column.create("s", StructType(Seq( + StructField("c1", IntegerType), + StructField("c2", StructType(Seq( + // missing column 'a' + StructField("m", MapType(StringType, StringType)), + StructField("c3", BooleanType) // new column + ))) + ))), + Column.create("dep", StringType)) + val tableInfo = new TableInfo.Builder() + .withColumns(columns) + .withProperties(extraTableProps) + .build() + catalog.createTable(sourceIdent, tableInfo) + + val sourceData = Seq( + Row(1, Row(10, Row(Map("c" -> "d"), false)), "sales"), + Row(2, Row(20, Row(Map("e" -> "f"), true)), "engineering") + ) + val sourceTableSchema = StructType(Seq( + StructField("pk", IntegerType, nullable = false), + StructField("s", StructType(Seq( + StructField("c1", IntegerType), + StructField("c2", StructType(Seq( + StructField("m", MapType(StringType, StringType)), + StructField("c3", BooleanType) + ))) + ))), + StructField("dep", StringType) + )) + spark.createDataFrame(spark.sparkContext.parallelize(sourceData), sourceTableSchema) + .createOrReplaceTempView("source_temp") + + sql(s"INSERT INTO $sourceTable SELECT * FROM source_temp") + + val mergeBuilder = spark.table(sourceTable) + .mergeInto(tableNameAsString, $"source_table.pk" === col(tableNameAsString + ".pk")) + .whenMatched() + .updateAll() + .whenNotMatched() + .insertAll() + + if (coerceNestedTypes) { + if (withSchemaEvolution) { + mergeBuilder.withSchemaEvolution().merge() + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Seq( + Row(1, Row(10, Row(null, Map("c" -> "d"), false)), "sales"), + Row(2, Row(20, Row(null, Map("e" -> "f"), true)), "engineering"))) + } else { + val exception = intercept[org.apache.spark.sql.AnalysisException] { + mergeBuilder.merge() + } + assert(exception.errorClass.get == + "INCOMPATIBLE_DATA_FOR_TABLE.EXTRA_STRUCT_FIELDS") + assert(exception.getMessage.contains( + "Cannot write extra fields `c3` to the struct `s`.`c2`")) + } + } else { + val exception = intercept[org.apache.spark.sql.AnalysisException] { + mergeBuilder.merge() + } + assert(exception.errorClass.get == "INCOMPATIBLE_DATA_FOR_TABLE.CANNOT_FIND_DATA") + assert(exception.getMessage.contains( + "Cannot find data for the output column `s`.`c2`.`a`")) + } + + sql(s"DROP TABLE $tableNameAsString") + } + } + } + } + } + + test("merge schema evolution replace column with nested struct and " + + "update top level struct using dataframe API") { + Seq(true, false).foreach { withSchemaEvolution => + Seq(true, false).foreach { coerceNestedTypes => + withSQLConf(SQLConf.MERGE_INTO_NESTED_TYPE_COERCION_ENABLED.key -> + coerceNestedTypes.toString) { + val sourceTable = "cat.ns1.source_table" + withTable(sourceTable) { + sql( + s"""CREATE TABLE $tableNameAsString ( + |pk INT NOT NULL, + |s STRUCT, m: MAP>>, + |dep STRING) + |PARTITIONED BY (dep) + |""".stripMargin) + + val tableSchema = StructType(Seq( + StructField("pk", IntegerType, nullable = false), + StructField("s", StructType(Seq( + StructField("c1", IntegerType), + StructField("c2", StructType(Seq( + StructField("a", ArrayType(IntegerType)), + StructField("m", MapType(StringType, StringType)) + ))) + ))), + StructField("dep", StringType) + )) + val targetData = Seq( + Row(1, Row(2, Row(Array(1, 2), Map("a" -> "b"))), "hr") + ) + spark.createDataFrame(spark.sparkContext.parallelize(targetData), tableSchema) + .coalesce(1).writeTo(tableNameAsString).append() + + // Create source table + val sourceIdent = Identifier.of(Array("ns1"), "source_table") + val columns = Array( + Column.create("pk", IntegerType, false), + Column.create("s", StructType(Seq( + StructField("c1", IntegerType), + StructField("c2", StructType(Seq( + // missing column 'a' + StructField("m", MapType(StringType, StringType)), + StructField("c3", BooleanType) // new column + ))) + ))), + Column.create("dep", StringType)) + val tableInfo = new TableInfo.Builder() + .withColumns(columns) + .withProperties(extraTableProps) + .build() + catalog.createTable(sourceIdent, tableInfo) + + val sourceData = Seq( + Row(1, Row(10, Row(Map("c" -> "d"), false)), "sales"), + Row(2, Row(20, Row(Map("e" -> "f"), true)), "engineering") + ) + val sourceTableSchema = StructType(Seq( + StructField("pk", IntegerType, nullable = false), + StructField("s", StructType(Seq( + StructField("c1", IntegerType), + StructField("c2", StructType(Seq( + StructField("m", MapType(StringType, StringType)), + StructField("c3", BooleanType) + ))) + ))), + StructField("dep", StringType) + )) + spark.createDataFrame(spark.sparkContext.parallelize(sourceData), sourceTableSchema) + .createOrReplaceTempView("source_temp") + + sql(s"INSERT INTO $sourceTable SELECT * FROM source_temp") + + val mergeBuilder = spark.table(sourceTable) + .mergeInto(tableNameAsString, $"source_table.pk" === col(tableNameAsString + ".pk")) + .whenMatched() + .update(Map("s" -> col("source_table.s"))) + .whenNotMatched() + .insertAll() + + if (coerceNestedTypes) { + if (withSchemaEvolution) { + mergeBuilder.withSchemaEvolution().merge() + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Seq( + Row(1, Row(10, Row(null, Map("c" -> "d"), false)), "hr"), + Row(2, Row(20, Row(null, Map("e" -> "f"), true)), "engineering"))) + } else { + val exception = intercept[org.apache.spark.sql.AnalysisException] { + mergeBuilder.merge() + } + assert(exception.errorClass.get == + "INCOMPATIBLE_DATA_FOR_TABLE.EXTRA_STRUCT_FIELDS") + assert(exception.getMessage.contains( + "Cannot write extra fields `c3` to the struct `s`.`c2`")) + } + } else { + val exception = intercept[org.apache.spark.sql.AnalysisException] { + mergeBuilder.merge() + } + assert(exception.errorClass.get == "INCOMPATIBLE_DATA_FOR_TABLE.CANNOT_FIND_DATA") + assert(exception.getMessage.contains( + "Cannot find data for the output column `s`.`c2`.`a`")) + } + + sql(s"DROP TABLE $tableNameAsString") + } + } + } + } + } + test("merge schema evolution should not evolve referencing new column " + "via transform using dataframe API") { Seq(true, false).foreach { withSchemaEvolution => @@ -5933,76 +6160,98 @@ abstract class MergeIntoTableSuiteBase extends RowLevelOperationSuiteBase } test("merge into with source missing fields in top-level struct using dataframe API") { - val sourceTable = "cat.ns1.source_table" - withTable(sourceTable) { - // Target table has struct with 3 fields at top level - sql( - s"""CREATE TABLE $tableNameAsString ( - |pk INT NOT NULL, - |s STRUCT, - |dep STRING)""".stripMargin) - - val targetData = Seq( - Row(0, Row(1, "a", true), "sales") - ) - val targetSchema = StructType(Seq( - StructField("pk", IntegerType, nullable = false), - StructField("s", StructType(Seq( - StructField("c1", IntegerType), - StructField("c2", StringType), - StructField("c3", BooleanType) - ))), - StructField("dep", StringType) - )) - spark.createDataFrame(spark.sparkContext.parallelize(targetData), targetSchema) - .writeTo(tableNameAsString).append() + Seq(true, false).foreach { coerceNestedTypes => + withSQLConf(SQLConf.MERGE_INTO_NESTED_TYPE_COERCION_ENABLED.key -> + coerceNestedTypes.toString) { + val sourceTable = "cat.ns1.source_table" + withTable(sourceTable) { + // Target table has struct with 3 fields at top level + sql( + s"""CREATE TABLE $tableNameAsString ( + |pk INT NOT NULL, + |s STRUCT, + |dep STRING)""".stripMargin) - // Create source table with struct having only 2 fields (c1, c2) - missing c3 - val sourceIdent = Identifier.of(Array("ns1"), "source_table") - val columns = Array( - Column.create("pk", IntegerType, false), - Column.create("s", StructType(Seq( - StructField("c1", IntegerType), - StructField("c2", StringType)))), // missing c3 field - Column.create("dep", StringType)) - val tableInfo = new TableInfo.Builder() - .withColumns(columns) - .withProperties(extraTableProps) - .build() - catalog.createTable(sourceIdent, tableInfo) + val targetData = Seq( + Row(0, Row(1, "a", true), "sales") + ) + val targetSchema = StructType(Seq( + StructField("pk", IntegerType, nullable = false), + StructField("s", StructType(Seq( + StructField("c1", IntegerType), + StructField("c2", StringType), + StructField("c3", BooleanType) + ))), + StructField("dep", StringType) + )) + spark.createDataFrame(spark.sparkContext.parallelize(targetData), targetSchema) + .writeTo(tableNameAsString).append() - val data = Seq( - Row(1, Row(10, "b"), "hr"), - Row(2, Row(20, "c"), "engineering") - ) - val sourceTableSchema = StructType(Seq( - StructField("pk", IntegerType, nullable = false), - StructField("s", StructType(Seq( - StructField("c1", IntegerType), - StructField("c2", StringType)))), - StructField("dep", StringType))) - spark.createDataFrame(spark.sparkContext.parallelize(data), sourceTableSchema) - .createOrReplaceTempView("source_temp") + // Create source table with struct having only 2 fields (c1, c2) - missing c3 + val sourceIdent = Identifier.of(Array("ns1"), "source_table") + val columns = Array( + Column.create("pk", IntegerType, false), + Column.create("s", StructType(Seq( + StructField("c1", IntegerType), + StructField("c2", StringType)))), // missing c3 field + Column.create("dep", StringType)) + val tableInfo = new TableInfo.Builder() + .withColumns(columns) + .withProperties(extraTableProps) + .build() + catalog.createTable(sourceIdent, tableInfo) - sql(s"INSERT INTO $sourceTable SELECT * FROM source_temp") + val data = Seq( + Row(1, Row(10, "b"), "hr"), + Row(2, Row(20, "c"), "engineering") + ) + val sourceTableSchema = StructType(Seq( + StructField("pk", IntegerType, nullable = false), + StructField("s", StructType(Seq( + StructField("c1", IntegerType), + StructField("c2", StringType)))), + StructField("dep", StringType))) + spark.createDataFrame(spark.sparkContext.parallelize(data), sourceTableSchema) + .createOrReplaceTempView("source_temp") - spark.table(sourceTable) - .mergeInto(tableNameAsString, $"source_table.pk" === col(tableNameAsString + ".pk")) - .whenMatched() - .updateAll() - .whenNotMatched() - .insertAll() - .merge() + sql(s"INSERT INTO $sourceTable SELECT * FROM source_temp") - // Missing field c3 should be filled with NULL - checkAnswer( - sql(s"SELECT * FROM $tableNameAsString"), - Seq( - Row(0, Row(1, "a", true), "sales"), - Row(1, Row(10, "b", null), "hr"), - Row(2, Row(20, "c", null), "engineering"))) + if (coerceNestedTypes) { + spark.table(sourceTable) + .mergeInto(tableNameAsString, $"source_table.pk" === col(tableNameAsString + ".pk")) + .whenMatched() + .updateAll() + .whenNotMatched() + .insertAll() + .merge() + + // Missing field c3 should be filled with NULL + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Seq( + Row(0, Row(1, "a", true), "sales"), + Row(1, Row(10, "b", null), "hr"), + Row(2, Row(20, "c", null), "engineering"))) + } else { + val exception = intercept[org.apache.spark.sql.AnalysisException] { + spark.table(sourceTable) + .mergeInto(tableNameAsString, $"source_table.pk" === col(tableNameAsString + ".pk")) + .whenMatched() + .updateAll() + .whenNotMatched() + .insertAll() + .merge() + } + assert(exception.errorClass.get == + "INCOMPATIBLE_DATA_FOR_TABLE.CANNOT_FIND_DATA") + assert(exception.getMessage.contains( + "Cannot write incompatible data for the table ``: " + + "Cannot find data for the output column `s`.`c3`.")) + } - sql(s"DROP TABLE $tableNameAsString") + sql(s"DROP TABLE $tableNameAsString") + } + } } }