Skip to content

Commit

Permalink
[SPARK-40000][SQL] Update INSERTs without user-specified fields to no…
Browse files Browse the repository at this point in the history
…t automatically add default values

### What changes were proposed in this pull request?

Update INSERTs without user-specified fields to not automatically add default values.

For example, with the new behavior, this `INSERT INTO` command will fail with an error message reporting that the table has two columns but the command only inserted one:

```
CREATE TABLE t (a INT DEFAULT 1, b INT DEFAULT 2) USING PARQUET;
INSERT INTO t VALUES (42);
```

For INSERTs with user-specified fields, these commands may now specify fewer field/value pairs than the number of columns in the target table. The analyzer will assign the default value for each remaining column (either NULL, or else the explicit DEFAULT value assigned to the column from a previous command).

For example, with the new behavior, this `INSERT INTO` command will succeed, assigning the new row `(42, 2)` to the target table:

```
CREATE TABLE t (a INT DEFAULT 1, b INT DEFAULT 2) USING PARQUET;
INSERT INTO t (a) VALUES (42);
```

To implement this behavior, this PR creates the following config which is true by default:

`spark.sql.defaultColumn.addMissingValuesForInsertsWithExplicitColumns`

To switch back to the previous behavior of returning errors for `INSERT INTO` commands with fewer user-specified fields than the number of columns in the target table, switch this new config to false.

### Why are the changes needed?

After looking at desired SQL semantics, it is preferred to be strict and require that the number of inserted columns exactly matches the target table to prevent against accidental mistakes.

### Does this PR introduce _any_ user-facing change?

Yes, see above.

### How was this patch tested?

Updated unit test coverage.

Closes #37430 from dtenedor/insert-fewer-columns.

Lead-authored-by: Daniel Tenedorio <daniel.tenedorio@databricks.com>
Co-authored-by: Gengliang Wang <gengliang@apache.org>
Signed-off-by: Gengliang Wang <gengliang@apache.org>
  • Loading branch information
dtenedor and gengliangwang committed Aug 16, 2022
1 parent 2511584 commit 13c1b59
Show file tree
Hide file tree
Showing 6 changed files with 194 additions and 178 deletions.
1 change: 1 addition & 0 deletions docs/sql-migration-guide.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ license: |

- Since Spark 3.4, Number or Number(\*) from Teradata will be treated as Decimal(38,18). In Spark 3.3 or earlier, Number or Number(\*) from Teradata will be treated as Decimal(38, 0), in which case the fractional part will be removed.
- Since Spark 3.4, v1 database, table, permanent view and function identifier will include 'spark_catalog' as the catalog name if database is defined, e.g. a table identifier will be: `spark_catalog.default.t`. To restore the legacy behavior, set `spark.sql.legacy.v1IdentifierNoCatalog` to `true`.
- Since Spark 3.4, `INSERT INTO` commands will now support user-specified column lists comprising fewer columns than present in the target table (for example, `INSERT INTO t (a, b) VALUES (1, 2)` where table `t` has three columns). In this case, Spark will insert `NULL` into the remaining columns in the row, or the explicit `DEFAULT` value if assigned to the column. To revert to the previous behavior, please set `spark.sql.defaultColumn.addMissingValuesForInsertsWithExplicitColumns` to false.
- Since Spark 3.4, when ANSI SQL mode(configuration `spark.sql.ansi.enabled`) is on, Spark SQL always returns NULL result on getting a map value with a non-existing key. In Spark 3.3 or earlier, there will be an error.

## Upgrading from Spark SQL 3.2 to 3.3
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ case class ResolveDefaultColumns(catalog: SessionCatalog) extends Rule[LogicalPl
val regenerated: InsertIntoStatement =
regenerateUserSpecifiedCols(i, schema)
val expanded: LogicalPlan =
addMissingDefaultValuesForInsertFromInlineTable(node, schema)
addMissingDefaultValuesForInsertFromInlineTable(node, schema, i.userSpecifiedCols.length)
val replaced: Option[LogicalPlan] =
replaceExplicitDefaultValuesForInputOfInsertInto(schema, expanded)
replaced.map { r: LogicalPlan =>
Expand All @@ -132,7 +132,7 @@ case class ResolveDefaultColumns(catalog: SessionCatalog) extends Rule[LogicalPl
val regenerated: InsertIntoStatement = regenerateUserSpecifiedCols(i, schema)
val project: Project = i.query.asInstanceOf[Project]
val expanded: Project =
addMissingDefaultValuesForInsertFromProject(project, schema)
addMissingDefaultValuesForInsertFromProject(project, schema, i.userSpecifiedCols.length)
val replaced: Option[LogicalPlan] =
replaceExplicitDefaultValuesForInputOfInsertInto(schema, expanded)
replaced.map { r =>
Expand Down Expand Up @@ -265,14 +265,15 @@ case class ResolveDefaultColumns(catalog: SessionCatalog) extends Rule[LogicalPl
*/
private def addMissingDefaultValuesForInsertFromInlineTable(
node: LogicalPlan,
insertTableSchemaWithoutPartitionColumns: StructType): LogicalPlan = {
insertTableSchemaWithoutPartitionColumns: StructType,
numUserSpecifiedFields: Int): LogicalPlan = {
val numQueryOutputs: Int = node match {
case table: UnresolvedInlineTable => table.rows(0).size
case local: LocalRelation => local.data(0).numFields
}
val schema = insertTableSchemaWithoutPartitionColumns
val newDefaultExpressions: Seq[Expression] =
getDefaultExpressionsForInsert(numQueryOutputs, schema)
getDefaultExpressionsForInsert(numQueryOutputs, schema, numUserSpecifiedFields, node)
val newNames: Seq[String] = schema.fields.drop(numQueryOutputs).map { _.name }
node match {
case _ if newDefaultExpressions.isEmpty => node
Expand All @@ -298,11 +299,12 @@ case class ResolveDefaultColumns(catalog: SessionCatalog) extends Rule[LogicalPl
*/
private def addMissingDefaultValuesForInsertFromProject(
project: Project,
insertTableSchemaWithoutPartitionColumns: StructType): Project = {
insertTableSchemaWithoutPartitionColumns: StructType,
numUserSpecifiedFields: Int): Project = {
val numQueryOutputs: Int = project.projectList.size
val schema = insertTableSchemaWithoutPartitionColumns
val newDefaultExpressions: Seq[Expression] =
getDefaultExpressionsForInsert(numQueryOutputs, schema)
getDefaultExpressionsForInsert(numQueryOutputs, schema, numUserSpecifiedFields, project)
val newAliases: Seq[NamedExpression] =
newDefaultExpressions.zip(schema.fields).map {
case (expr, field) => Alias(expr, field.name)()
Expand All @@ -315,20 +317,19 @@ case class ResolveDefaultColumns(catalog: SessionCatalog) extends Rule[LogicalPl
*/
private def getDefaultExpressionsForInsert(
numQueryOutputs: Int,
schema: StructType): Seq[Expression] = {
val remainingFields: Seq[StructField] = schema.fields.drop(numQueryOutputs)
val numDefaultExpressionsToAdd = getStructFieldsForDefaultExpressions(remainingFields).size
Seq.fill(numDefaultExpressionsToAdd)(UnresolvedAttribute(CURRENT_DEFAULT_COLUMN_NAME))
}

/**
* This is a helper for the getDefaultExpressionsForInsert methods above.
*/
private def getStructFieldsForDefaultExpressions(fields: Seq[StructField]): Seq[StructField] = {
if (SQLConf.get.useNullsForMissingDefaultColumnValues) {
fields
schema: StructType,
numUserSpecifiedFields: Int,
treeNode: LogicalPlan): Seq[Expression] = {
if (numUserSpecifiedFields > 0 && numUserSpecifiedFields != numQueryOutputs) {
throw QueryCompilationErrors.writeTableWithMismatchedColumnsError(
numUserSpecifiedFields, numQueryOutputs, treeNode)
}
if (numUserSpecifiedFields > 0 && SQLConf.get.addMissingValuesForInsertsWithExplicitColumns) {
val remainingFields: Seq[StructField] = schema.fields.drop(numQueryOutputs)
val numDefaultExpressionsToAdd = remainingFields.size
Seq.fill(numDefaultExpressionsToAdd)(UnresolvedAttribute(CURRENT_DEFAULT_COLUMN_NAME))
} else {
fields.takeWhile(_.metadata.contains(CURRENT_DEFAULT_COLUMN_METADATA_KEY))
Seq.empty[Expression]
}
}

Expand Down Expand Up @@ -487,8 +488,7 @@ case class ResolveDefaultColumns(catalog: SessionCatalog) extends Rule[LogicalPl
schema.fields.filter {
field => !userSpecifiedColNames.contains(field.name)
}
Some(StructType(userSpecifiedFields ++
getStructFieldsForDefaultExpressions(nonUserSpecifiedFields)))
Some(StructType(userSpecifiedFields ++ nonUserSpecifiedFields))
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2924,6 +2924,18 @@ object SQLConf {
.stringConf
.createWithDefault("csv,json,orc,parquet")

val ADD_MISSING_DEFAULT_COLUMN_VALUES_FOR_INSERTS_WITH_EXPLICIT_COLUMNS =
buildConf("spark.sql.defaultColumn.addMissingValuesForInsertsWithExplicitColumns")
.internal()
.doc("When true, allow INSERT INTO commands with explicit columns (such as " +
"INSERT INTO t(a, b)) to specify fewer columns than the target table; the analyzer will " +
"assign default values for remaining columns (either NULL, or otherwise the explicit " +
"DEFAULT value associated with the column from a previous command). Otherwise, if " +
"false, return an error.")
.version("3.4.0")
.booleanConf
.createWithDefault(true)

val JSON_GENERATOR_WRITE_NULL_IF_WITH_DEFAULT_VALUE =
buildConf("spark.sql.jsonGenerator.writeNullIfWithDefaultValue")
.internal()
Expand All @@ -2936,17 +2948,6 @@ object SQLConf {
.booleanConf
.createWithDefault(true)

val USE_NULLS_FOR_MISSING_DEFAULT_COLUMN_VALUES =
buildConf("spark.sql.defaultColumn.useNullsForMissingDefaultValues")
.internal()
.doc("When true, and DEFAULT columns are enabled, allow column definitions lacking " +
"explicit default values to behave as if they had specified DEFAULT NULL instead. " +
"For example, this allows most INSERT INTO statements to specify only a prefix of the " +
"columns in the target table, and the remaining columns will receive NULL values.")
.version("3.4.0")
.booleanConf
.createWithDefault(false)

val ENFORCE_RESERVED_KEYWORDS = buildConf("spark.sql.ansi.enforceReservedKeywords")
.doc(s"When true and '${ANSI_ENABLED.key}' is true, the Spark SQL parser enforces the ANSI " +
"reserved keywords and forbids SQL queries that use reserved keywords as alias names " +
Expand Down Expand Up @@ -4530,12 +4531,12 @@ class SQLConf extends Serializable with Logging {

def defaultColumnAllowedProviders: String = getConf(SQLConf.DEFAULT_COLUMN_ALLOWED_PROVIDERS)

def addMissingValuesForInsertsWithExplicitColumns: Boolean =
getConf(SQLConf.ADD_MISSING_DEFAULT_COLUMN_VALUES_FOR_INSERTS_WITH_EXPLICIT_COLUMNS)

def jsonWriteNullIfWithDefaultValue: Boolean =
getConf(JSON_GENERATOR_WRITE_NULL_IF_WITH_DEFAULT_VALUE)

def useNullsForMissingDefaultColumnValues: Boolean =
getConf(SQLConf.USE_NULLS_FOR_MISSING_DEFAULT_COLUMN_VALUES)

def enforceReservedKeywords: Boolean = ansiEnabled && getConf(ENFORCE_RESERVED_KEYWORDS)

def timestampType: AtomicType = getConf(TIMESTAMP_TYPE) match {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -175,45 +175,48 @@ trait SQLInsertTestSuite extends QueryTest with SQLTestUtils {
test("insert with column list - mismatched column list size") {
val msgs = Seq("Cannot write to table due to mismatched user specified column size",
"expected 3 columns but found")
def test: Unit = {
withSQLConf(SQLConf.ENABLE_DEFAULT_COLUMNS.key -> "false",
SQLConf.ENABLE_DEFAULT_COLUMNS.key -> "true") {
withTable("t1") {
val cols = Seq("c1", "c2", "c3")
createTable("t1", cols, Seq("int", "long", "string"))
val e1 = intercept[AnalysisException](sql(s"INSERT INTO t1 (c1, c2) values(1, 2, 3)"))
assert(e1.getMessage.contains(msgs(0)) || e1.getMessage.contains(msgs(1)))
val e2 = intercept[AnalysisException](sql(s"INSERT INTO t1 (c1, c2, c3) values(1, 2)"))
assert(e2.getMessage.contains(msgs(0)) || e2.getMessage.contains(msgs(1)))
Seq(
"INSERT INTO t1 (c1, c2) values(1, 2, 3)",
"INSERT INTO t1 (c1, c2) select 1, 2, 3",
"INSERT INTO t1 (c1, c2, c3) values(1, 2)",
"INSERT INTO t1 (c1, c2, c3) select 1, 2"
).foreach { query =>
val e = intercept[AnalysisException](sql(query))
assert(e.getMessage.contains(msgs(0)) || e.getMessage.contains(msgs(1)))
}
}
}
withSQLConf(SQLConf.ENABLE_DEFAULT_COLUMNS.key -> "false") {
test
}
withSQLConf(SQLConf.ENABLE_DEFAULT_COLUMNS.key -> "true") {
test
}
}

test("insert with column list - mismatched target table out size after rewritten query") {
val v2Msg = "expected 2 columns but found"
val v2Msg = "Cannot write to table due to mismatched user specified column size"
val cols = Seq("c1", "c2", "c3", "c4")

withTable("t1") {
createTable("t1", cols, Seq.fill(4)("int"))
val e1 = intercept[AnalysisException](sql(s"INSERT INTO t1 (c1) values(1)"))
assert(e1.getMessage.contains("target table has 4 column(s) but the inserted data has 1") ||
e1.getMessage.contains("expected 4 columns but found 1") ||
e1.getMessage.contains("not enough data columns") ||
e1.getMessage.contains(v2Msg))
}
withSQLConf(
SQLConf.ADD_MISSING_DEFAULT_COLUMN_VALUES_FOR_INSERTS_WITH_EXPLICIT_COLUMNS.key -> "false") {
withTable("t1") {
createTable("t1", cols, Seq.fill(4)("int"))
val e1 = intercept[AnalysisException](sql(s"INSERT INTO t1 (c1) values(1)"))
assert(e1.getMessage.contains("target table has 4 column(s) but the inserted data has 1") ||
e1.getMessage.contains("expected 4 columns but found 1") ||
e1.getMessage.contains("not enough data columns") ||
e1.getMessage.contains(v2Msg))
}

withTable("t1") {
createTable("t1", cols, Seq.fill(4)("int"), cols.takeRight(2))
val e1 = intercept[AnalysisException] {
sql(s"INSERT INTO t1 partition(c3=3, c4=4) (c1) values(1)")
withTable("t1") {
createTable("t1", cols, Seq.fill(4)("int"), cols.takeRight(2))
val e1 = intercept[AnalysisException] {
sql(s"INSERT INTO t1 partition(c3=3, c4=4) (c1) values(1)")
}
assert(e1.getMessage.contains("target table has 4 column(s) but the inserted data has 3") ||
e1.getMessage.contains("not enough data columns") ||
e1.getMessage.contains(v2Msg))
}
assert(e1.getMessage.contains("target table has 4 column(s) but the inserted data has 3") ||
e1.getMessage.contains("not enough data columns") ||
e1.getMessage.contains(v2Msg))
}
}

Expand Down

0 comments on commit 13c1b59

Please sign in to comment.