Skip to content

Commit e17df6e

Browse files
aokolnychyicloud-fan
authored andcommitted
[SPARK-51290][SQL] Enable filling default values in DSv2 writes
### What changes were proposed in this pull request? This PR enables filling default values in DSv2 writes. ### Why are the changes needed? These changes are needed for proper support of default values for DSv2 connectors. ### Does this PR introduce _any_ user-facing change? Users will be able to omit columns with default values. There is no impact to existing jobs. ### How was this patch tested? This patch comes with tests. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #50044 from aokolnychyi/spark-51290. Authored-by: Anton Okolnychyi <aokolnychyi@apache.org> Signed-off-by: Wenchen Fan <wenchen@databricks.com>
1 parent 74293cc commit e17df6e

10 files changed

Lines changed: 153 additions & 46 deletions

File tree

sql/api/src/main/scala/org/apache/spark/sql/types/StructField.scala

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -214,6 +214,10 @@ case class StructField(
214214
}
215215
}
216216

217+
private[sql] def hasExistenceDefaultValue: Boolean = {
218+
metadata.contains(EXISTS_DEFAULT_COLUMN_METADATA_KEY)
219+
}
220+
217221
private def getDDLDefault = getCurrentDefaultValue()
218222
.map(" DEFAULT " + _)
219223
.getOrElse("")

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3534,7 +3534,8 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor
35343534
TableOutputResolver.suitableForByNameCheck(v2Write.isByName,
35353535
expected = v2Write.table.output, queryOutput = v2Write.query.output)
35363536
val projection = TableOutputResolver.resolveOutputColumns(
3537-
v2Write.table.name, v2Write.table.output, v2Write.query, v2Write.isByName, conf)
3537+
v2Write.table.name, v2Write.table.output, v2Write.query, v2Write.isByName, conf,
3538+
supportColDefaultValue = true)
35383539
if (projection != v2Write.query) {
35393540
val cleanedTable = v2Write.table match {
35403541
case r: DataSourceV2Relation =>

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TableOutputResolver.scala

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,6 @@ object TableOutputResolver extends SQLConfHelper with Logging {
8080
query: LogicalPlan,
8181
byName: Boolean,
8282
conf: SQLConf,
83-
// TODO: Only DS v1 writing will set it to true. We should enable in for DS v2 as well.
8483
supportColDefaultValue: Boolean = false): LogicalPlan = {
8584

8685
if (expected.size < query.output.size) {

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ResolveDefaultColumnsUtil.scala

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ import org.apache.spark.sql.internal.connector.V1Function
4040
import org.apache.spark.sql.types._
4141
import org.apache.spark.sql.util.CaseInsensitiveStringMap
4242
import org.apache.spark.util.ArrayImplicits._
43+
import org.apache.spark.util.Utils
4344

4445
/**
4546
* This object contains fields to help process DEFAULT columns.
@@ -120,7 +121,11 @@ object ResolveDefaultColumns extends QueryErrorsBase
120121
schema.exists(_.metadata.contains(CURRENT_DEFAULT_COLUMN_METADATA_KEY))) {
121122
val keywords: Array[String] = SQLConf.get.getConf(SQLConf.DEFAULT_COLUMN_ALLOWED_PROVIDERS)
122123
.toLowerCase().split(",").map(_.trim)
123-
val allowedTableProviders: Array[String] = keywords.map(_.stripSuffix("*"))
124+
val allowedTableProviders: Array[String] = if (Utils.isTesting) {
125+
"in-memory" +: keywords.map(_.stripSuffix("*"))
126+
} else {
127+
keywords.map(_.stripSuffix("*"))
128+
}
124129
val addColumnExistingTableBannedProviders: Array[String] =
125130
keywords.filter(_.endsWith("*")).map(_.stripSuffix("*"))
126131
val givenTableProvider: String = tableProvider.getOrElse("").toLowerCase()
@@ -459,15 +464,17 @@ object ResolveDefaultColumns extends QueryErrorsBase
459464
* Any type suitable for assigning into a row using the InternalRow.update method.
460465
*/
461466
def getExistenceDefaultValues(schema: StructType): Array[Any] = {
462-
schema.fields.map { field: StructField =>
463-
val defaultValue: Option[String] = field.getExistenceDefaultValue()
464-
defaultValue.map { _: String =>
465-
val expr = analyzeExistenceDefaultValue(field)
466-
467-
// The expression should be a literal value by this point, possibly wrapped in a cast
468-
// function. This is enforced by the execution of commands that assign default values.
469-
expr.eval()
470-
}.orNull
467+
schema.fields.map(getExistenceDefaultValue)
468+
}
469+
470+
def getExistenceDefaultValue(field: StructField): Any = {
471+
if (field.hasExistenceDefaultValue) {
472+
val expr = analyzeExistenceDefaultValue(field)
473+
// The expression should be a literal value by this point, possibly wrapped in a cast
474+
// function. This is enforced by the execution of commands that assign default values.
475+
expr.eval()
476+
} else {
477+
null
471478
}
472479
}
473480

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/V2WriteAnalysisSuite.scala

Lines changed: 24 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -420,12 +420,14 @@ abstract class V2WriteAnalysisSuiteBase extends AnalysisTest {
420420

421421
val parsedPlan = byName(table, query)
422422

423-
assertNotResolved(parsedPlan)
424-
assertAnalysisErrorCondition(
425-
parsedPlan,
426-
expectedErrorCondition = "INCOMPATIBLE_DATA_FOR_TABLE.CANNOT_FIND_DATA",
427-
expectedMessageParameters = Map("tableName" -> "`table-name`", "colName" -> "`x`")
428-
)
423+
withSQLConf(SQLConf.USE_NULLS_FOR_MISSING_DEFAULT_COLUMN_VALUES.key -> "false") {
424+
assertNotResolved(parsedPlan)
425+
assertAnalysisErrorCondition(
426+
parsedPlan,
427+
expectedErrorCondition = "INCOMPATIBLE_DATA_FOR_TABLE.CANNOT_FIND_DATA",
428+
expectedMessageParameters = Map("tableName" -> "`table-name`", "colName" -> "`x`")
429+
)
430+
}
429431
}
430432

431433
test("byName: case sensitive column resolution") {
@@ -435,12 +437,14 @@ abstract class V2WriteAnalysisSuiteBase extends AnalysisTest {
435437

436438
val parsedPlan = byName(table, query)
437439

438-
assertNotResolved(parsedPlan)
439-
assertAnalysisErrorCondition(
440-
parsedPlan,
441-
expectedErrorCondition = "INCOMPATIBLE_DATA_FOR_TABLE.CANNOT_FIND_DATA",
442-
expectedMessageParameters = Map("tableName" -> "`table-name`", "colName" -> "`x`")
443-
)
440+
withSQLConf(SQLConf.USE_NULLS_FOR_MISSING_DEFAULT_COLUMN_VALUES.key -> "false") {
441+
assertNotResolved(parsedPlan)
442+
assertAnalysisErrorCondition(
443+
parsedPlan,
444+
expectedErrorCondition = "INCOMPATIBLE_DATA_FOR_TABLE.CANNOT_FIND_DATA",
445+
expectedMessageParameters = Map("tableName" -> "`table-name`", "colName" -> "`x`")
446+
)
447+
}
444448
}
445449

446450
test("byName: case insensitive column resolution") {
@@ -513,12 +517,14 @@ abstract class V2WriteAnalysisSuiteBase extends AnalysisTest {
513517

514518
val parsedPlan = byName(table, query)
515519

516-
assertNotResolved(parsedPlan)
517-
assertAnalysisErrorCondition(
518-
parsedPlan,
519-
expectedErrorCondition = "INCOMPATIBLE_DATA_FOR_TABLE.CANNOT_FIND_DATA",
520-
expectedMessageParameters = Map("tableName" -> "`table-name`", "colName" -> "`x`")
521-
)
520+
withSQLConf(SQLConf.USE_NULLS_FOR_MISSING_DEFAULT_COLUMN_VALUES.key -> "false") {
521+
assertNotResolved(parsedPlan)
522+
assertAnalysisErrorCondition(
523+
parsedPlan,
524+
expectedErrorCondition = "INCOMPATIBLE_DATA_FOR_TABLE.CANNOT_FIND_DATA",
525+
expectedMessageParameters = Map("tableName" -> "`table-name`", "colName" -> "`x`")
526+
)
527+
}
522528
}
523529

524530
test("byName: insert safe cast") {

sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryBaseTable.scala

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ import com.google.common.base.Objects
2828

2929
import org.apache.spark.sql.catalyst.InternalRow
3030
import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, JoinedRow, MetadataStructFieldWithLogicalName}
31-
import org.apache.spark.sql.catalyst.util.{CharVarcharUtils, DateTimeUtils}
31+
import org.apache.spark.sql.catalyst.util.{CharVarcharUtils, DateTimeUtils, ResolveDefaultColumns}
3232
import org.apache.spark.sql.connector.distributions.{Distribution, Distributions}
3333
import org.apache.spark.sql.connector.expressions._
3434
import org.apache.spark.sql.connector.metric.{CustomMetric, CustomSumMetric, CustomTaskMetric}
@@ -141,7 +141,8 @@ abstract class InMemoryBaseTable(
141141
schema: StructType,
142142
row: InternalRow): (Any, DataType) = {
143143
val index = schema.fieldIndex(fieldNames(0))
144-
val value = row.toSeq(schema).apply(index)
144+
val field = schema(index)
145+
val value = row.get(index, field.dataType)
145146
if (fieldNames.length > 1) {
146147
(value, schema(index).dataType) match {
147148
case (row: InternalRow, nestedSchema: StructType) =>
@@ -400,18 +401,23 @@ abstract class InMemoryBaseTable(
400401
val sizeInBytes = numRows * rowSizeInBytes
401402

402403
val numOfCols = tableSchema.fields.length
403-
val dataTypes = tableSchema.fields.map(_.dataType)
404-
val colValueSets = new Array[util.HashSet[Object]](numOfCols)
404+
val colValueSets = new Array[util.HashSet[Any]](numOfCols)
405405
val numOfNulls = new Array[Long](numOfCols)
406406
for (i <- 0 until numOfCols) {
407-
colValueSets(i) = new util.HashSet[Object]
407+
colValueSets(i) = new util.HashSet[Any]
408408
}
409409

410410
inputPartitions.foreach(inputPartition =>
411411
inputPartition.rows.foreach(row =>
412412
for (i <- 0 until numOfCols) {
413-
colValueSets(i).add(row.get(i, dataTypes(i)))
414-
if (row.isNullAt(i)) {
413+
val field = tableSchema(i)
414+
val colValue = if (i < row.numFields) {
415+
row.get(i, field.dataType)
416+
} else {
417+
ResolveDefaultColumns.getExistenceDefaultValue(field)
418+
}
419+
colValueSets(i).add(colValue)
420+
if (colValue == null) {
415421
numOfNulls(i) += 1
416422
}
417423
}
@@ -718,6 +724,11 @@ private class BufferedRowsReader(
718724
schema: StructType,
719725
row: InternalRow): Any = {
720726
val index = schema.fieldIndex(field.name)
727+
728+
if (index >= row.numFields) {
729+
return ResolveDefaultColumns.getExistenceDefaultValue(field)
730+
}
731+
721732
field.dataType match {
722733
case StructType(fields) =>
723734
if (row.isNullAt(index)) {

sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTableCatalog.scala

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,11 @@ class BasicInMemoryTableCatalog extends TableCatalog {
128128
override def alterTable(ident: Identifier, changes: TableChange*): Table = {
129129
val table = loadTable(ident).asInstanceOf[InMemoryTable]
130130
val properties = CatalogV2Util.applyPropertiesChanges(table.properties, changes)
131-
val schema = CatalogV2Util.applySchemaChanges(table.schema, changes, None, "ALTER TABLE")
131+
val schema = CatalogV2Util.applySchemaChanges(
132+
table.schema,
133+
changes,
134+
tableProvider = Some("in-memory"),
135+
statementType = "ALTER TABLE")
132136
val finalPartitioning = CatalogV2Util.applyClusterByChanges(table.partitioning, schema, changes)
133137

134138
// fail if the last column in the schema was dropped

sql/core/src/test/scala/org/apache/spark/sql/DataFrameWriterV2Suite.scala

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -146,8 +146,8 @@ class DataFrameWriterV2Suite extends QueryTest with SharedSparkSession with Befo
146146
exception = intercept[AnalysisException] {
147147
spark.table("source").withColumnRenamed("data", "d").writeTo("testcat.table_name").append()
148148
},
149-
condition = "INCOMPATIBLE_DATA_FOR_TABLE.CANNOT_FIND_DATA",
150-
parameters = Map("tableName" -> "`testcat`.`table_name`", "colName" -> "`data`")
149+
condition = "INCOMPATIBLE_DATA_FOR_TABLE.EXTRA_COLUMNS",
150+
parameters = Map("tableName" -> "`testcat`.`table_name`", "extraColumns" -> "`d`")
151151
)
152152

153153
checkAnswer(
@@ -251,8 +251,8 @@ class DataFrameWriterV2Suite extends QueryTest with SharedSparkSession with Befo
251251
spark.table("source").withColumnRenamed("data", "d")
252252
.writeTo("testcat.table_name").overwrite(lit(true))
253253
},
254-
condition = "INCOMPATIBLE_DATA_FOR_TABLE.CANNOT_FIND_DATA",
255-
parameters = Map("tableName" -> "`testcat`.`table_name`", "colName" -> "`data`")
254+
condition = "INCOMPATIBLE_DATA_FOR_TABLE.EXTRA_COLUMNS",
255+
parameters = Map("tableName" -> "`testcat`.`table_name`", "extraColumns" -> "`d`")
256256
)
257257

258258
checkAnswer(
@@ -356,8 +356,8 @@ class DataFrameWriterV2Suite extends QueryTest with SharedSparkSession with Befo
356356
spark.table("source").withColumnRenamed("data", "d")
357357
.writeTo("testcat.table_name").overwritePartitions()
358358
},
359-
condition = "INCOMPATIBLE_DATA_FOR_TABLE.CANNOT_FIND_DATA",
360-
parameters = Map("tableName" -> "`testcat`.`table_name`", "colName" -> "`data`")
359+
condition = "INCOMPATIBLE_DATA_FOR_TABLE.EXTRA_COLUMNS",
360+
parameters = Map("tableName" -> "`testcat`.`table_name`", "extraColumns" -> "`d`")
361361
)
362362

363363
checkAnswer(

sql/core/src/test/scala/org/apache/spark/sql/SQLInsertTestSuite.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -218,10 +218,10 @@ trait SQLInsertTestSuite extends QueryTest with SQLTestUtils with AdaptiveSparkP
218218
processInsert("t1", df, overwrite = false, byName = true)
219219
},
220220
v1ErrorClass = "INCOMPATIBLE_DATA_FOR_TABLE.EXTRA_COLUMNS",
221-
v2ErrorClass = "INCOMPATIBLE_DATA_FOR_TABLE.CANNOT_FIND_DATA",
221+
v2ErrorClass = "INCOMPATIBLE_DATA_FOR_TABLE.EXTRA_COLUMNS",
222222
v1Parameters = Map("tableName" -> "`spark_catalog`.`default`.`t1`",
223223
"extraColumns" -> "`x1`"),
224-
v2Parameters = Map("tableName" -> "`testcat`.`t1`", "colName" -> "`c1`")
224+
v2Parameters = Map("tableName" -> "`testcat`.`t1`", "extraColumns" -> "`x1`")
225225
)
226226
val df2 = Seq((3, 2, 1, 0)).toDF(Seq("c3", "c2", "c1", "c0"): _*)
227227
checkV1AndV2Error(

sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2DataFrameSuite.scala

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -263,4 +263,79 @@ class DataSourceV2DataFrameSuite
263263
spark.listenerManager.unregister(listener)
264264
}
265265
}
266+
267+
test("add columns with default values") {
268+
val tableName = "testcat.ns1.ns2.tbl"
269+
withTable(tableName) {
270+
sql(s"CREATE TABLE $tableName (id INT, dep STRING) USING foo")
271+
272+
val df1 = Seq((1, "hr")).toDF("id", "dep")
273+
df1.writeTo(tableName).append()
274+
275+
sql(s"ALTER TABLE $tableName ADD COLUMN txt STRING DEFAULT 'initial-text'")
276+
277+
val df2 = Seq((2, "hr"), (3, "software")).toDF("id", "dep")
278+
df2.writeTo(tableName).append()
279+
280+
sql(s"ALTER TABLE $tableName ALTER COLUMN txt SET DEFAULT 'new-text'")
281+
282+
val df3 = Seq((4, "hr"), (5, "hr")).toDF("id", "dep")
283+
df3.writeTo(tableName).append()
284+
285+
val df4 = Seq((6, "hr", null), (7, "hr", "explicit-text")).toDF("id", "dep", "txt")
286+
df4.writeTo(tableName).append()
287+
288+
sql(s"ALTER TABLE $tableName ALTER COLUMN txt DROP DEFAULT")
289+
290+
val df5 = Seq((8, "hr"), (9, "hr")).toDF("id", "dep")
291+
df5.writeTo(tableName).append()
292+
293+
checkAnswer(
294+
sql(s"SELECT * FROM $tableName"),
295+
Seq(
296+
Row(1, "hr", "initial-text"),
297+
Row(2, "hr", "initial-text"),
298+
Row(3, "software", "initial-text"),
299+
Row(4, "hr", "new-text"),
300+
Row(5, "hr", "new-text"),
301+
Row(6, "hr", null),
302+
Row(7, "hr", "explicit-text"),
303+
Row(8, "hr", null),
304+
Row(9, "hr", null)))
305+
}
306+
}
307+
308+
test("create/replace table with default values") {
309+
val tableName = "testcat.ns1.ns2.tbl"
310+
withTable(tableName) {
311+
sql(s"CREATE TABLE $tableName (id INT, dep STRING DEFAULT 'hr') USING foo")
312+
313+
val df1 = Seq(1, 2).toDF("id")
314+
df1.writeTo(tableName).append()
315+
316+
sql(s"ALTER TABLE $tableName ALTER COLUMN dep SET DEFAULT 'it'")
317+
318+
val df2 = Seq(3, 4).toDF("id")
319+
df2.writeTo(tableName).append()
320+
321+
checkAnswer(
322+
sql(s"SELECT * FROM $tableName"),
323+
Seq(
324+
Row(1, "hr"),
325+
Row(2, "hr"),
326+
Row(3, "it"),
327+
Row(4, "it")))
328+
329+
sql(s"REPLACE TABLE $tableName (id INT, dep STRING DEFAULT 'unknown') USING foo")
330+
331+
val df3 = Seq(1, 2).toDF("id")
332+
df3.writeTo(tableName).append()
333+
334+
checkAnswer(
335+
sql(s"SELECT * FROM $tableName"),
336+
Seq(
337+
Row(1, "unknown"),
338+
Row(2, "unknown")))
339+
}
340+
}
266341
}

0 commit comments

Comments
 (0)