Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-40000][SQL] Update INSERTs without user-specified fields to not automatically add default values #37430

Closed
wants to merge 26 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/sql-migration-guide.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ license: |

- Since Spark 3.4, Number or Number(\*) from Teradata will be treated as Decimal(38,18). In Spark 3.3 or earlier, Number or Number(\*) from Teradata will be treated as Decimal(38, 0), in which case the fractional part will be removed.
- Since Spark 3.4, v1 database, table, permanent view and function identifier will include 'spark_catalog' as the catalog name if database is defined, e.g. a table identifier will be: `spark_catalog.default.t`. To restore the legacy behavior, set `spark.sql.legacy.v1IdentifierNoCatalog` to `true`.
- Since Spark 3.4, `INSERT INTO` commands will now support user-specified column lists comprising fewer columns than present in the target table (for example, `INSERT INTO t (a, b) VALUES (1, 2)` where table `t` has three columns). In this case, Spark will insert `NULL` into the remaining columns in the row, or the explicit `DEFAULT` value if assigned to the column. To revert to the previous behavior, please set `spark.sql.defaultColumn.addMissingValuesForInsertsWithExplicitColumns` to false.
- Since Spark 3.4, when ANSI SQL mode(configuration `spark.sql.ansi.enabled`) is on, Spark SQL always returns NULL result on getting a map value with a non-existing key. In Spark 3.3 or earlier, there will be an error.

## Upgrading from Spark SQL 3.2 to 3.3
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ case class ResolveDefaultColumns(catalog: SessionCatalog) extends Rule[LogicalPl
val regenerated: InsertIntoStatement =
regenerateUserSpecifiedCols(i, schema)
val expanded: LogicalPlan =
addMissingDefaultValuesForInsertFromInlineTable(node, schema)
addMissingDefaultValuesForInsertFromInlineTable(node, schema, i.userSpecifiedCols.length)
val replaced: Option[LogicalPlan] =
replaceExplicitDefaultValuesForInputOfInsertInto(schema, expanded)
replaced.map { r: LogicalPlan =>
Expand All @@ -132,7 +132,7 @@ case class ResolveDefaultColumns(catalog: SessionCatalog) extends Rule[LogicalPl
val regenerated: InsertIntoStatement = regenerateUserSpecifiedCols(i, schema)
val project: Project = i.query.asInstanceOf[Project]
val expanded: Project =
addMissingDefaultValuesForInsertFromProject(project, schema)
addMissingDefaultValuesForInsertFromProject(project, schema, i.userSpecifiedCols.length)
val replaced: Option[LogicalPlan] =
replaceExplicitDefaultValuesForInputOfInsertInto(schema, expanded)
replaced.map { r =>
Expand Down Expand Up @@ -265,14 +265,15 @@ case class ResolveDefaultColumns(catalog: SessionCatalog) extends Rule[LogicalPl
*/
private def addMissingDefaultValuesForInsertFromInlineTable(
node: LogicalPlan,
insertTableSchemaWithoutPartitionColumns: StructType): LogicalPlan = {
insertTableSchemaWithoutPartitionColumns: StructType,
numUserSpecifiedFields: Int): LogicalPlan = {
val numQueryOutputs: Int = node match {
case table: UnresolvedInlineTable => table.rows(0).size
case local: LocalRelation => local.data(0).numFields
}
val schema = insertTableSchemaWithoutPartitionColumns
val newDefaultExpressions: Seq[Expression] =
getDefaultExpressionsForInsert(numQueryOutputs, schema)
getDefaultExpressionsForInsert(numQueryOutputs, schema, numUserSpecifiedFields, node)
val newNames: Seq[String] = schema.fields.drop(numQueryOutputs).map { _.name }
node match {
case _ if newDefaultExpressions.isEmpty => node
Expand All @@ -298,11 +299,12 @@ case class ResolveDefaultColumns(catalog: SessionCatalog) extends Rule[LogicalPl
*/
private def addMissingDefaultValuesForInsertFromProject(
project: Project,
insertTableSchemaWithoutPartitionColumns: StructType): Project = {
insertTableSchemaWithoutPartitionColumns: StructType,
numUserSpecifiedFields: Int): Project = {
val numQueryOutputs: Int = project.projectList.size
val schema = insertTableSchemaWithoutPartitionColumns
val newDefaultExpressions: Seq[Expression] =
getDefaultExpressionsForInsert(numQueryOutputs, schema)
getDefaultExpressionsForInsert(numQueryOutputs, schema, numUserSpecifiedFields, project)
val newAliases: Seq[NamedExpression] =
newDefaultExpressions.zip(schema.fields).map {
case (expr, field) => Alias(expr, field.name)()
Expand All @@ -315,20 +317,19 @@ case class ResolveDefaultColumns(catalog: SessionCatalog) extends Rule[LogicalPl
*/
private def getDefaultExpressionsForInsert(
numQueryOutputs: Int,
schema: StructType): Seq[Expression] = {
val remainingFields: Seq[StructField] = schema.fields.drop(numQueryOutputs)
val numDefaultExpressionsToAdd = getStructFieldsForDefaultExpressions(remainingFields).size
Seq.fill(numDefaultExpressionsToAdd)(UnresolvedAttribute(CURRENT_DEFAULT_COLUMN_NAME))
}

/**
* This is a helper for the getDefaultExpressionsForInsert methods above.
*/
private def getStructFieldsForDefaultExpressions(fields: Seq[StructField]): Seq[StructField] = {
if (SQLConf.get.useNullsForMissingDefaultColumnValues) {
fields
schema: StructType,
numUserSpecifiedFields: Int,
treeNode: LogicalPlan): Seq[Expression] = {
if (numUserSpecifiedFields > 0 && numUserSpecifiedFields != numQueryOutputs) {
throw QueryCompilationErrors.writeTableWithMismatchedColumnsError(
numUserSpecifiedFields, numQueryOutputs, treeNode)
}
if (numUserSpecifiedFields > 0 && SQLConf.get.addMissingValuesForInsertsWithExplicitColumns) {
val remainingFields: Seq[StructField] = schema.fields.drop(numQueryOutputs)
val numDefaultExpressionsToAdd = remainingFields.size
Seq.fill(numDefaultExpressionsToAdd)(UnresolvedAttribute(CURRENT_DEFAULT_COLUMN_NAME))
} else {
fields.takeWhile(_.metadata.contains(CURRENT_DEFAULT_COLUMN_METADATA_KEY))
Seq.empty[Expression]
}
}

Expand Down Expand Up @@ -487,8 +488,7 @@ case class ResolveDefaultColumns(catalog: SessionCatalog) extends Rule[LogicalPl
schema.fields.filter {
field => !userSpecifiedColNames.contains(field.name)
}
Some(StructType(userSpecifiedFields ++
getStructFieldsForDefaultExpressions(nonUserSpecifiedFields)))
Some(StructType(userSpecifiedFields ++ nonUserSpecifiedFields))
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2924,6 +2924,18 @@ object SQLConf {
.stringConf
.createWithDefault("csv,json,orc,parquet")

val ADD_MISSING_DEFAULT_COLUMN_VALUES_FOR_INSERTS_WITH_EXPLICIT_COLUMNS =
buildConf("spark.sql.defaultColumn.addMissingValuesForInsertsWithExplicitColumns")
.internal()
.doc("When true, allow INSERT INTO commands with explicit columns (such as " +
"INSERT INTO t(a, b)) to specify fewer columns than the target table; the analyzer will " +
"assign default values for remaining columns (either NULL, or otherwise the explicit " +
"DEFAULT value associated with the column from a previous command). Otherwise, if " +
"false, return an error.")
.version("3.4.0")
.booleanConf
.createWithDefault(true)

val JSON_GENERATOR_WRITE_NULL_IF_WITH_DEFAULT_VALUE =
buildConf("spark.sql.jsonGenerator.writeNullIfWithDefaultValue")
.internal()
Expand All @@ -2936,17 +2948,6 @@ object SQLConf {
.booleanConf
.createWithDefault(true)

val USE_NULLS_FOR_MISSING_DEFAULT_COLUMN_VALUES =
buildConf("spark.sql.defaultColumn.useNullsForMissingDefaultValues")
.internal()
.doc("When true, and DEFAULT columns are enabled, allow column definitions lacking " +
"explicit default values to behave as if they had specified DEFAULT NULL instead. " +
"For example, this allows most INSERT INTO statements to specify only a prefix of the " +
"columns in the target table, and the remaining columns will receive NULL values.")
.version("3.4.0")
.booleanConf
.createWithDefault(false)

val ENFORCE_RESERVED_KEYWORDS = buildConf("spark.sql.ansi.enforceReservedKeywords")
.doc(s"When true and '${ANSI_ENABLED.key}' is true, the Spark SQL parser enforces the ANSI " +
"reserved keywords and forbids SQL queries that use reserved keywords as alias names " +
Expand Down Expand Up @@ -4530,12 +4531,12 @@ class SQLConf extends Serializable with Logging {

def defaultColumnAllowedProviders: String = getConf(SQLConf.DEFAULT_COLUMN_ALLOWED_PROVIDERS)

def addMissingValuesForInsertsWithExplicitColumns: Boolean =
getConf(SQLConf.ADD_MISSING_DEFAULT_COLUMN_VALUES_FOR_INSERTS_WITH_EXPLICIT_COLUMNS)

def jsonWriteNullIfWithDefaultValue: Boolean =
getConf(JSON_GENERATOR_WRITE_NULL_IF_WITH_DEFAULT_VALUE)

def useNullsForMissingDefaultColumnValues: Boolean =
getConf(SQLConf.USE_NULLS_FOR_MISSING_DEFAULT_COLUMN_VALUES)

def enforceReservedKeywords: Boolean = ansiEnabled && getConf(ENFORCE_RESERVED_KEYWORDS)

def timestampType: AtomicType = getConf(TIMESTAMP_TYPE) match {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -175,45 +175,48 @@ trait SQLInsertTestSuite extends QueryTest with SQLTestUtils {
test("insert with column list - mismatched column list size") {
val msgs = Seq("Cannot write to table due to mismatched user specified column size",
"expected 3 columns but found")
def test: Unit = {
withSQLConf(SQLConf.ENABLE_DEFAULT_COLUMNS.key -> "false",
SQLConf.ENABLE_DEFAULT_COLUMNS.key -> "true") {
withTable("t1") {
val cols = Seq("c1", "c2", "c3")
createTable("t1", cols, Seq("int", "long", "string"))
val e1 = intercept[AnalysisException](sql(s"INSERT INTO t1 (c1, c2) values(1, 2, 3)"))
assert(e1.getMessage.contains(msgs(0)) || e1.getMessage.contains(msgs(1)))
val e2 = intercept[AnalysisException](sql(s"INSERT INTO t1 (c1, c2, c3) values(1, 2)"))
assert(e2.getMessage.contains(msgs(0)) || e2.getMessage.contains(msgs(1)))
Seq(
"INSERT INTO t1 (c1, c2) values(1, 2, 3)",
"INSERT INTO t1 (c1, c2) select 1, 2, 3",
"INSERT INTO t1 (c1, c2, c3) values(1, 2)",
"INSERT INTO t1 (c1, c2, c3) select 1, 2"
).foreach { query =>
val e = intercept[AnalysisException](sql(query))
assert(e.getMessage.contains(msgs(0)) || e.getMessage.contains(msgs(1)))
}
}
}
withSQLConf(SQLConf.ENABLE_DEFAULT_COLUMNS.key -> "false") {
test
}
withSQLConf(SQLConf.ENABLE_DEFAULT_COLUMNS.key -> "true") {
test
}
}

test("insert with column list - mismatched target table out size after rewritten query") {
val v2Msg = "expected 2 columns but found"
val v2Msg = "Cannot write to table due to mismatched user specified column size"
val cols = Seq("c1", "c2", "c3", "c4")

withTable("t1") {
createTable("t1", cols, Seq.fill(4)("int"))
val e1 = intercept[AnalysisException](sql(s"INSERT INTO t1 (c1) values(1)"))
assert(e1.getMessage.contains("target table has 4 column(s) but the inserted data has 1") ||
e1.getMessage.contains("expected 4 columns but found 1") ||
e1.getMessage.contains("not enough data columns") ||
e1.getMessage.contains(v2Msg))
}
withSQLConf(
SQLConf.ADD_MISSING_DEFAULT_COLUMN_VALUES_FOR_INSERTS_WITH_EXPLICIT_COLUMNS.key -> "false") {
withTable("t1") {
createTable("t1", cols, Seq.fill(4)("int"))
val e1 = intercept[AnalysisException](sql(s"INSERT INTO t1 (c1) values(1)"))
assert(e1.getMessage.contains("target table has 4 column(s) but the inserted data has 1") ||
e1.getMessage.contains("expected 4 columns but found 1") ||
e1.getMessage.contains("not enough data columns") ||
e1.getMessage.contains(v2Msg))
}

withTable("t1") {
createTable("t1", cols, Seq.fill(4)("int"), cols.takeRight(2))
val e1 = intercept[AnalysisException] {
sql(s"INSERT INTO t1 partition(c3=3, c4=4) (c1) values(1)")
withTable("t1") {
createTable("t1", cols, Seq.fill(4)("int"), cols.takeRight(2))
val e1 = intercept[AnalysisException] {
sql(s"INSERT INTO t1 partition(c3=3, c4=4) (c1) values(1)")
}
assert(e1.getMessage.contains("target table has 4 column(s) but the inserted data has 3") ||
e1.getMessage.contains("not enough data columns") ||
e1.getMessage.contains(v2Msg))
}
assert(e1.getMessage.contains("target table has 4 column(s) but the inserted data has 3") ||
e1.getMessage.contains("not enough data columns") ||
e1.getMessage.contains(v2Msg))
}
}

Expand Down