diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala index 5dbcf4a915cbf..5cc21eeaeaa94 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala @@ -22,7 +22,7 @@ import java.util.Locale import org.apache.spark.sql.{AnalysisException, SaveMode, SparkSession} import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.catalog._ -import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, Expression, InputFileBlockLength, InputFileBlockStart, InputFileName, RowOrdering} +import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, Cast, Expression, InputFileBlockLength, InputFileBlockStart, InputFileName, RowOrdering} import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.execution.command.DDLUtils @@ -178,7 +178,8 @@ case class PreprocessTableCreation(sparkSession: SparkSession) extends Rule[Logi c.copy( tableDesc = existingTable, - query = Some(newQuery)) + query = Some(DDLPreprocessingUtils.castAndRenameQueryOutput( + newQuery, existingTable.schema.toAttributes, conf))) // Here we normalize partition, bucket and sort column names, w.r.t. the case sensitivity // config, and do various checks: @@ -316,7 +317,7 @@ case class PreprocessTableCreation(sparkSession: SparkSession) extends Rule[Logi * table. It also does data type casting and field renaming, to make sure that the columns to be * inserted have the correct data type and fields have the correct names. */ -case class PreprocessTableInsertion(conf: SQLConf) extends Rule[LogicalPlan] with CastSupport { +case class PreprocessTableInsertion(conf: SQLConf) extends Rule[LogicalPlan] { private def preprocess( insert: InsertIntoTable, tblName: String, @@ -336,6 +337,8 @@ case class PreprocessTableInsertion(conf: SQLConf) extends Rule[LogicalPlan] wit s"including ${staticPartCols.size} partition column(s) having constant value(s).") } + val newQuery = DDLPreprocessingUtils.castAndRenameQueryOutput( + insert.query, expectedColumns, conf) if (normalizedPartSpec.nonEmpty) { if (normalizedPartSpec.size != partColNames.length) { throw new AnalysisException( @@ -346,37 +349,11 @@ case class PreprocessTableInsertion(conf: SQLConf) extends Rule[LogicalPlan] wit """.stripMargin) } - castAndRenameChildOutput(insert.copy(partition = normalizedPartSpec), expectedColumns) + insert.copy(query = newQuery, partition = normalizedPartSpec) } else { // All partition columns are dynamic because the InsertIntoTable command does // not explicitly specify partitioning columns. - castAndRenameChildOutput(insert, expectedColumns) - .copy(partition = partColNames.map(_ -> None).toMap) - } - } - - private def castAndRenameChildOutput( - insert: InsertIntoTable, - expectedOutput: Seq[Attribute]): InsertIntoTable = { - val newChildOutput = expectedOutput.zip(insert.query.output).map { - case (expected, actual) => - if (expected.dataType.sameType(actual.dataType) && - expected.name == actual.name && - expected.metadata == actual.metadata) { - actual - } else { - // Renaming is needed for handling the following cases like - // 1) Column names/types do not match, e.g., INSERT INTO TABLE tab1 SELECT 1, 2 - // 2) Target tables have column metadata - Alias(cast(actual, expected.dataType), expected.name)( - explicitMetadata = Option(expected.metadata)) - } - } - - if (newChildOutput == insert.query.output) { - insert - } else { - insert.copy(query = Project(newChildOutput, insert.query)) + insert.copy(query = newQuery, partition = partColNames.map(_ -> None).toMap) } } @@ -491,3 +468,36 @@ object PreWriteCheck extends (LogicalPlan => Unit) { } } } + +object DDLPreprocessingUtils { + + /** + * Adjusts the name and data type of the input query output columns, to match the expectation. + */ + def castAndRenameQueryOutput( + query: LogicalPlan, + expectedOutput: Seq[Attribute], + conf: SQLConf): LogicalPlan = { + val newChildOutput = expectedOutput.zip(query.output).map { + case (expected, actual) => + if (expected.dataType.sameType(actual.dataType) && + expected.name == actual.name && + expected.metadata == actual.metadata) { + actual + } else { + // Renaming is needed for handling the following cases like + // 1) Column names/types do not match, e.g., INSERT INTO TABLE tab1 SELECT 1, 2 + // 2) Target tables have column metadata + Alias( + Cast(actual, expected.dataType, Option(conf.sessionLocalTimeZone)), + expected.name)(explicitMetadata = Option(expected.metadata)) + } + } + + if (newChildOutput == query.output) { + query + } else { + Project(newChildOutput, query) + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala index ee3674ba17821..f76bfd2fda2b9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala @@ -37,6 +37,8 @@ import org.apache.spark.util.Utils class InMemoryCatalogedDDLSuite extends DDLSuite with SharedSQLContext with BeforeAndAfterEach { + import testImplicits._ + override def afterEach(): Unit = { try { // drop all databases, tables and functions after each test @@ -132,6 +134,32 @@ class InMemoryCatalogedDDLSuite extends DDLSuite with SharedSQLContext with Befo checkAnswer(spark.table("t"), Row(Row("a", 1)) :: Nil) } } + + // TODO: This test is copied from HiveDDLSuite, unify it later. + test("SPARK-23348: append data to data source table with saveAsTable") { + withTable("t", "t1") { + Seq(1 -> "a").toDF("i", "j").write.saveAsTable("t") + checkAnswer(spark.table("t"), Row(1, "a")) + + sql("INSERT INTO t SELECT 2, 'b'") + checkAnswer(spark.table("t"), Row(1, "a") :: Row(2, "b") :: Nil) + + Seq(3 -> "c").toDF("i", "j").write.mode("append").saveAsTable("t") + checkAnswer(spark.table("t"), Row(1, "a") :: Row(2, "b") :: Row(3, "c") :: Nil) + + Seq("c" -> 3).toDF("i", "j").write.mode("append").saveAsTable("t") + checkAnswer(spark.table("t"), Row(1, "a") :: Row(2, "b") :: Row(3, "c") + :: Row(null, "3") :: Nil) + + Seq(4 -> "d").toDF("i", "j").write.saveAsTable("t1") + + val e = intercept[AnalysisException] { + Seq(5 -> "e").toDF("i", "j").write.mode("append").format("json").saveAsTable("t1") + } + assert(e.message.contains("The format of the existing table default.t1 is " + + "`ParquetFileFormat`. It doesn't match the specified format `JsonFileFormat`.")) + } + } } abstract class DDLSuite extends QueryTest with SQLTestUtils {