From 9ec7d36a198560441e3c3e96fa59789bdd36751b Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Fri, 20 Jan 2017 14:10:36 +0800 Subject: [PATCH 1/2] partitioned table should always put partition columns at the end of table schema --- .../sql/execution/datasources/rules.scala | 57 +++++++++++++------ .../sql/hive/execution/HiveDDLSuite.scala | 30 ++++++++++ 2 files changed, 69 insertions(+), 18 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala index 5ca8226dbcdd6..09d872dcea7a8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala @@ -199,31 +199,52 @@ case class AnalyzeCreateTable(sparkSession: SparkSession) extends Rule[LogicalPl // * can't use all table columns as partition columns. // * partition columns' type must be AtomicType. // * sort columns' type must be orderable. + // * reorder table schema or output of query plan, to put partition columns at the end. case c @ CreateTable(tableDesc, _, query) => - val analyzedQuery = query.map { q => - // Analyze the query in CTAS and then we can do the normalization and checking. - val qe = sparkSession.sessionState.executePlan(q) + if (query.isDefined) { + val qe = sparkSession.sessionState.executePlan(query.get) qe.assertAnalyzed() - qe.analyzed - } - val schema = if (analyzedQuery.isDefined) { - analyzedQuery.get.schema - } else { - tableDesc.schema - } + val analyzedQuery = qe.analyzed + + val normalizedTable = normalizeCatalogTable(analyzedQuery.schema, tableDesc) + + val output = analyzedQuery.output + val partitionAttrs = normalizedTable.partitionColumnNames.map { partCol => + output.find(_.name == partCol).get + } + val newOutput = output.filterNot(partitionAttrs.contains) ++ partitionAttrs + val reorderedQuery = if (newOutput == output) { + analyzedQuery + } else { + Project(newOutput, analyzedQuery) + } - val columnNames = if (sparkSession.sessionState.conf.caseSensitiveAnalysis) { - schema.map(_.name) + c.copy(tableDesc = normalizedTable, query = Some(reorderedQuery)) } else { - schema.map(_.name.toLowerCase) + val normalizedTable = normalizeCatalogTable(tableDesc.schema, tableDesc) + + val partitionSchema = normalizedTable.partitionColumnNames.map { partCol => + normalizedTable.schema.find(_.name == partCol).get + } + + val reorderedSchema = + StructType(normalizedTable.schema.filterNot(partitionSchema.contains) ++ partitionSchema) + + c.copy(tableDesc = normalizedTable.copy(schema = reorderedSchema)) } - checkDuplication(columnNames, "table definition of " + tableDesc.identifier) + } - val normalizedTable = tableDesc.copy( - partitionColumnNames = normalizePartitionColumns(schema, tableDesc), - bucketSpec = normalizeBucketSpec(schema, tableDesc)) + private def normalizeCatalogTable(schema: StructType, table: CatalogTable): CatalogTable = { + val columnNames = if (sparkSession.sessionState.conf.caseSensitiveAnalysis) { + schema.map(_.name) + } else { + schema.map(_.name.toLowerCase) + } + checkDuplication(columnNames, "table definition of " + table.identifier) - c.copy(tableDesc = normalizedTable, query = analyzedQuery) + table.copy( + partitionColumnNames = normalizePartitionColumns(schema, table), + bucketSpec = normalizeBucketSpec(schema, table)) } private def normalizePartitionColumns(schema: StructType, table: CatalogTable): Seq[String] = { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala index edef30823b55c..7f58603d327af 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala @@ -1384,4 +1384,34 @@ class HiveDDLSuite assert(e2.message.contains("Hive data source can only be used with tables")) } } + + test("partitioned table should always put partition columns at the end of table schema") { + def getTableColumns(tblName: String): Seq[String] = { + spark.sessionState.catalog.getTableMetadata(TableIdentifier(tblName)).schema.map(_.name) + } + + withTable("t", "t1", "t2", "t3", "t4") { + sql("CREATE TABLE t(a int, b int, c int, d int) USING parquet PARTITIONED BY (d, b)") + assert(getTableColumns("t") == Seq("a", "c", "d", "b")) + + sql("CREATE TABLE t1 USING parquet PARTITIONED BY (d, b) AS SELECT 1 a, 1 b, 1 c, 1 d") + assert(getTableColumns("t1") == Seq("a", "c", "d", "b")) + + Seq((1, 1, 1, 1)).toDF("a", "b", "c", "d").write.partitionBy("d", "b").saveAsTable("t2") + assert(getTableColumns("t2") == Seq("a", "c", "d", "b")) + + withTempPath { path => + val dataPath = new File(new File(path, "d=1"), "b=1").getCanonicalPath + Seq(1 -> 1).toDF("a", "c").write.save(dataPath) + + sql(s"CREATE TABLE t3 USING parquet LOCATION '${path.getCanonicalPath}'") + assert(getTableColumns("t3") == Seq("a", "c", "d", "b")) + } + + sql("CREATE TABLE t4(a int, b int, c int, d int) USING hive PARTITIONED BY (d, b)") + assert(getTableColumns("t4") == Seq("a", "c", "d", "b")) + + // TODO: add test for creating partitioned hive serde table as select, once we support it. + } + } } From 68f639e468333faa9070cca639b3b491585b2e39 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Sat, 21 Jan 2017 11:14:43 +0800 Subject: [PATCH 2/2] address comments --- .../org/apache/spark/sql/execution/datasources/rules.scala | 3 +++ 1 file changed, 3 insertions(+) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala index 09d872dcea7a8..c84533779475e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala @@ -202,6 +202,9 @@ case class AnalyzeCreateTable(sparkSession: SparkSession) extends Rule[LogicalPl // * reorder table schema or output of query plan, to put partition columns at the end. case c @ CreateTable(tableDesc, _, query) => if (query.isDefined) { + assert(tableDesc.schema.isEmpty, + "Schema may not be specified in a Create Table As Select (CTAS) statement") + val qe = sparkSession.sessionState.executePlan(query.get) qe.assertAnalyzed() val analyzedQuery = qe.analyzed