Skip to content

Commit

Permalink
[SPARK-32976][SQL] Support column list in INSERT statement
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?

#### JIRA expectations
```
   INSERT currently does not support named column lists.

   INSERT INTO <table> (col1, col2,…) VALUES( 'val1', 'val2', … )
   Note, we assume the column list contains all the column names. Issue an exception if the list is not complete. The column order could be different from the column order defined in the table definition.
```
#### implemetations
In this PR, we add a column list  as an optional part to the `INSERT OVERWRITE/INTO` statements:
```
  /**
   * {{{
   *   INSERT OVERWRITE TABLE tableIdentifier [partitionSpec [IF NOT EXISTS]]? [identifierList] ...
   *   INSERT INTO [TABLE] tableIdentifier [partitionSpec]  [identifierList] ...
   * }}}
   */
```
The column list represents all expected columns with an explicit order that you want to insert to the target table. **Particularly**,  we assume the column list contains all the column names in the current implementation, it will fail when the list is incomplete.

In **Analyzer**, we add a code path to resolve the column list in the `ResolveOutputRelation` rule before it is transformed to v1 or v2 command. It will fail here if the list has any field that not belongs to the target table.

Then, for v2 command, e.g. `AppendData`, we use the resolved column list and output of the target table to resolve the output of the source query `ResolveOutputRelation` rule. If the list has duplicated columns, we fail. If the list is not empty but the list size does not match the target table, we fail. If no other exceptions occur, we use the column list to map the output of the source query to the output of the target table.  The column list will be set to Nil and it will not hit the rule again after it is resolved.

for v1 command, those all happen in the `PreprocessTableInsertion` rule

### Why are the changes needed?
 new feature support

### Does this PR introduce _any_ user-facing change?

yes, insert into/overwrite table support specify column list
### How was this patch tested?

new tests

Closes #29893 from yaooqinn/SPARK-32976.

Authored-by: Kent Yao <yaooqinn@hotmail.com>
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
  • Loading branch information
yaooqinn authored and cloud-fan committed Nov 30, 2020
1 parent 4851453 commit 2da7259
Show file tree
Hide file tree
Showing 16 changed files with 396 additions and 34 deletions.
Expand Up @@ -332,8 +332,8 @@ query
;

insertInto
: INSERT OVERWRITE TABLE? multipartIdentifier (partitionSpec (IF NOT EXISTS)?)? #insertOverwriteTable
| INSERT INTO TABLE? multipartIdentifier partitionSpec? (IF NOT EXISTS)? #insertIntoTable
: INSERT OVERWRITE TABLE? multipartIdentifier (partitionSpec (IF NOT EXISTS)?)? identifierList? #insertOverwriteTable
| INSERT INTO TABLE? multipartIdentifier partitionSpec? (IF NOT EXISTS)? identifierList? #insertIntoTable
| INSERT OVERWRITE LOCAL? DIRECTORY path=STRING rowFormat? createFileFormat? #insertOverwriteHiveDir
| INSERT OVERWRITE LOCAL? DIRECTORY (path=STRING)? tableProvider (OPTIONS options=tablePropertyList)? #insertOverwriteDir
;
Expand Down
Expand Up @@ -49,7 +49,7 @@ import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.internal.SQLConf.{PartitionOverwriteMode, StoreAssignmentPolicy}
import org.apache.spark.sql.types._
import org.apache.spark.sql.util.CaseInsensitiveStringMap
import org.apache.spark.sql.util.{CaseInsensitiveStringMap, SchemaUtils}
import org.apache.spark.util.Utils

/**
Expand Down Expand Up @@ -218,6 +218,7 @@ class Analyzer(override val catalogManager: CatalogManager)
ResolveTableValuedFunctions ::
ResolveNamespace(catalogManager) ::
new ResolveCatalogs(catalogManager) ::
ResolveUserSpecifiedColumns ::
ResolveInsertInto ::
ResolveRelations ::
ResolveTables ::
Expand Down Expand Up @@ -846,7 +847,7 @@ class Analyzer(override val catalogManager: CatalogManager)
def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUp {
case u @ UnresolvedRelation(ident, _, isStreaming) =>
lookupTempView(ident, isStreaming).getOrElse(u)
case i @ InsertIntoStatement(UnresolvedRelation(ident, _, false), _, _, _, _) =>
case i @ InsertIntoStatement(UnresolvedRelation(ident, _, false), _, _, _, _, _) =>
lookupTempView(ident)
.map(view => i.copy(table = view))
.getOrElse(i)
Expand Down Expand Up @@ -961,7 +962,7 @@ class Analyzer(override val catalogManager: CatalogManager)
.map(ResolvedTable(catalog.asTableCatalog, ident, _))
.getOrElse(u)

case i @ InsertIntoStatement(u @ UnresolvedRelation(_, _, false), _, _, _, _)
case i @ InsertIntoStatement(u @ UnresolvedRelation(_, _, false), _, _, _, _, _)
if i.query.resolved =>
lookupV2Relation(u.multipartIdentifier, u.options, false)
.map(v2Relation => i.copy(table = v2Relation))
Expand Down Expand Up @@ -1045,7 +1046,7 @@ class Analyzer(override val catalogManager: CatalogManager)
}

def apply(plan: LogicalPlan): LogicalPlan = ResolveTempViews(plan).resolveOperatorsUp {
case i @ InsertIntoStatement(table, _, _, _, _) if i.query.resolved =>
case i @ InsertIntoStatement(table, _, _, _, _, _) if i.query.resolved =>
val relation = table match {
case u @ UnresolvedRelation(_, _, false) =>
lookupRelation(u.multipartIdentifier, u.options, false).getOrElse(u)
Expand Down Expand Up @@ -1160,7 +1161,8 @@ class Analyzer(override val catalogManager: CatalogManager)

object ResolveInsertInto extends Rule[LogicalPlan] {
override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
case i @ InsertIntoStatement(r: DataSourceV2Relation, _, _, _, _) if i.query.resolved =>
case i @ InsertIntoStatement(r: DataSourceV2Relation, _, _, _, _, _)
if i.query.resolved && i.userSpecifiedCols.isEmpty =>
// ifPartitionNotExists is append with validation, but validation is not supported
if (i.ifPartitionNotExists) {
throw QueryCompilationErrors.unsupportedIfNotExistsError(r.table.name)
Expand Down Expand Up @@ -3107,6 +3109,46 @@ class Analyzer(override val catalogManager: CatalogManager)
}
}

object ResolveUserSpecifiedColumns extends Rule[LogicalPlan] {
override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators {
case i: InsertIntoStatement if i.table.resolved && i.query.resolved &&
i.userSpecifiedCols.nonEmpty =>
val resolved = resolveUserSpecifiedColumns(i)
val projection = addColumnListOnQuery(i.table.output, resolved, i.query)
i.copy(userSpecifiedCols = Nil, query = projection)
}

private def resolveUserSpecifiedColumns(i: InsertIntoStatement): Seq[NamedExpression] = {
SchemaUtils.checkColumnNameDuplication(
i.userSpecifiedCols, "in the column list", resolver)

i.userSpecifiedCols.map { col =>
i.table.resolve(Seq(col), resolver)
.getOrElse(i.table.failAnalysis(s"Cannot resolve column name $col"))
}
}

private def addColumnListOnQuery(
tableOutput: Seq[Attribute],
cols: Seq[NamedExpression],
query: LogicalPlan): LogicalPlan = {
if (cols.size != query.output.size) {
query.failAnalysis(
s"Cannot write to table due to mismatched user specified column size(${cols.size}) and" +
s" data column size(${query.output.size})")
}
val nameToQueryExpr = cols.zip(query.output).toMap
// Static partition columns in the table output should not appear in the column list
// they will be handled in another rule ResolveInsertInto
val reordered = tableOutput.flatMap { nameToQueryExpr.get(_).orElse(None) }
if (reordered == query.output) {
query
} else {
Project(reordered, query)
}
}
}

private def validateStoreAssignmentPolicy(): Unit = {
// SPARK-28730: LEGACY store assignment policy is disallowed in data source v2.
if (conf.storeAssignmentPolicy == StoreAssignmentPolicy.LEGACY) {
Expand Down
Expand Up @@ -108,7 +108,7 @@ trait CheckAnalysis extends PredicateHelper {
case u: UnresolvedRelation =>
u.failAnalysis(s"Table or view not found: ${u.multipartIdentifier.quoted}")

case InsertIntoStatement(u: UnresolvedRelation, _, _, _, _) =>
case InsertIntoStatement(u: UnresolvedRelation, _, _, _, _, _) =>
failAnalysis(s"Table not found: ${u.multipartIdentifier.quoted}")

// TODO (SPARK-27484): handle streaming write commands when we have them.
Expand Down
Expand Up @@ -431,7 +431,7 @@ package object dsl {
partition: Map[String, Option[String]] = Map.empty,
overwrite: Boolean = false,
ifPartitionNotExists: Boolean = false): LogicalPlan =
InsertIntoStatement(table, partition, logicalPlan, overwrite, ifPartitionNotExists)
InsertIntoStatement(table, partition, Nil, logicalPlan, overwrite, ifPartitionNotExists)

def as(alias: String): LogicalPlan = SubqueryAlias(alias, logicalPlan)

Expand Down
Expand Up @@ -243,9 +243,9 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with SQLConfHelper with Logg

/**
* Parameters used for writing query to a table:
* (multipartIdentifier, partitionKeys, ifPartitionNotExists).
* (multipartIdentifier, tableColumnList, partitionKeys, ifPartitionNotExists).
*/
type InsertTableParams = (Seq[String], Map[String, Option[String]], Boolean)
type InsertTableParams = (Seq[String], Seq[String], Map[String, Option[String]], Boolean)

/**
* Parameters used for writing query to a directory: (isLocal, CatalogStorageFormat, provider).
Expand All @@ -255,8 +255,8 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with SQLConfHelper with Logg
/**
* Add an
* {{{
* INSERT OVERWRITE TABLE tableIdentifier [partitionSpec [IF NOT EXISTS]]?
* INSERT INTO [TABLE] tableIdentifier [partitionSpec]
* INSERT OVERWRITE TABLE tableIdentifier [partitionSpec [IF NOT EXISTS]]? [identifierList]
* INSERT INTO [TABLE] tableIdentifier [partitionSpec] [identifierList]
* INSERT OVERWRITE [LOCAL] DIRECTORY STRING [rowFormat] [createFileFormat]
* INSERT OVERWRITE [LOCAL] DIRECTORY [STRING] tableProvider [OPTIONS tablePropertyList]
* }}}
Expand All @@ -267,18 +267,20 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with SQLConfHelper with Logg
query: LogicalPlan): LogicalPlan = withOrigin(ctx) {
ctx match {
case table: InsertIntoTableContext =>
val (tableIdent, partition, ifPartitionNotExists) = visitInsertIntoTable(table)
val (tableIdent, cols, partition, ifPartitionNotExists) = visitInsertIntoTable(table)
InsertIntoStatement(
UnresolvedRelation(tableIdent),
partition,
cols,
query,
overwrite = false,
ifPartitionNotExists)
case table: InsertOverwriteTableContext =>
val (tableIdent, partition, ifPartitionNotExists) = visitInsertOverwriteTable(table)
val (tableIdent, cols, partition, ifPartitionNotExists) = visitInsertOverwriteTable(table)
InsertIntoStatement(
UnresolvedRelation(tableIdent),
partition,
cols,
query,
overwrite = true,
ifPartitionNotExists)
Expand All @@ -299,13 +301,14 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with SQLConfHelper with Logg
override def visitInsertIntoTable(
ctx: InsertIntoTableContext): InsertTableParams = withOrigin(ctx) {
val tableIdent = visitMultipartIdentifier(ctx.multipartIdentifier)
val cols = Option(ctx.identifierList()).map(visitIdentifierList).getOrElse(Nil)
val partitionKeys = Option(ctx.partitionSpec).map(visitPartitionSpec).getOrElse(Map.empty)

if (ctx.EXISTS != null) {
operationNotAllowed("INSERT INTO ... IF NOT EXISTS", ctx)
}

(tableIdent, partitionKeys, false)
(tableIdent, cols, partitionKeys, false)
}

/**
Expand All @@ -315,6 +318,7 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with SQLConfHelper with Logg
ctx: InsertOverwriteTableContext): InsertTableParams = withOrigin(ctx) {
assert(ctx.OVERWRITE() != null)
val tableIdent = visitMultipartIdentifier(ctx.multipartIdentifier)
val cols = Option(ctx.identifierList()).map(visitIdentifierList).getOrElse(Nil)
val partitionKeys = Option(ctx.partitionSpec).map(visitPartitionSpec).getOrElse(Map.empty)

val dynamicPartitionKeys: Map[String, Option[String]] = partitionKeys.filter(_._2.isEmpty)
Expand All @@ -323,7 +327,7 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with SQLConfHelper with Logg
dynamicPartitionKeys.keys.mkString(", "), ctx)
}

(tableIdent, partitionKeys, ctx.EXISTS() != null)
(tableIdent, cols, partitionKeys, ctx.EXISTS() != null)
}

/**
Expand Down
Expand Up @@ -357,6 +357,7 @@ case class DropViewStatement(
* An INSERT INTO statement, as parsed from SQL.
*
* @param table the logical plan representing the table.
* @param userSpecifiedCols the user specified list of columns that belong to the table.
* @param query the logical plan representing data to write to.
* @param overwrite overwrite existing table or partitions.
* @param partitionSpec a map from the partition key to the partition value (optional).
Expand All @@ -371,6 +372,7 @@ case class DropViewStatement(
case class InsertIntoStatement(
table: LogicalPlan,
partitionSpec: Map[String, Option[String]],
userSpecifiedCols: Seq[String],
query: LogicalPlan,
overwrite: Boolean,
ifPartitionNotExists: Boolean) extends ParsedStatement {
Expand Down
Expand Up @@ -1172,6 +1172,22 @@ class DDLParserSuite extends AnalysisTest {
InsertIntoStatement(
UnresolvedRelation(Seq("testcat", "ns1", "ns2", "tbl")),
Map.empty,
Nil,
Project(Seq(UnresolvedStar(None)), UnresolvedRelation(Seq("source"))),
overwrite = false, ifPartitionNotExists = false))
}
}

test("insert table: basic append with a column list") {
Seq(
"INSERT INTO TABLE testcat.ns1.ns2.tbl (a, b) SELECT * FROM source",
"INSERT INTO testcat.ns1.ns2.tbl (a, b) SELECT * FROM source"
).foreach { sql =>
parseCompare(sql,
InsertIntoStatement(
UnresolvedRelation(Seq("testcat", "ns1", "ns2", "tbl")),
Map.empty,
Seq("a", "b"),
Project(Seq(UnresolvedStar(None)), UnresolvedRelation(Seq("source"))),
overwrite = false, ifPartitionNotExists = false))
}
Expand All @@ -1182,6 +1198,7 @@ class DDLParserSuite extends AnalysisTest {
InsertIntoStatement(
UnresolvedRelation(Seq("testcat", "ns1", "ns2", "tbl")),
Map.empty,
Nil,
Project(Seq(UnresolvedStar(None)), UnresolvedRelation(Seq("testcat2", "db", "tbl"))),
overwrite = false, ifPartitionNotExists = false))
}
Expand All @@ -1196,6 +1213,22 @@ class DDLParserSuite extends AnalysisTest {
InsertIntoStatement(
UnresolvedRelation(Seq("testcat", "ns1", "ns2", "tbl")),
Map("p1" -> Some("3"), "p2" -> None),
Nil,
Project(Seq(UnresolvedStar(None)), UnresolvedRelation(Seq("source"))),
overwrite = false, ifPartitionNotExists = false))
}

test("insert table: append with partition and a column list") {
parseCompare(
"""
|INSERT INTO testcat.ns1.ns2.tbl
|PARTITION (p1 = 3, p2) (a, b)
|SELECT * FROM source
""".stripMargin,
InsertIntoStatement(
UnresolvedRelation(Seq("testcat", "ns1", "ns2", "tbl")),
Map("p1" -> Some("3"), "p2" -> None),
Seq("a", "b"),
Project(Seq(UnresolvedStar(None)), UnresolvedRelation(Seq("source"))),
overwrite = false, ifPartitionNotExists = false))
}
Expand All @@ -1209,6 +1242,22 @@ class DDLParserSuite extends AnalysisTest {
InsertIntoStatement(
UnresolvedRelation(Seq("testcat", "ns1", "ns2", "tbl")),
Map.empty,
Nil,
Project(Seq(UnresolvedStar(None)), UnresolvedRelation(Seq("source"))),
overwrite = true, ifPartitionNotExists = false))
}
}

test("insert table: overwrite with column list") {
Seq(
"INSERT OVERWRITE TABLE testcat.ns1.ns2.tbl (a, b) SELECT * FROM source",
"INSERT OVERWRITE testcat.ns1.ns2.tbl (a, b) SELECT * FROM source"
).foreach { sql =>
parseCompare(sql,
InsertIntoStatement(
UnresolvedRelation(Seq("testcat", "ns1", "ns2", "tbl")),
Map.empty,
Seq("a", "b"),
Project(Seq(UnresolvedStar(None)), UnresolvedRelation(Seq("source"))),
overwrite = true, ifPartitionNotExists = false))
}
Expand All @@ -1224,6 +1273,22 @@ class DDLParserSuite extends AnalysisTest {
InsertIntoStatement(
UnresolvedRelation(Seq("testcat", "ns1", "ns2", "tbl")),
Map("p1" -> Some("3"), "p2" -> None),
Nil,
Project(Seq(UnresolvedStar(None)), UnresolvedRelation(Seq("source"))),
overwrite = true, ifPartitionNotExists = false))
}

test("insert table: overwrite with partition and column list") {
parseCompare(
"""
|INSERT OVERWRITE TABLE testcat.ns1.ns2.tbl
|PARTITION (p1 = 3, p2) (a, b)
|SELECT * FROM source
""".stripMargin,
InsertIntoStatement(
UnresolvedRelation(Seq("testcat", "ns1", "ns2", "tbl")),
Map("p1" -> Some("3"), "p2" -> None),
Seq("a", "b"),
Project(Seq(UnresolvedStar(None)), UnresolvedRelation(Seq("source"))),
overwrite = true, ifPartitionNotExists = false))
}
Expand All @@ -1238,6 +1303,7 @@ class DDLParserSuite extends AnalysisTest {
InsertIntoStatement(
UnresolvedRelation(Seq("testcat", "ns1", "ns2", "tbl")),
Map("p1" -> Some("3")),
Nil,
Project(Seq(UnresolvedStar(None)), UnresolvedRelation(Seq("source"))),
overwrite = true, ifPartitionNotExists = true))
}
Expand Down
Expand Up @@ -295,7 +295,7 @@ class PlanParserSuite extends AnalysisTest {
partition: Map[String, Option[String]],
overwrite: Boolean = false,
ifPartitionNotExists: Boolean = false): LogicalPlan =
InsertIntoStatement(table("s"), partition, plan, overwrite, ifPartitionNotExists)
InsertIntoStatement(table("s"), partition, Nil, plan, overwrite, ifPartitionNotExists)

// Single inserts
assertEqual(s"insert overwrite table s $sql",
Expand Down Expand Up @@ -713,7 +713,7 @@ class PlanParserSuite extends AnalysisTest {
comparePlans(
parsePlan(
"INSERT INTO s SELECT /*+ REPARTITION(100), COALESCE(500), COALESCE(10) */ * FROM t"),
InsertIntoStatement(table("s"), Map.empty,
InsertIntoStatement(table("s"), Map.empty, Nil,
UnresolvedHint("REPARTITION", Seq(Literal(100)),
UnresolvedHint("COALESCE", Seq(Literal(500)),
UnresolvedHint("COALESCE", Seq(Literal(10)),
Expand Down
Expand Up @@ -536,6 +536,7 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) {
InsertIntoStatement(
table = UnresolvedRelation(tableIdent),
partitionSpec = Map.empty[String, Option[String]],
Nil,
query = df.logicalPlan,
overwrite = mode == SaveMode.Overwrite,
ifPartitionNotExists = false)
Expand Down
Expand Up @@ -156,7 +156,7 @@ object DataSourceAnalysis extends Rule[LogicalPlan] with CastSupport {
CreateDataSourceTableAsSelectCommand(tableDesc, mode, query, query.output.map(_.name))

case InsertIntoStatement(l @ LogicalRelation(_: InsertableRelation, _, _, _),
parts, query, overwrite, false) if parts.isEmpty =>
parts, _, query, overwrite, false) if parts.isEmpty =>
InsertIntoDataSourceCommand(l, query, overwrite)

case InsertIntoDir(_, storage, provider, query, overwrite)
Expand All @@ -168,7 +168,7 @@ object DataSourceAnalysis extends Rule[LogicalPlan] with CastSupport {
InsertIntoDataSourceDirCommand(storage, provider.get, query, overwrite)

case i @ InsertIntoStatement(
l @ LogicalRelation(t: HadoopFsRelation, _, table, _), parts, query, overwrite, _) =>
l @ LogicalRelation(t: HadoopFsRelation, _, table, _), parts, _, query, overwrite, _) =>
// If the InsertIntoTable command is for a partitioned HadoopFsRelation and
// the user has specified static partitions, we add a Project operator on top of the query
// to include those constant column values in the query result.
Expand Down Expand Up @@ -276,11 +276,11 @@ class FindDataSourceTable(sparkSession: SparkSession) extends Rule[LogicalPlan]


override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
case i @ InsertIntoStatement(UnresolvedCatalogRelation(tableMeta, options, false), _, _, _, _)
if DDLUtils.isDatasourceTable(tableMeta) =>
case i @ InsertIntoStatement(UnresolvedCatalogRelation(tableMeta, options, false),
_, _, _, _, _) if DDLUtils.isDatasourceTable(tableMeta) =>
i.copy(table = readDataSourceTable(tableMeta, options))

case i @ InsertIntoStatement(UnresolvedCatalogRelation(tableMeta, _, false), _, _, _, _) =>
case i @ InsertIntoStatement(UnresolvedCatalogRelation(tableMeta, _, false), _, _, _, _, _) =>
i.copy(table = DDLUtils.readHiveTable(tableMeta))

case UnresolvedCatalogRelation(tableMeta, options, false)
Expand Down

0 comments on commit 2da7259

Please sign in to comment.