diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala index a13bb2476ceab..9652f7c25a204 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala @@ -401,6 +401,16 @@ case class DataSource( val caseSensitive = sparkSession.sessionState.conf.caseSensitiveAnalysis PartitioningUtils.validatePartitionColumn(data.schema, partitionColumns, caseSensitive) + // SPARK-17230: Resolve the partition columns so InsertIntoHadoopFsRelationCommand does + // not need to have the query as child, to avoid to analyze an optimized query, + // because InsertIntoHadoopFsRelationCommand will be optimized first. + val partitionAttributes = partitionColumns.map { name => + val plan = data.logicalPlan + plan.resolve(name :: Nil, data.sparkSession.sessionState.analyzer.resolver).getOrElse { + throw new AnalysisException( + s"Unable to resolve $name given [${plan.output.map(_.name).mkString(", ")}]") + }.asInstanceOf[Attribute] + } val fileIndex = catalogTable.map(_.identifier).map { tableIdent => sparkSession.table(tableIdent).queryExecution.analyzed.collect { case LogicalRelation(t: HadoopFsRelation, _, _) => t.location @@ -414,7 +424,7 @@ case class DataSource( outputPath = outputPath, staticPartitions = Map.empty, ifPartitionNotExists = false, - partitionColumns = partitionColumns, + partitionColumns = partitionAttributes, bucketSpec = bucketSpec, fileFormat = format, options = options, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala index 04ee081a0f9ce..0df1c38e60dc0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala @@ -188,13 +188,15 @@ case class DataSourceAnalysis(conf: SQLConf) extends Rule[LogicalPlan] with Cast "Cannot overwrite a path that is also being read from.") } + val partitionSchema = actualQuery.resolve( + t.partitionSchema, t.sparkSession.sessionState.analyzer.resolver) val staticPartitions = parts.filter(_._2.nonEmpty).map { case (k, v) => k -> v.get } InsertIntoHadoopFsRelationCommand( outputPath, staticPartitions, i.ifPartitionNotExists, - partitionColumns = t.partitionSchema.map(_.name), + partitionSchema, t.bucketSpec, t.fileFormat, t.options, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala index 2c31d2a84c258..e87cf8d0f84ce 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala @@ -39,7 +39,7 @@ import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, DateTimeUtils} import org.apache.spark.sql.execution.{QueryExecution, SortExec, SQLExecution} -import org.apache.spark.sql.types.{StringType, StructType} +import org.apache.spark.sql.types.StringType import org.apache.spark.util.{SerializableConfiguration, Utils} @@ -101,7 +101,7 @@ object FileFormatWriter extends Logging { committer: FileCommitProtocol, outputSpec: OutputSpec, hadoopConf: Configuration, - partitionColumnNames: Seq[String], + partitionColumns: Seq[Attribute], bucketSpec: Option[BucketSpec], refreshFunction: (Seq[TablePartitionSpec]) => Unit, options: Map[String, String]): Unit = { @@ -111,16 +111,9 @@ object FileFormatWriter extends Logging { job.setOutputValueClass(classOf[InternalRow]) FileOutputFormat.setOutputPath(job, new Path(outputSpec.outputPath)) - val allColumns = queryExecution.executedPlan.output - // Get the actual partition columns as attributes after matching them by name with - // the given columns names. - val partitionColumns = partitionColumnNames.map { col => - val nameEquality = sparkSession.sessionState.conf.resolver - allColumns.find(f => nameEquality(f.name, col)).getOrElse { - throw new RuntimeException( - s"Partition column $col not found in schema ${queryExecution.executedPlan.schema}") - } - } + // Pick the attributes from analyzed plan, as optimizer may not preserve the output schema + // names' case. + val allColumns = queryExecution.analyzed.output val partitionSet = AttributeSet(partitionColumns) val dataColumns = allColumns.filterNot(partitionSet.contains) @@ -179,8 +172,13 @@ object FileFormatWriter extends Logging { val rdd = if (orderingMatched) { queryExecution.toRdd } else { + // SPARK-21165: the `requiredOrdering` is based on the attributes from analyzed plan, and + // the physical plan may have different attribute ids due to optimizer removing some + // aliases. Here we bind the expression ahead to avoid potential attribute ids mismatch. + val orderingExpr = requiredOrdering + .map(SortOrder(_, Ascending)).map(BindReferences.bindReference(_, allColumns)) SortExec( - requiredOrdering.map(SortOrder(_, Ascending)), + orderingExpr, global = false, child = queryExecution.executedPlan).execute() } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala index ab35fdcbc1f25..c9d31449d3629 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala @@ -44,7 +44,7 @@ case class InsertIntoHadoopFsRelationCommand( outputPath: Path, staticPartitions: TablePartitionSpec, ifPartitionNotExists: Boolean, - partitionColumns: Seq[String], + partitionColumns: Seq[Attribute], bucketSpec: Option[BucketSpec], fileFormat: FileFormat, options: Map[String, String], @@ -150,7 +150,7 @@ case class InsertIntoHadoopFsRelationCommand( outputSpec = FileFormatWriter.OutputSpec( qualifiedOutputPath.toString, customPartitionLocations), hadoopConf = hadoopConf, - partitionColumnNames = partitionColumns, + partitionColumns = partitionColumns, bucketSpec = bucketSpec, refreshFunction = refreshPartitionsCallback, options = options) @@ -176,10 +176,10 @@ case class InsertIntoHadoopFsRelationCommand( customPartitionLocations: Map[TablePartitionSpec, String], committer: FileCommitProtocol): Unit = { val staticPartitionPrefix = if (staticPartitions.nonEmpty) { - "/" + partitionColumns.flatMap { col => - staticPartitions.get(col) match { + "/" + partitionColumns.flatMap { p => + staticPartitions.get(p.name) match { case Some(value) => - Some(escapePathName(col) + "=" + escapePathName(value)) + Some(escapePathName(p.name) + "=" + escapePathName(value)) case None => None } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSink.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSink.scala index 2a652920c10c9..6885d0bf67ccb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSink.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSink.scala @@ -111,6 +111,15 @@ class FileStreamSink( case _ => // Do nothing } + // Get the actual partition columns as attributes after matching them by name with + // the given columns names. + val partitionColumns: Seq[Attribute] = partitionColumnNames.map { col => + val nameEquality = data.sparkSession.sessionState.conf.resolver + data.logicalPlan.output.find(f => nameEquality(f.name, col)).getOrElse { + throw new RuntimeException(s"Partition column $col not found in schema ${data.schema}") + } + } + FileFormatWriter.write( sparkSession = sparkSession, queryExecution = data.queryExecution, @@ -118,7 +127,7 @@ class FileStreamSink( committer = committer, outputSpec = FileFormatWriter.OutputSpec(path, Map.empty), hadoopConf = hadoopConf, - partitionColumnNames = partitionColumnNames, + partitionColumns = partitionColumns, bucketSpec = None, refreshFunction = _ => (), options = options) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 3fa538c51bde6..1f29cb752fe36 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -1759,4 +1759,17 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { testData2.select(lit(7), 'a, 'b).orderBy(lit(1), lit(2), lit(3)), Seq(Row(7, 1, 1), Row(7, 1, 2), Row(7, 2, 1), Row(7, 2, 2), Row(7, 3, 1), Row(7, 3, 2))) } + + test("SPARK-22252: FileFormatWriter should respect the input query schema") { + withTable("t1", "t2", "t3", "t4") { + spark.range(1).select('id as 'col1, 'id as 'col2).write.saveAsTable("t1") + spark.sql("select COL1, COL2 from t1").write.saveAsTable("t2") + checkAnswer(spark.table("t2"), Row(0, 0)) + + // Test picking part of the columns when writing. + spark.range(1).select('id, 'id as 'col1, 'id as 'col2).write.saveAsTable("t3") + spark.sql("select COL1, COL2 from t3").write.saveAsTable("t4") + checkAnswer(spark.table("t4"), Row(0, 0)) + } + } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala index 66ee5d4581e7e..8032d7e728a04 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala @@ -314,6 +314,13 @@ case class InsertIntoHiveTable( outputPath = tmpLocation.toString, isAppend = false) + val partitionAttributes = partitionColumnNames.takeRight(numDynamicPartitions).map { name => + query.resolve(name :: Nil, sparkSession.sessionState.analyzer.resolver).getOrElse { + throw new AnalysisException( + s"Unable to resolve $name given [${query.output.map(_.name).mkString(", ")}]") + }.asInstanceOf[Attribute] + } + FileFormatWriter.write( sparkSession = sparkSession, queryExecution = Dataset.ofRows(sparkSession, query).queryExecution, @@ -321,7 +328,7 @@ case class InsertIntoHiveTable( committer = committer, outputSpec = FileFormatWriter.OutputSpec(tmpLocation.toString, Map.empty), hadoopConf = hadoopConf, - partitionColumnNames = partitionColumnNames.takeRight(numDynamicPartitions), + partitionColumns = partitionAttributes, bucketSpec = None, refreshFunction = _ => (), options = Map.empty) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala index 618e5b68ff8c0..d69615669348b 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala @@ -468,7 +468,7 @@ class InsertIntoHiveTableSuite extends QueryTest with TestHiveSingleton with Bef } } - test("SPARK-21165: the query schema of INSERT is changed after optimization") { + test("SPARK-21165: FileFormatWriter should only rely on attributes from analyzed plan") { withSQLConf(("hive.exec.dynamic.partition.mode", "nonstrict")) { withTable("tab1", "tab2") { Seq(("a", "b", 3)).toDF("word", "first", "length").write.saveAsTable("tab1")