From cd8e0d7fdc299ebf81a96d585d174c76f5e9cbe9 Mon Sep 17 00:00:00 2001 From: Huaxin Gao Date: Sun, 5 Oct 2025 20:37:17 -0700 Subject: [PATCH 1/4] [SPARK-53805][SQL] Push Variant into DSv2 scan --- .../spark/sql/execution/SparkOptimizer.scala | 4 +- .../datasources/PushVariantIntoScan.scala | 101 ++++++++++++ .../datasources/v2/VariantV2ReadSuite.scala | 148 ++++++++++++++++++ 3 files changed, 251 insertions(+), 2 deletions(-) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/VariantV2ReadSuite.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala index 8edb59f49282e..9699d8a2563fb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala @@ -40,11 +40,11 @@ class SparkOptimizer( SchemaPruning, GroupBasedRowLevelOperationScanPlanning, V1Writes, + PushVariantIntoScan, V2ScanRelationPushDown, V2ScanPartitioningAndOrdering, V2Writes, - PruneFileSourcePartitions, - PushVariantIntoScan) + PruneFileSourcePartitions) override def preCBORules: Seq[Rule[LogicalPlan]] = Seq(OptimizeMetadataOnlyDeleteFromTable) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PushVariantIntoScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PushVariantIntoScan.scala index 5960cf8c38ced..c30b8adbdd05c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PushVariantIntoScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PushVariantIntoScan.scala @@ -28,6 +28,7 @@ import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.catalyst.util.ResolveDefaultColumns import org.apache.spark.sql.errors.QueryExecutionErrors import org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat +import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ @@ -279,6 +280,8 @@ object PushVariantIntoScan extends Rule[LogicalPlan] { relation @ LogicalRelationWithTable( hadoopFsRelation@HadoopFsRelation(_, _, _, _, _: ParquetFileFormat, _), _)) => rewritePlan(p, projectList, filters, relation, hadoopFsRelation) + case p@PhysicalOperation(projectList, filters, relation: DataSourceV2Relation) => + rewriteV2RelationPlan(p, projectList, filters, relation.output, relation) } } @@ -343,4 +346,102 @@ object PushVariantIntoScan extends Rule[LogicalPlan] { } Project(newProjectList, withFilter) } + + private def rewriteV2RelationPlan( + originalPlan: LogicalPlan, + projectList: Seq[NamedExpression], + filters: Seq[Expression], + relationOutput: Seq[AttributeReference], + relation: LogicalPlan): LogicalPlan = { + + // Collect variant fields from the relation output + val (variants, attributeMap) = collectAndRewriteVariants(relationOutput) + if (attributeMap.isEmpty) return originalPlan + + // Collect requested fields from projections and filters + projectList.foreach(variants.collectRequestedFields) + filters.foreach(variants.collectRequestedFields) + if (variants.mapping.forall(_._2.isEmpty)) return originalPlan + + // Build attribute map with rewritten types + val finalAttributeMap = buildAttributeMap(relationOutput, variants) + + // Rewrite the relation with new output + val newRelation = relation match { + case r: DataSourceV2Relation => + val newOutput = r.output.map(a => finalAttributeMap.getOrElse(a.exprId, a)) + r.copy(output = newOutput.toIndexedSeq) + case _ => return originalPlan + } + + // Build filter and project with rewritten expressions + buildFilterAndProject(newRelation, projectList, filters, variants, finalAttributeMap) + } + + /** + * Collect variant fields and return initialized VariantInRelation. + */ + private def collectAndRewriteVariants( + schemaAttributes: Seq[AttributeReference]): (VariantInRelation, Map[ExprId, Attribute]) = { + val variants = new VariantInRelation + val defaultValues = ResolveDefaultColumns.existenceDefaultValues(StructType( + schemaAttributes.map(a => StructField(a.name, a.dataType, a.nullable, a.metadata)))) + + for ((a, defaultValue) <- schemaAttributes.zip(defaultValues)) { + variants.addVariantFields(a.exprId, a.dataType, defaultValue, Nil) + } + + val attributeMap = if (variants.mapping.isEmpty) { + Map.empty[ExprId, Attribute] + } else { + schemaAttributes.map(a => (a.exprId, a)).toMap + } + + (variants, attributeMap) + } + + /** + * Build attribute map with rewritten variant types. + */ + private def buildAttributeMap( + schemaAttributes: Seq[AttributeReference], + variants: VariantInRelation): Map[ExprId, AttributeReference] = { + schemaAttributes.map { a => + if (variants.mapping.get(a.exprId).exists(_.nonEmpty)) { + val newType = variants.rewriteType(a.exprId, a.dataType, Nil) + val newAttr = AttributeReference(a.name, newType, a.nullable, a.metadata)( + qualifier = a.qualifier) + (a.exprId, newAttr) + } else { + (a.exprId, a) + } + }.toMap + } + + /** + * Build the final Project(Filter(relation)) plan with rewritten expressions. + */ + private def buildFilterAndProject( + relation: LogicalPlan, + projectList: Seq[NamedExpression], + filters: Seq[Expression], + variants: VariantInRelation, + attributeMap: Map[ExprId, AttributeReference]): LogicalPlan = { + + val withFilter = if (filters.nonEmpty) { + Filter(filters.map(variants.rewriteExpr(_, attributeMap)).reduce(And), relation) + } else { + relation + } + + val newProjectList = projectList.map { e => + val rewritten = variants.rewriteExpr(e, attributeMap) + rewritten match { + case n: NamedExpression => n + case _ => Alias(rewritten, e.name)(e.exprId, e.qualifier) + } + } + + Project(newProjectList, withFilter) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/VariantV2ReadSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/VariantV2ReadSuite.scala new file mode 100644 index 0000000000000..67b85ebea9e9f --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/VariantV2ReadSuite.scala @@ -0,0 +1,148 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.v2 + +import org.apache.spark.sql.QueryTest +import org.apache.spark.sql.execution.datasources.VariantMetadata +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.test.SharedSparkSession +import org.apache.spark.sql.types.{IntegerType, StringType, StructType, VariantType} + +class VariantV2ReadSuite extends QueryTest with SharedSparkSession { + + private val testCatalogClass = "org.apache.spark.sql.connector.catalog.InMemoryTableCatalog" + + private def withV2Catalog(f: => Unit): Unit = { + withSQLConf( + SQLConf.DEFAULT_CATALOG.key -> "testcat", + s"spark.sql.catalog.testcat" -> testCatalogClass, + SQLConf.USE_V1_SOURCE_LIST.key -> "", + SQLConf.PUSH_VARIANT_INTO_SCAN.key -> "true", + SQLConf.VARIANT_ALLOW_READING_SHREDDED.key -> "true") { + f + } + } + + test("DSV2: push variant_get fields") { + withV2Catalog { + sql("DROP TABLE IF EXISTS testcat.ns.users") + sql( + """CREATE TABLE testcat.ns.users ( + | id bigint, + | name string, + | v variant, + | vd variant default parse_json('1') + |) USING parquet""".stripMargin) + + val out = sql( + """ + |SELECT + | id, + | variant_get(v, '$.username', 'string') as username, + | variant_get(v, '$.age', 'int') as age + |FROM testcat.ns.users + |WHERE variant_get(v, '$.status', 'string') = 'active' + |""".stripMargin) + + checkAnswer(out, Seq.empty) + + // Verify variant column rewrite + val optimized = out.queryExecution.optimizedPlan + val relOutput = optimized.collectFirst { + case s: DataSourceV2ScanRelation => s.output + }.getOrElse(fail("Expected DSv2 relation in optimized plan")) + + val vAttr = relOutput.find(_.name == "v").getOrElse(fail("Missing 'v' column")) + vAttr.dataType match { + case s: StructType => + assert(s.fields.length == 3, + s"Expected 3 fields (username, age, status), got ${s.fields.length}") + assert(s.fields.forall(_.metadata.contains(VariantMetadata.METADATA_KEY)), + "All fields should have VariantMetadata") + + val paths = s.fields.map(f => VariantMetadata.fromMetadata(f.metadata).path).toSet + assert(paths == Set("$.username", "$.age", "$.status"), + s"Expected username, age, status paths, got: $paths") + + val fieldTypes = s.fields.map(_.dataType).toSet + assert(fieldTypes.contains(StringType), "Expected StringType for string fields") + assert(fieldTypes.contains(IntegerType), "Expected IntegerType for age") + + case other => + fail(s"Expected StructType for 'v', got: $other") + } + + // Verify variant with default value is NOT rewritten + relOutput.find(_.name == "vd").foreach { vdAttr => + assert(vdAttr.dataType == VariantType, + "Variant column with default value should not be rewritten") + } + } + } + + test("DSV2: nested column pruning for variant struct") { + withV2Catalog { + sql("DROP TABLE IF EXISTS testcat.ns.users2") + sql( + """CREATE TABLE testcat.ns.users2 ( + | id bigint, + | name string, + | v variant + |) USING parquet""".stripMargin) + + val out = sql( + """ + |SELECT id, variant_get(v, '$.username', 'string') as username + |FROM testcat.ns.users2 + |""".stripMargin) + + checkAnswer(out, Seq.empty) + + val scan = out.queryExecution.executedPlan.collectFirst { + case b: BatchScanExec => b.scan + }.getOrElse(fail("Expected BatchScanExec in physical plan")) + + val readSchema = scan.readSchema() + + // Verify 'v' field exists and is a struct + val vField = readSchema.fields.find(_.name == "v").getOrElse( + fail("Expected 'v' field in read schema") + ) + + vField.dataType match { + case s: StructType => + assert(s.fields.length == 1, + "Expected only 1 field ($.username) in pruned schema, got " + s.fields.length + ": " + + s.fields.map(f => VariantMetadata.fromMetadata(f.metadata).path).mkString(", ")) + + val field = s.fields(0) + assert(field.metadata.contains(VariantMetadata.METADATA_KEY), + "Field should have VariantMetadata") + + val metadata = VariantMetadata.fromMetadata(field.metadata) + assert(metadata.path == "$.username", + "Expected path '$.username', got '" + metadata.path + "'") + assert(field.dataType == StringType, + s"Expected StringType, got ${field.dataType}") + + case other => + fail(s"Expected StructType for 'v' after rewrite and pruning, got: $other") + } + } + } +} \ No newline at end of file From c8b9df565467a498df3a70fcad8e4cc6512cb1ab Mon Sep 17 00:00:00 2001 From: Huaxin Gao Date: Sun, 5 Oct 2025 23:37:39 -0700 Subject: [PATCH 2/4] add new line at end of file --- .../spark/sql/execution/datasources/v2/VariantV2ReadSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/VariantV2ReadSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/VariantV2ReadSuite.scala index 67b85ebea9e9f..a6521dfe76da1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/VariantV2ReadSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/VariantV2ReadSuite.scala @@ -145,4 +145,4 @@ class VariantV2ReadSuite extends QueryTest with SharedSparkSession { } } } -} \ No newline at end of file +} From 2092fce4aefd3cfb8435a389cf871118535ed0bd Mon Sep 17 00:00:00 2001 From: Huaxin Gao Date: Tue, 7 Oct 2025 16:19:43 -0700 Subject: [PATCH 3/4] address comments --- .../datasources/PushVariantIntoScan.scala | 17 ++++++----------- 1 file changed, 6 insertions(+), 11 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PushVariantIntoScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PushVariantIntoScan.scala index c30b8adbdd05c..4a1ad63edcb8c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PushVariantIntoScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PushVariantIntoScan.scala @@ -281,7 +281,7 @@ object PushVariantIntoScan extends Rule[LogicalPlan] { hadoopFsRelation@HadoopFsRelation(_, _, _, _, _: ParquetFileFormat, _), _)) => rewritePlan(p, projectList, filters, relation, hadoopFsRelation) case p@PhysicalOperation(projectList, filters, relation: DataSourceV2Relation) => - rewriteV2RelationPlan(p, projectList, filters, relation.output, relation) + rewriteV2RelationPlan(p, projectList, filters, relation) } } @@ -351,11 +351,10 @@ object PushVariantIntoScan extends Rule[LogicalPlan] { originalPlan: LogicalPlan, projectList: Seq[NamedExpression], filters: Seq[Expression], - relationOutput: Seq[AttributeReference], - relation: LogicalPlan): LogicalPlan = { + relation: DataSourceV2Relation): LogicalPlan = { // Collect variant fields from the relation output - val (variants, attributeMap) = collectAndRewriteVariants(relationOutput) + val (variants, attributeMap) = collectAndRewriteVariants(relation.output) if (attributeMap.isEmpty) return originalPlan // Collect requested fields from projections and filters @@ -364,15 +363,11 @@ object PushVariantIntoScan extends Rule[LogicalPlan] { if (variants.mapping.forall(_._2.isEmpty)) return originalPlan // Build attribute map with rewritten types - val finalAttributeMap = buildAttributeMap(relationOutput, variants) + val finalAttributeMap = buildAttributeMap(relation.output, variants) // Rewrite the relation with new output - val newRelation = relation match { - case r: DataSourceV2Relation => - val newOutput = r.output.map(a => finalAttributeMap.getOrElse(a.exprId, a)) - r.copy(output = newOutput.toIndexedSeq) - case _ => return originalPlan - } + val newOutput = relation.output.map(a => finalAttributeMap.getOrElse(a.exprId, a)) + val newRelation = relation.copy(output = newOutput.toIndexedSeq) // Build filter and project with rewritten expressions buildFilterAndProject(newRelation, projectList, filters, variants, finalAttributeMap) From 9bb25cc3a4bb1d9ae32aca6181e62d01197dcfad Mon Sep 17 00:00:00 2001 From: Huaxin Gao Date: Thu, 9 Oct 2025 09:29:58 -0700 Subject: [PATCH 4/4] reuse common code --- .../datasources/PushVariantIntoScan.scala | 81 +++++++------------ 1 file changed, 30 insertions(+), 51 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PushVariantIntoScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PushVariantIntoScan.scala index 4a1ad63edcb8c..6ce53e3367c41 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PushVariantIntoScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PushVariantIntoScan.scala @@ -291,60 +291,37 @@ object PushVariantIntoScan extends Rule[LogicalPlan] { filters: Seq[Expression], relation: LogicalRelation, hadoopFsRelation: HadoopFsRelation): LogicalPlan = { - val variants = new VariantInRelation - val schemaAttributes = relation.resolve(hadoopFsRelation.dataSchema, hadoopFsRelation.sparkSession.sessionState.analyzer.resolver) - val defaultValues = ResolveDefaultColumns.existenceDefaultValues(StructType( - schemaAttributes.map(a => StructField(a.name, a.dataType, a.nullable, a.metadata)))) - for ((a, defaultValue) <- schemaAttributes.zip(defaultValues)) { - variants.addVariantFields(a.exprId, a.dataType, defaultValue, Nil) - } + + // Collect variant fields from the relation output + val variants = collectAndRewriteVariants(schemaAttributes) if (variants.mapping.isEmpty) return originalPlan + // Collect requested fields from projections and filters projectList.foreach(variants.collectRequestedFields) filters.foreach(variants.collectRequestedFields) // `collectRequestedFields` may have removed all variant columns. if (variants.mapping.forall(_._2.isEmpty)) return originalPlan - val attributeMap = schemaAttributes.map { a => - if (variants.mapping.get(a.exprId).exists(_.nonEmpty)) { - val newType = variants.rewriteType(a.exprId, a.dataType, Nil) - val newAttr = AttributeReference(a.name, newType, a.nullable, a.metadata)( - qualifier = a.qualifier) - (a.exprId, newAttr) - } else { - // `relation.resolve` actually returns `Seq[AttributeReference]`, although the return type - // is `Seq[Attribute]`. - (a.exprId, a.asInstanceOf[AttributeReference]) - } - }.toMap + // Build attribute map with rewritten types + val attributeMap = buildAttributeMap(schemaAttributes, variants) + + // Build new schema with variant types replaced by struct types val newFields = schemaAttributes.map { a => val dataType = attributeMap(a.exprId).dataType StructField(a.name, dataType, a.nullable, a.metadata) } + // Update relation output attributes with new types val newOutput = relation.output.map(a => attributeMap.getOrElse(a.exprId, a)) + // Update HadoopFsRelation's data schema so the file source reads the struct columns val newHadoopFsRelation = hadoopFsRelation.copy(dataSchema = StructType(newFields))( hadoopFsRelation.sparkSession) val newRelation = relation.copy(relation = newHadoopFsRelation, output = newOutput.toIndexedSeq) - val withFilter = if (filters.nonEmpty) { - Filter(filters.map(variants.rewriteExpr(_, attributeMap)).reduce(And), newRelation) - } else { - newRelation - } - val newProjectList = projectList.map { e => - val rewritten = variants.rewriteExpr(e, attributeMap) - rewritten match { - case n: NamedExpression => n - // This is when the variant column is directly selected. We replace the attribute reference - // with a struct access, which is not a `NamedExpression` that `Project` requires. We wrap - // it with an `Alias`. - case _ => Alias(rewritten, e.name)(e.exprId, e.qualifier) - } - } - Project(newProjectList, withFilter) + // Build filter and project with rewritten expressions + buildFilterAndProject(newRelation, projectList, filters, variants, attributeMap) } private def rewriteV2RelationPlan( @@ -354,30 +331,33 @@ object PushVariantIntoScan extends Rule[LogicalPlan] { relation: DataSourceV2Relation): LogicalPlan = { // Collect variant fields from the relation output - val (variants, attributeMap) = collectAndRewriteVariants(relation.output) - if (attributeMap.isEmpty) return originalPlan + val variants = collectAndRewriteVariants(relation.output) + if (variants.mapping.isEmpty) return originalPlan // Collect requested fields from projections and filters projectList.foreach(variants.collectRequestedFields) filters.foreach(variants.collectRequestedFields) + // `collectRequestedFields` may have removed all variant columns. if (variants.mapping.forall(_._2.isEmpty)) return originalPlan // Build attribute map with rewritten types - val finalAttributeMap = buildAttributeMap(relation.output, variants) + val attributeMap = buildAttributeMap(relation.output, variants) - // Rewrite the relation with new output - val newOutput = relation.output.map(a => finalAttributeMap.getOrElse(a.exprId, a)) + // Update relation output attributes with new types + // Note: DSv2 doesn't need to update the schema in the relation itself. The schema will be + // communicated to the data source later via V2ScanRelationPushDown.pruneColumns() API. + val newOutput = relation.output.map(a => attributeMap.getOrElse(a.exprId, a)) val newRelation = relation.copy(output = newOutput.toIndexedSeq) // Build filter and project with rewritten expressions - buildFilterAndProject(newRelation, projectList, filters, variants, finalAttributeMap) + buildFilterAndProject(newRelation, projectList, filters, variants, attributeMap) } /** * Collect variant fields and return initialized VariantInRelation. */ private def collectAndRewriteVariants( - schemaAttributes: Seq[AttributeReference]): (VariantInRelation, Map[ExprId, Attribute]) = { + schemaAttributes: Seq[Attribute]): VariantInRelation = { val variants = new VariantInRelation val defaultValues = ResolveDefaultColumns.existenceDefaultValues(StructType( schemaAttributes.map(a => StructField(a.name, a.dataType, a.nullable, a.metadata)))) @@ -386,20 +366,14 @@ object PushVariantIntoScan extends Rule[LogicalPlan] { variants.addVariantFields(a.exprId, a.dataType, defaultValue, Nil) } - val attributeMap = if (variants.mapping.isEmpty) { - Map.empty[ExprId, Attribute] - } else { - schemaAttributes.map(a => (a.exprId, a)).toMap - } - - (variants, attributeMap) + variants } /** * Build attribute map with rewritten variant types. */ private def buildAttributeMap( - schemaAttributes: Seq[AttributeReference], + schemaAttributes: Seq[Attribute], variants: VariantInRelation): Map[ExprId, AttributeReference] = { schemaAttributes.map { a => if (variants.mapping.get(a.exprId).exists(_.nonEmpty)) { @@ -408,7 +382,9 @@ object PushVariantIntoScan extends Rule[LogicalPlan] { qualifier = a.qualifier) (a.exprId, newAttr) } else { - (a.exprId, a) + // `relation.resolve` actually returns `Seq[AttributeReference]`, although the return type + // is `Seq[Attribute]`. + (a.exprId, a.asInstanceOf[AttributeReference]) } }.toMap } @@ -433,6 +409,9 @@ object PushVariantIntoScan extends Rule[LogicalPlan] { val rewritten = variants.rewriteExpr(e, attributeMap) rewritten match { case n: NamedExpression => n + // This is when the variant column is directly selected. We replace the attribute reference + // with a struct access, which is not a `NamedExpression` that `Project` requires. We wrap + // it with an `Alias`. case _ => Alias(rewritten, e.name)(e.exprId, e.qualifier) } }