Skip to content

Commit

Permalink
[SPARK-34331][SQL] Speed up DS v2 metadata col resolution
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?

This is a follow-up of #28027

#28027 added a DS v2 API that allows data sources to produce metadata/hidden columns that can only be seen when it's explicitly selected. The way we integrate this API into Spark is:
1. The v2 relation gets normal output and metadata output from the data source, and the metadata output is excluded from the plan output by default.
2. column resolution can resolve `UnresolvedAttribute` with metadata columns, even if the child plan doesn't output metadata columns.
3. An analyzer rule searches the query plan, trying to find a node that has missing inputs. If such node is found, transform the sub-plan of this node, and update the v2 relation to include the metadata output.

The analyzer rule in step 3 brings a perf regression, for queries that do not read v2 tables at all. This rule will calculate `QueryPlan.inputSet` (which builds an `AttributeSet` from outputs of all children) and `QueryPlan.missingInput` (which does a set exclusion and creates a new `AttributeSet`) for every plan node in the query plan. In our benchmark, the TPCDS query compilation time gets increased by more than 10%

This PR proposes a simple way to improve it: we add a special metadata entry to the metadata attribute, which allows us to quickly check if a plan needs to add metadata columns: we just check all the references of this plan, and see if the attribute contains the special metadata entry, instead of calculating `QueryPlan.missingInput`.

This PR also fixes one bug: we should not change the final output schema of the plan, if we only use metadata columns in operators like filter, sort, etc.

### Why are the changes needed?

Fix perf regression in SQL query compilation, and fix a bug.

### Does this PR introduce _any_ user-facing change?

No

### How was this patch tested?

Run `org.apache.spark.sql.TPCDSQuerySuite`, before this PR, `AddMetadataColumns` is the top 4 rule ranked by running time
```
=== Metrics of Analyzer/Optimizer Rules ===
Total number of runs: 407641
Total time: 47.257239779 seconds

Rule                                  Effective Time / Total Time                     Effective Runs / Total Runs

OptimizeSubqueries                      4157690003 / 8485444626                         49 / 2778
Analyzer$ResolveAggregateFunctions      1238968711 / 3369351761                         49 / 2141
ColumnPruning                           660038236 / 2924755292                          338 / 6391
Analyzer$AddMetadataColumns             0 / 2918352992                                  0 / 2151
```
after this PR:
```
Analyzer$AddMetadataColumns             0 / 122885629                                   0 / 2151
```
This rule is 20 times faster and is negligible to the total compilation time.

This PR also add new tests to verify the bug fix.

Closes #31440 from cloud-fan/metadata-col.

Authored-by: Wenchen Fan <wenchen@databricks.com>
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
  • Loading branch information
cloud-fan committed Feb 5, 2021
1 parent ee11a8f commit 989eb68
Show file tree
Hide file tree
Showing 4 changed files with 77 additions and 19 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -965,11 +965,37 @@ class Analyzer(override val catalogManager: CatalogManager)
* columns are not accidentally selected by *.
*/
object AddMetadataColumns extends Rule[LogicalPlan] {
import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Implicits._

private def hasMetadataCol(plan: LogicalPlan): Boolean = {
plan.expressions.exists(_.find {
case a: Attribute => a.isMetadataCol
case _ => false
}.isDefined)
}

private def addMetadataCol(plan: LogicalPlan): LogicalPlan = plan match {
case r: DataSourceV2Relation => r.withMetadataColumns()
case _ => plan.withNewChildren(plan.children.map(addMetadataCol))
}

def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperatorsUp {
case node if node.resolved && node.children.nonEmpty && node.missingInput.nonEmpty =>
node resolveOperatorsUp {
case rel: DataSourceV2Relation =>
rel.withMetadataColumns()
case node if node.children.nonEmpty && node.resolved && hasMetadataCol(node) =>
val inputAttrs = AttributeSet(node.children.flatMap(_.output))
val metaCols = node.expressions.flatMap(_.collect {
case a: Attribute if a.isMetadataCol && !inputAttrs.contains(a) => a
})
if (metaCols.isEmpty) {
node
} else {
val newNode = addMetadataCol(node)
// We should not change the output schema of the plan. We should project away the extr
// metadata columns if necessary.
if (newNode.sameOutput(node)) {
newNode
} else {
Project(node.output, newNode)
}
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,14 @@ import scala.collection.JavaConverters._

import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.analysis.{PartitionSpec, ResolvedPartitionSpec, UnresolvedPartitionSpec}
import org.apache.spark.sql.catalyst.expressions.AttributeReference
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference}
import org.apache.spark.sql.connector.catalog.{MetadataColumn, SupportsAtomicPartitionManagement, SupportsDelete, SupportsPartitionManagement, SupportsRead, SupportsWrite, Table, TableCapability}
import org.apache.spark.sql.types.{StructField, StructType}
import org.apache.spark.sql.types.{MetadataBuilder, StructField, StructType}
import org.apache.spark.sql.util.CaseInsensitiveStringMap

object DataSourceV2Implicits {
private val METADATA_COL_ATTR_KEY = "__metadata_col"

implicit class TableHelper(table: Table) {
def asReadable: SupportsRead = {
table match {
Expand Down Expand Up @@ -83,7 +85,8 @@ object DataSourceV2Implicits {
implicit class MetadataColumnsHelper(metadata: Array[MetadataColumn]) {
def asStruct: StructType = {
val fields = metadata.map { metaCol =>
val field = StructField(metaCol.name, metaCol.dataType, metaCol.isNullable)
val fieldMeta = new MetadataBuilder().putBoolean(METADATA_COL_ATTR_KEY, true).build()
val field = StructField(metaCol.name, metaCol.dataType, metaCol.isNullable, fieldMeta)
Option(metaCol.comment).map(field.withComment).getOrElse(field)
}
StructType(fields)
Expand All @@ -92,6 +95,11 @@ object DataSourceV2Implicits {
def toAttributes: Seq[AttributeReference] = asStruct.toAttributes
}

implicit class MetadataColumnHelper(attr: Attribute) {
def isMetadataCol: Boolean = attr.metadata.contains(METADATA_COL_ATTR_KEY) &&
attr.metadata.getBoolean(METADATA_COL_ATTR_KEY)
}

implicit class OptionsHelper(options: Map[String, String]) {
def asOptions: CaseInsensitiveStringMap = {
new CaseInsensitiveStringMap(options.asJava)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ import org.apache.spark.sql.connector.read._
import org.apache.spark.sql.connector.write._
import org.apache.spark.sql.connector.write.streaming.{StreamingDataWriterFactory, StreamingWrite}
import org.apache.spark.sql.sources.{And, EqualNullSafe, EqualTo, Filter, IsNotNull, IsNull}
import org.apache.spark.sql.types.{DataType, DateType, StringType, StructField, StructType, TimestampType}
import org.apache.spark.sql.types.{DataType, DateType, IntegerType, StringType, StructField, StructType, TimestampType}
import org.apache.spark.sql.util.CaseInsensitiveStringMap
import org.apache.spark.unsafe.types.UTF8String

Expand All @@ -61,7 +61,7 @@ class InMemoryTable(

private object IndexColumn extends MetadataColumn {
override def name: String = "index"
override def dataType: DataType = StringType
override def dataType: DataType = IntegerType
override def comment: String = "Metadata column used to conflict with a data column"
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ class DataSourceV2SQLSuite
Array("Part 0", "id", ""),
Array("", "", ""),
Array("# Metadata Columns", "", ""),
Array("index", "string", "Metadata column used to conflict with a data column"),
Array("index", "int", "Metadata column used to conflict with a data column"),
Array("_partition", "string", "Partition key used to store the row"),
Array("", "", ""),
Array("# Detailed Table Information", "", ""),
Expand Down Expand Up @@ -2443,9 +2443,12 @@ class DataSourceV2SQLSuite
"PARTITIONED BY (bucket(4, id), id)")
sql(s"INSERT INTO $t1 VALUES (1, 'a'), (2, 'b'), (3, 'c')")

checkAnswer(
spark.sql(s"SELECT id, data, _partition FROM $t1"),
Seq(Row(1, "a", "3/1"), Row(2, "b", "0/2"), Row(3, "c", "1/3")))
val sqlQuery = spark.sql(s"SELECT id, data, index, _partition FROM $t1")
val dfQuery = spark.table(t1).select("id", "data", "index", "_partition")

Seq(sqlQuery, dfQuery).foreach { query =>
checkAnswer(query, Seq(Row(1, "a", 0, "3/1"), Row(2, "b", 0, "0/2"), Row(3, "c", 0, "1/3")))
}
}
}

Expand All @@ -2456,9 +2459,12 @@ class DataSourceV2SQLSuite
"PARTITIONED BY (bucket(4, index), index)")
sql(s"INSERT INTO $t1 VALUES (3, 'c'), (2, 'b'), (1, 'a')")

checkAnswer(
spark.sql(s"SELECT index, data, _partition FROM $t1"),
Seq(Row(3, "c", "1/3"), Row(2, "b", "0/2"), Row(1, "a", "3/1")))
val sqlQuery = spark.sql(s"SELECT index, data, _partition FROM $t1")
val dfQuery = spark.table(t1).select("index", "data", "_partition")

Seq(sqlQuery, dfQuery).foreach { query =>
checkAnswer(query, Seq(Row(3, "c", "1/3"), Row(2, "b", "0/2"), Row(1, "a", "3/1")))
}
}
}

Expand All @@ -2469,9 +2475,27 @@ class DataSourceV2SQLSuite
"PARTITIONED BY (bucket(4, id), id)")
sql(s"INSERT INTO $t1 VALUES (3, 'c'), (2, 'b'), (1, 'a')")

checkAnswer(
spark.sql(s"SELECT * FROM $t1"),
Seq(Row(3, "c"), Row(2, "b"), Row(1, "a")))
val sqlQuery = spark.sql(s"SELECT * FROM $t1")
val dfQuery = spark.table(t1)

Seq(sqlQuery, dfQuery).foreach { query =>
checkAnswer(query, Seq(Row(3, "c"), Row(2, "b"), Row(1, "a")))
}
}
}

test("SPARK-31255: metadata column should only be produced when necessary") {
val t1 = s"${catalogAndNamespace}table"
withTable(t1) {
sql(s"CREATE TABLE $t1 (id bigint, data string) USING $v2Format " +
"PARTITIONED BY (bucket(4, id), id)")

val sqlQuery = spark.sql(s"SELECT * FROM $t1 WHERE index = 0")
val dfQuery = spark.table(t1).filter("index = 0")

Seq(sqlQuery, dfQuery).foreach { query =>
assert(query.schema.fieldNames.toSeq == Seq("id", "data"))
}
}
}

Expand Down

0 comments on commit 989eb68

Please sign in to comment.