Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -40,11 +40,11 @@ class SparkOptimizer(
SchemaPruning,
GroupBasedRowLevelOperationScanPlanning,
V1Writes,
PushVariantIntoScan,

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

now PushVariantIntoScan runs before the PruneFileSourcePartition, which i think was for v1 sources, does this matter or if i were to ask did we just like add in later, just because it was a new rule ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think variant columns will ever be used in the partition schema. Schema transformations by PushVariantIntoScan shouldn't affect partition pruning in v1 sources.

V2ScanRelationPushDown,
V2ScanPartitioningAndOrdering,
V2Writes,
PruneFileSourcePartitions,
PushVariantIntoScan)
PruneFileSourcePartitions)

override def preCBORules: Seq[Rule[LogicalPlan]] =
Seq(OptimizeMetadataOnlyDeleteFromTable)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.catalyst.util.ResolveDefaultColumns
import org.apache.spark.sql.errors.QueryExecutionErrors
import org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat
import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._

Expand Down Expand Up @@ -279,6 +280,8 @@ object PushVariantIntoScan extends Rule[LogicalPlan] {
relation @ LogicalRelationWithTable(
hadoopFsRelation@HadoopFsRelation(_, _, _, _, _: ParquetFileFormat, _), _)) =>
rewritePlan(p, projectList, filters, relation, hadoopFsRelation)
case p@PhysicalOperation(projectList, filters, relation: DataSourceV2Relation) =>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there any code we can share between the v1 rewritePlan and the v2 rewriteV2RelationPlan?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, there’s shared logic. I intentionally left the v1 rewritePlan unchanged in this PR to keep the diff small and easier to review. After this merges, I’ll do a small follow-up to have v1 rewritePlan reuse the common code. If you prefer, I can fold that refactor into this PR.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's actually harder to review as I can't tell what's the key difference between the v1 and v2 versions with the current PR...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry for the confusion. I have updated the code.

The logic for transforming variant columns to struct is identical between DSv1 and DSv2. Now they both use the same helper methods (collectAndRewriteVariants, buildAttributeMap, buildFilterAndProject).

The only difference is how the transformed schema is communicated to the data source. DSv1 stores the new schema in HadoopFsRelation.dataSchema and the file source reads this field directly; DSv2 has no schema field to update. The schema is communicated later when V2ScanRelationPushDown calls pruneColumns.

rewriteV2RelationPlan(p, projectList, filters, relation)
}
}

Expand All @@ -288,23 +291,91 @@ object PushVariantIntoScan extends Rule[LogicalPlan] {
filters: Seq[Expression],
relation: LogicalRelation,
hadoopFsRelation: HadoopFsRelation): LogicalPlan = {
val variants = new VariantInRelation

val schemaAttributes = relation.resolve(hadoopFsRelation.dataSchema,
hadoopFsRelation.sparkSession.sessionState.analyzer.resolver)
val defaultValues = ResolveDefaultColumns.existenceDefaultValues(StructType(
schemaAttributes.map(a => StructField(a.name, a.dataType, a.nullable, a.metadata))))
for ((a, defaultValue) <- schemaAttributes.zip(defaultValues)) {
variants.addVariantFields(a.exprId, a.dataType, defaultValue, Nil)

// Collect variant fields from the relation output
val variants = collectAndRewriteVariants(schemaAttributes)
if (variants.mapping.isEmpty) return originalPlan

// Collect requested fields from projections and filters
projectList.foreach(variants.collectRequestedFields)
filters.foreach(variants.collectRequestedFields)
// `collectRequestedFields` may have removed all variant columns.
if (variants.mapping.forall(_._2.isEmpty)) return originalPlan

// Build attribute map with rewritten types
val attributeMap = buildAttributeMap(schemaAttributes, variants)

// Build new schema with variant types replaced by struct types
val newFields = schemaAttributes.map { a =>
val dataType = attributeMap(a.exprId).dataType
StructField(a.name, dataType, a.nullable, a.metadata)
}
// Update relation output attributes with new types
val newOutput = relation.output.map(a => attributeMap.getOrElse(a.exprId, a))

// Update HadoopFsRelation's data schema so the file source reads the struct columns
val newHadoopFsRelation = hadoopFsRelation.copy(dataSchema = StructType(newFields))(
hadoopFsRelation.sparkSession)
val newRelation = relation.copy(relation = newHadoopFsRelation, output = newOutput.toIndexedSeq)

// Build filter and project with rewritten expressions
buildFilterAndProject(newRelation, projectList, filters, variants, attributeMap)
}

private def rewriteV2RelationPlan(
originalPlan: LogicalPlan,
projectList: Seq[NamedExpression],
filters: Seq[Expression],
relation: DataSourceV2Relation): LogicalPlan = {

// Collect variant fields from the relation output
val variants = collectAndRewriteVariants(relation.output)
if (variants.mapping.isEmpty) return originalPlan

// Collect requested fields from projections and filters
projectList.foreach(variants.collectRequestedFields)
filters.foreach(variants.collectRequestedFields)
// `collectRequestedFields` may have removed all variant columns.
if (variants.mapping.forall(_._2.isEmpty)) return originalPlan

val attributeMap = schemaAttributes.map { a =>
// Build attribute map with rewritten types
val attributeMap = buildAttributeMap(relation.output, variants)

// Update relation output attributes with new types
// Note: DSv2 doesn't need to update the schema in the relation itself. The schema will be
// communicated to the data source later via V2ScanRelationPushDown.pruneColumns() API.
val newOutput = relation.output.map(a => attributeMap.getOrElse(a.exprId, a))
val newRelation = relation.copy(output = newOutput.toIndexedSeq)

// Build filter and project with rewritten expressions
buildFilterAndProject(newRelation, projectList, filters, variants, attributeMap)
}

/**
* Collect variant fields and return initialized VariantInRelation.
*/
private def collectAndRewriteVariants(
schemaAttributes: Seq[Attribute]): VariantInRelation = {
val variants = new VariantInRelation
val defaultValues = ResolveDefaultColumns.existenceDefaultValues(StructType(
schemaAttributes.map(a => StructField(a.name, a.dataType, a.nullable, a.metadata))))

for ((a, defaultValue) <- schemaAttributes.zip(defaultValues)) {
variants.addVariantFields(a.exprId, a.dataType, defaultValue, Nil)
}

variants
}

/**
* Build attribute map with rewritten variant types.
*/
private def buildAttributeMap(
schemaAttributes: Seq[Attribute],
variants: VariantInRelation): Map[ExprId, AttributeReference] = {
schemaAttributes.map { a =>
if (variants.mapping.get(a.exprId).exists(_.nonEmpty)) {
val newType = variants.rewriteType(a.exprId, a.dataType, Nil)
val newAttr = AttributeReference(a.name, newType, a.nullable, a.metadata)(
Expand All @@ -316,21 +387,24 @@ object PushVariantIntoScan extends Rule[LogicalPlan] {
(a.exprId, a.asInstanceOf[AttributeReference])
}
}.toMap
val newFields = schemaAttributes.map { a =>
val dataType = attributeMap(a.exprId).dataType
StructField(a.name, dataType, a.nullable, a.metadata)
}
val newOutput = relation.output.map(a => attributeMap.getOrElse(a.exprId, a))
}

val newHadoopFsRelation = hadoopFsRelation.copy(dataSchema = StructType(newFields))(
hadoopFsRelation.sparkSession)
val newRelation = relation.copy(relation = newHadoopFsRelation, output = newOutput.toIndexedSeq)
/**
* Build the final Project(Filter(relation)) plan with rewritten expressions.
*/
private def buildFilterAndProject(
relation: LogicalPlan,
projectList: Seq[NamedExpression],
filters: Seq[Expression],
variants: VariantInRelation,
attributeMap: Map[ExprId, AttributeReference]): LogicalPlan = {

val withFilter = if (filters.nonEmpty) {
Filter(filters.map(variants.rewriteExpr(_, attributeMap)).reduce(And), newRelation)
Filter(filters.map(variants.rewriteExpr(_, attributeMap)).reduce(And), relation)
} else {
newRelation
relation
}

val newProjectList = projectList.map { e =>
val rewritten = variants.rewriteExpr(e, attributeMap)
rewritten match {
Expand All @@ -341,6 +415,7 @@ object PushVariantIntoScan extends Rule[LogicalPlan] {
case _ => Alias(rewritten, e.name)(e.exprId, e.qualifier)
}
}

Project(newProjectList, withFilter)
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.execution.datasources.v2

import org.apache.spark.sql.QueryTest
import org.apache.spark.sql.execution.datasources.VariantMetadata
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SharedSparkSession
import org.apache.spark.sql.types.{IntegerType, StringType, StructType, VariantType}

class VariantV2ReadSuite extends QueryTest with SharedSparkSession {

private val testCatalogClass = "org.apache.spark.sql.connector.catalog.InMemoryTableCatalog"

private def withV2Catalog(f: => Unit): Unit = {
withSQLConf(
SQLConf.DEFAULT_CATALOG.key -> "testcat",
s"spark.sql.catalog.testcat" -> testCatalogClass,
SQLConf.USE_V1_SOURCE_LIST.key -> "",
SQLConf.PUSH_VARIANT_INTO_SCAN.key -> "true",
SQLConf.VARIANT_ALLOW_READING_SHREDDED.key -> "true") {
f
}
}

test("DSV2: push variant_get fields") {
withV2Catalog {
sql("DROP TABLE IF EXISTS testcat.ns.users")
sql(
"""CREATE TABLE testcat.ns.users (
| id bigint,
| name string,
| v variant,
| vd variant default parse_json('1')
|) USING parquet""".stripMargin)

val out = sql(
"""
|SELECT
| id,
| variant_get(v, '$.username', 'string') as username,
| variant_get(v, '$.age', 'int') as age
|FROM testcat.ns.users
|WHERE variant_get(v, '$.status', 'string') = 'active'
|""".stripMargin)

checkAnswer(out, Seq.empty)

// Verify variant column rewrite
val optimized = out.queryExecution.optimizedPlan
val relOutput = optimized.collectFirst {
case s: DataSourceV2ScanRelation => s.output
}.getOrElse(fail("Expected DSv2 relation in optimized plan"))

val vAttr = relOutput.find(_.name == "v").getOrElse(fail("Missing 'v' column"))
vAttr.dataType match {
case s: StructType =>
assert(s.fields.length == 3,
s"Expected 3 fields (username, age, status), got ${s.fields.length}")
assert(s.fields.forall(_.metadata.contains(VariantMetadata.METADATA_KEY)),
"All fields should have VariantMetadata")

val paths = s.fields.map(f => VariantMetadata.fromMetadata(f.metadata).path).toSet
assert(paths == Set("$.username", "$.age", "$.status"),
s"Expected username, age, status paths, got: $paths")

val fieldTypes = s.fields.map(_.dataType).toSet
assert(fieldTypes.contains(StringType), "Expected StringType for string fields")
assert(fieldTypes.contains(IntegerType), "Expected IntegerType for age")

case other =>
fail(s"Expected StructType for 'v', got: $other")
}

// Verify variant with default value is NOT rewritten
relOutput.find(_.name == "vd").foreach { vdAttr =>
assert(vdAttr.dataType == VariantType,
"Variant column with default value should not be rewritten")
}
}
}

test("DSV2: nested column pruning for variant struct") {
withV2Catalog {
sql("DROP TABLE IF EXISTS testcat.ns.users2")
sql(
"""CREATE TABLE testcat.ns.users2 (
| id bigint,
| name string,
| v variant
|) USING parquet""".stripMargin)

val out = sql(
"""
|SELECT id, variant_get(v, '$.username', 'string') as username
|FROM testcat.ns.users2
|""".stripMargin)

checkAnswer(out, Seq.empty)

val scan = out.queryExecution.executedPlan.collectFirst {
case b: BatchScanExec => b.scan
}.getOrElse(fail("Expected BatchScanExec in physical plan"))

val readSchema = scan.readSchema()

// Verify 'v' field exists and is a struct
val vField = readSchema.fields.find(_.name == "v").getOrElse(
fail("Expected 'v' field in read schema")
)

vField.dataType match {
case s: StructType =>
assert(s.fields.length == 1,
"Expected only 1 field ($.username) in pruned schema, got " + s.fields.length + ": " +
s.fields.map(f => VariantMetadata.fromMetadata(f.metadata).path).mkString(", "))

val field = s.fields(0)
assert(field.metadata.contains(VariantMetadata.METADATA_KEY),
"Field should have VariantMetadata")

val metadata = VariantMetadata.fromMetadata(field.metadata)
assert(metadata.path == "$.username",
"Expected path '$.username', got '" + metadata.path + "'")
assert(field.dataType == StringType,
s"Expected StringType, got ${field.dataType}")

case other =>
fail(s"Expected StructType for 'v' after rewrite and pruning, got: $other")
}
}
}
}