Skip to content

Commit

Permalink
[SPARK-32168][SQL] Fix hidden partitioning correctness bug in SQL ove…
Browse files Browse the repository at this point in the history
…rwrite

### What changes were proposed in this pull request?

When converting an `INSERT OVERWRITE` query to a v2 overwrite plan, Spark attempts to detect when a dynamic overwrite and a static overwrite will produce the same result so it can use the static overwrite. Spark incorrectly detects when dynamic and static overwrites are equivalent when there are hidden partitions, such as `days(ts)`.

This updates the analyzer rule `ResolveInsertInto` to always use a dynamic overwrite when the mode is dynamic, and static when the mode is static. This avoids the problem by not trying to determine whether the two plans are equivalent and always using the one that corresponds to the partition overwrite mode.

### Why are the changes needed?

This is a correctness bug. If a table has hidden partitions, all of the values for those partitions are dropped instead of dynamically overwriting changed partitions.

This only affects SQL commands (not `DataFrameWriter`) writing to tables that have hidden partitions. It is also only a problem when the partition overwrite mode is dynamic.

### Does this PR introduce _any_ user-facing change?

Yes, it fixes the correctness bug detailed above.

### How was this patch tested?

* This updates the in-memory table to support a hidden partition transform, `days`, and adds a test case to `DataSourceV2SQLSuite` in which the table uses this hidden partition function. This test fails without the fix to `ResolveInsertInto`.
* This updates the test case `InsertInto: overwrite - multiple static partitions - dynamic mode` in `InsertIntoTests`. The result of the SQL command is unchanged, but the SQL command will now use a dynamic overwrite so the test now uses `dynamicOverwriteTest`.

Closes #28993 from rdblue/fix-insert-overwrite-v2-conversion.

Authored-by: Ryan Blue <blue@apache.org>
Signed-off-by: Dongjoon Hyun <dongjoon@apache.org>
(cherry picked from commit 3bb1ac5)
Signed-off-by: Dongjoon Hyun <dongjoon@apache.org>
  • Loading branch information
rdblue authored and dongjoon-hyun committed Jul 8, 2020
1 parent ac2c6cd commit 30e3fcb
Show file tree
Hide file tree
Showing 6 changed files with 107 additions and 37 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1041,12 +1041,10 @@ class Analyzer(

val staticPartitions = i.partitionSpec.filter(_._2.isDefined).mapValues(_.get)
val query = addStaticPartitionColumns(r, i.query, staticPartitions)
val dynamicPartitionOverwrite = partCols.size > staticPartitions.size &&
conf.partitionOverwriteMode == PartitionOverwriteMode.DYNAMIC

if (!i.overwrite) {
AppendData.byPosition(r, query)
} else if (dynamicPartitionOverwrite) {
} else if (conf.partitionOverwriteMode == PartitionOverwriteMode.DYNAMIC) {
OverwritePartitionsDynamic.byPosition(r, query)
} else {
OverwriteByExpression.byPosition(r, query, staticDeleteExpression(r, staticPartitions))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

package org.apache.spark.sql.connector

import java.time.{Instant, ZoneId}
import java.time.temporal.ChronoUnit
import java.util

import scala.collection.JavaConverters._
Expand All @@ -25,12 +27,13 @@ import scala.collection.mutable
import org.scalatest.Assertions._

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.sql.connector.catalog._
import org.apache.spark.sql.connector.expressions.{IdentityTransform, NamedReference, Transform}
import org.apache.spark.sql.connector.expressions.{BucketTransform, DaysTransform, HoursTransform, IdentityTransform, MonthsTransform, Transform, YearsTransform}
import org.apache.spark.sql.connector.read._
import org.apache.spark.sql.connector.write._
import org.apache.spark.sql.sources.{And, EqualTo, Filter, IsNotNull}
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.types.{DataType, DateType, StructType, TimestampType}
import org.apache.spark.sql.util.CaseInsensitiveStringMap

/**
Expand All @@ -46,10 +49,15 @@ class InMemoryTable(
private val allowUnsupportedTransforms =
properties.getOrDefault("allow-unsupported-transforms", "false").toBoolean

partitioning.foreach { t =>
if (!t.isInstanceOf[IdentityTransform] && !allowUnsupportedTransforms) {
throw new IllegalArgumentException(s"Transform $t must be IdentityTransform")
}
partitioning.foreach {
case _: IdentityTransform =>
case _: YearsTransform =>
case _: MonthsTransform =>
case _: DaysTransform =>
case _: HoursTransform =>
case _: BucketTransform =>
case t if !allowUnsupportedTransforms =>
throw new IllegalArgumentException(s"Transform $t is not a supported transform")
}

// The key `Seq[Any]` is the partition values.
Expand All @@ -66,8 +74,14 @@ class InMemoryTable(
}
}

private val UTC = ZoneId.of("UTC")
private val EPOCH_LOCAL_DATE = Instant.EPOCH.atZone(UTC).toLocalDate

private def getKey(row: InternalRow): Seq[Any] = {
def extractor(fieldNames: Array[String], schema: StructType, row: InternalRow): Any = {
def extractor(
fieldNames: Array[String],
schema: StructType,
row: InternalRow): (Any, DataType) = {
val index = schema.fieldIndex(fieldNames(0))
val value = row.toSeq(schema).apply(index)
if (fieldNames.length > 1) {
Expand All @@ -78,10 +92,44 @@ class InMemoryTable(
throw new IllegalArgumentException(s"Unsupported type, ${dataType.simpleString}")
}
} else {
value
(value, schema(index).dataType)
}
}
partCols.map(fieldNames => extractor(fieldNames, schema, row))

partitioning.map {
case IdentityTransform(ref) =>
extractor(ref.fieldNames, schema, row)._1
case YearsTransform(ref) =>
extractor(ref.fieldNames, schema, row) match {
case (days: Int, DateType) =>
ChronoUnit.YEARS.between(EPOCH_LOCAL_DATE, DateTimeUtils.daysToLocalDate(days))
case (micros: Long, TimestampType) =>
val localDate = DateTimeUtils.microsToInstant(micros).atZone(UTC).toLocalDate
ChronoUnit.YEARS.between(EPOCH_LOCAL_DATE, localDate)
}
case MonthsTransform(ref) =>
extractor(ref.fieldNames, schema, row) match {
case (days: Int, DateType) =>
ChronoUnit.MONTHS.between(EPOCH_LOCAL_DATE, DateTimeUtils.daysToLocalDate(days))
case (micros: Long, TimestampType) =>
val localDate = DateTimeUtils.microsToInstant(micros).atZone(UTC).toLocalDate
ChronoUnit.MONTHS.between(EPOCH_LOCAL_DATE, localDate)
}
case DaysTransform(ref) =>
extractor(ref.fieldNames, schema, row) match {
case (days, DateType) =>
days
case (micros: Long, TimestampType) =>
ChronoUnit.DAYS.between(Instant.EPOCH, DateTimeUtils.microsToInstant(micros))
}
case HoursTransform(ref) =>
extractor(ref.fieldNames, schema, row) match {
case (micros: Long, TimestampType) =>
ChronoUnit.HOURS.between(Instant.EPOCH, DateTimeUtils.microsToInstant(micros))
}
case BucketTransform(numBuckets, ref) =>
(extractor(ref.fieldNames, schema, row).hashCode() & Integer.MAX_VALUE) % numBuckets
}
}

def withData(data: Array[BufferedRows]): InMemoryTable = dataMap.synchronized {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ case class BatchScanExec(

override def hashCode(): Int = batch.hashCode()

override lazy val partitions: Seq[InputPartition] = batch.planInputPartitions()
@transient override lazy val partitions: Seq[InputPartition] = batch.planInputPartitions()

override lazy val readerFactory: PartitionReaderFactory = batch.createReaderFactory()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -336,7 +336,6 @@ class DataFrameWriterV2Suite extends QueryTest with SharedSparkSession with Befo
spark.table("source")
.withColumn("ts", lit("2019-06-01 10:00:00.000000").cast("timestamp"))
.writeTo("testcat.table_name")
.tableProperty("allow-unsupported-transforms", "true")
.partitionedBy(years($"ts"))
.create()

Expand All @@ -350,7 +349,6 @@ class DataFrameWriterV2Suite extends QueryTest with SharedSparkSession with Befo
spark.table("source")
.withColumn("ts", lit("2019-06-01 10:00:00.000000").cast("timestamp"))
.writeTo("testcat.table_name")
.tableProperty("allow-unsupported-transforms", "true")
.partitionedBy(months($"ts"))
.create()

Expand All @@ -364,7 +362,6 @@ class DataFrameWriterV2Suite extends QueryTest with SharedSparkSession with Befo
spark.table("source")
.withColumn("ts", lit("2019-06-01 10:00:00.000000").cast("timestamp"))
.writeTo("testcat.table_name")
.tableProperty("allow-unsupported-transforms", "true")
.partitionedBy(days($"ts"))
.create()

Expand All @@ -378,7 +375,6 @@ class DataFrameWriterV2Suite extends QueryTest with SharedSparkSession with Befo
spark.table("source")
.withColumn("ts", lit("2019-06-01 10:00:00.000000").cast("timestamp"))
.writeTo("testcat.table_name")
.tableProperty("allow-unsupported-transforms", "true")
.partitionedBy(hours($"ts"))
.create()

Expand All @@ -391,7 +387,6 @@ class DataFrameWriterV2Suite extends QueryTest with SharedSparkSession with Befo
test("Create: partitioned by bucket(4, id)") {
spark.table("source")
.writeTo("testcat.table_name")
.tableProperty("allow-unsupported-transforms", "true")
.partitionedBy(bucket(4, $"id"))
.create()

Expand Down Expand Up @@ -596,7 +591,6 @@ class DataFrameWriterV2Suite extends QueryTest with SharedSparkSession with Befo
lit("2019-09-02 07:00:00.000000").cast("timestamp") as "modified",
lit("America/Los_Angeles") as "timezone"))
.writeTo("testcat.table_name")
.tableProperty("allow-unsupported-transforms", "true")
.partitionedBy(
years($"ts.created"), months($"ts.created"), days($"ts.created"), hours($"ts.created"),
years($"ts.modified"), months($"ts.modified"), days($"ts.modified"), hours($"ts.modified")
Expand Down Expand Up @@ -624,7 +618,6 @@ class DataFrameWriterV2Suite extends QueryTest with SharedSparkSession with Befo
lit("2019-09-02 07:00:00.000000").cast("timestamp") as "modified",
lit("America/Los_Angeles") as "timezone"))
.writeTo("testcat.table_name")
.tableProperty("allow-unsupported-transforms", "true")
.partitionedBy(bucket(4, $"ts.timezone"))
.create()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@

package org.apache.spark.sql.connector

import java.sql.Timestamp
import java.time.LocalDate

import scala.collection.JavaConverters._

import org.apache.spark.SparkException
Expand All @@ -27,7 +30,7 @@ import org.apache.spark.sql.connector.catalog._
import org.apache.spark.sql.connector.catalog.CatalogManager.SESSION_CATALOG_NAME
import org.apache.spark.sql.connector.catalog.CatalogV2Util.withDefaultOwnership
import org.apache.spark.sql.internal.{SQLConf, StaticSQLConf}
import org.apache.spark.sql.internal.SQLConf.V2_SESSION_CATALOG_IMPLEMENTATION
import org.apache.spark.sql.internal.SQLConf.{PARTITION_OVERWRITE_MODE, PartitionOverwriteMode, V2_SESSION_CATALOG_IMPLEMENTATION}
import org.apache.spark.sql.internal.connector.SimpleTableProvider
import org.apache.spark.sql.sources.SimpleScanSource
import org.apache.spark.sql.types.{BooleanType, LongType, StringType, StructField, StructType}
Expand Down Expand Up @@ -1630,7 +1633,6 @@ class DataSourceV2SQLSuite
"""
|CREATE TABLE testcat.t (id int, `a.b` string) USING foo
|CLUSTERED BY (`a.b`) INTO 4 BUCKETS
|OPTIONS ('allow-unsupported-transforms'=true)
""".stripMargin)

val testCatalog = catalog("testcat").asTableCatalog.asInstanceOf[InMemoryTableCatalog]
Expand Down Expand Up @@ -2476,6 +2478,38 @@ class DataSourceV2SQLSuite
}
}

test("SPARK-32168: INSERT OVERWRITE - hidden days partition - dynamic mode") {
def testTimestamp(daysOffset: Int): Timestamp = {
Timestamp.valueOf(LocalDate.of(2020, 1, 1 + daysOffset).atStartOfDay())
}

withSQLConf(PARTITION_OVERWRITE_MODE.key -> PartitionOverwriteMode.DYNAMIC.toString) {
val t1 = s"${catalogAndNamespace}tbl"
withTable(t1) {
val df = spark.createDataFrame(Seq(
(testTimestamp(1), "a"),
(testTimestamp(2), "b"),
(testTimestamp(3), "c"))).toDF("ts", "data")
df.createOrReplaceTempView("source_view")

sql(s"CREATE TABLE $t1 (ts timestamp, data string) " +
s"USING $v2Format PARTITIONED BY (days(ts))")
sql(s"INSERT INTO $t1 VALUES " +
s"(CAST(date_add('2020-01-01', 2) AS timestamp), 'dummy'), " +
s"(CAST(date_add('2020-01-01', 4) AS timestamp), 'keep')")
sql(s"INSERT OVERWRITE TABLE $t1 SELECT ts, data FROM source_view")

val expected = spark.createDataFrame(Seq(
(testTimestamp(1), "a"),
(testTimestamp(2), "b"),
(testTimestamp(3), "c"),
(testTimestamp(4), "keep"))).toDF("ts", "data")

verifyTable(t1, expected)
}
}
}

private def testV1Command(sqlCommand: String, sqlParams: String): Unit = {
val e = intercept[AnalysisException] {
sql(s"$sqlCommand $sqlParams")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -446,21 +446,18 @@ trait InsertIntoSQLOnlyTests
}
}

test("InsertInto: overwrite - multiple static partitions - dynamic mode") {
// Since all partitions are provided statically, this should be supported by everyone
withSQLConf(PARTITION_OVERWRITE_MODE.key -> PartitionOverwriteMode.DYNAMIC.toString) {
val t1 = s"${catalogAndNamespace}tbl"
withTableAndData(t1) { view =>
sql(s"CREATE TABLE $t1 (id bigint, data string, p int) " +
s"USING $v2Format PARTITIONED BY (id, p)")
sql(s"INSERT INTO $t1 VALUES (2L, 'dummy', 2), (4L, 'keep', 2)")
sql(s"INSERT OVERWRITE TABLE $t1 PARTITION (id = 2, p = 2) SELECT data FROM $view")
verifyTable(t1, Seq(
(2, "a", 2),
(2, "b", 2),
(2, "c", 2),
(4, "keep", 2)).toDF("id", "data", "p"))
}
dynamicOverwriteTest("InsertInto: overwrite - multiple static partitions - dynamic mode") {
val t1 = s"${catalogAndNamespace}tbl"
withTableAndData(t1) { view =>
sql(s"CREATE TABLE $t1 (id bigint, data string, p int) " +
s"USING $v2Format PARTITIONED BY (id, p)")
sql(s"INSERT INTO $t1 VALUES (2L, 'dummy', 2), (4L, 'keep', 2)")
sql(s"INSERT OVERWRITE TABLE $t1 PARTITION (id = 2, p = 2) SELECT data FROM $view")
verifyTable(t1, Seq(
(2, "a", 2),
(2, "b", 2),
(2, "c", 2),
(4, "keep", 2)).toDF("id", "data", "p"))
}
}

Expand Down

0 comments on commit 30e3fcb

Please sign in to comment.