diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala index 8edb59f49282e..9699d8a2563fb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala @@ -40,11 +40,11 @@ class SparkOptimizer( SchemaPruning, GroupBasedRowLevelOperationScanPlanning, V1Writes, + PushVariantIntoScan, V2ScanRelationPushDown, V2ScanPartitioningAndOrdering, V2Writes, - PruneFileSourcePartitions, - PushVariantIntoScan) + PruneFileSourcePartitions) override def preCBORules: Seq[Rule[LogicalPlan]] = Seq(OptimizeMetadataOnlyDeleteFromTable) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PushVariantIntoScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PushVariantIntoScan.scala index 5960cf8c38ced..6ce53e3367c41 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PushVariantIntoScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PushVariantIntoScan.scala @@ -28,6 +28,7 @@ import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.catalyst.util.ResolveDefaultColumns import org.apache.spark.sql.errors.QueryExecutionErrors import org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat +import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ @@ -279,6 +280,8 @@ object PushVariantIntoScan extends Rule[LogicalPlan] { relation @ LogicalRelationWithTable( hadoopFsRelation@HadoopFsRelation(_, _, _, _, _: ParquetFileFormat, _), _)) => rewritePlan(p, projectList, filters, relation, hadoopFsRelation) + case p@PhysicalOperation(projectList, filters, relation: DataSourceV2Relation) => + rewriteV2RelationPlan(p, projectList, filters, relation) } } @@ -288,23 +291,91 @@ object PushVariantIntoScan extends Rule[LogicalPlan] { filters: Seq[Expression], relation: LogicalRelation, hadoopFsRelation: HadoopFsRelation): LogicalPlan = { - val variants = new VariantInRelation - val schemaAttributes = relation.resolve(hadoopFsRelation.dataSchema, hadoopFsRelation.sparkSession.sessionState.analyzer.resolver) - val defaultValues = ResolveDefaultColumns.existenceDefaultValues(StructType( - schemaAttributes.map(a => StructField(a.name, a.dataType, a.nullable, a.metadata)))) - for ((a, defaultValue) <- schemaAttributes.zip(defaultValues)) { - variants.addVariantFields(a.exprId, a.dataType, defaultValue, Nil) + + // Collect variant fields from the relation output + val variants = collectAndRewriteVariants(schemaAttributes) + if (variants.mapping.isEmpty) return originalPlan + + // Collect requested fields from projections and filters + projectList.foreach(variants.collectRequestedFields) + filters.foreach(variants.collectRequestedFields) + // `collectRequestedFields` may have removed all variant columns. + if (variants.mapping.forall(_._2.isEmpty)) return originalPlan + + // Build attribute map with rewritten types + val attributeMap = buildAttributeMap(schemaAttributes, variants) + + // Build new schema with variant types replaced by struct types + val newFields = schemaAttributes.map { a => + val dataType = attributeMap(a.exprId).dataType + StructField(a.name, dataType, a.nullable, a.metadata) } + // Update relation output attributes with new types + val newOutput = relation.output.map(a => attributeMap.getOrElse(a.exprId, a)) + + // Update HadoopFsRelation's data schema so the file source reads the struct columns + val newHadoopFsRelation = hadoopFsRelation.copy(dataSchema = StructType(newFields))( + hadoopFsRelation.sparkSession) + val newRelation = relation.copy(relation = newHadoopFsRelation, output = newOutput.toIndexedSeq) + + // Build filter and project with rewritten expressions + buildFilterAndProject(newRelation, projectList, filters, variants, attributeMap) + } + + private def rewriteV2RelationPlan( + originalPlan: LogicalPlan, + projectList: Seq[NamedExpression], + filters: Seq[Expression], + relation: DataSourceV2Relation): LogicalPlan = { + + // Collect variant fields from the relation output + val variants = collectAndRewriteVariants(relation.output) if (variants.mapping.isEmpty) return originalPlan + // Collect requested fields from projections and filters projectList.foreach(variants.collectRequestedFields) filters.foreach(variants.collectRequestedFields) // `collectRequestedFields` may have removed all variant columns. if (variants.mapping.forall(_._2.isEmpty)) return originalPlan - val attributeMap = schemaAttributes.map { a => + // Build attribute map with rewritten types + val attributeMap = buildAttributeMap(relation.output, variants) + + // Update relation output attributes with new types + // Note: DSv2 doesn't need to update the schema in the relation itself. The schema will be + // communicated to the data source later via V2ScanRelationPushDown.pruneColumns() API. + val newOutput = relation.output.map(a => attributeMap.getOrElse(a.exprId, a)) + val newRelation = relation.copy(output = newOutput.toIndexedSeq) + + // Build filter and project with rewritten expressions + buildFilterAndProject(newRelation, projectList, filters, variants, attributeMap) + } + + /** + * Collect variant fields and return initialized VariantInRelation. + */ + private def collectAndRewriteVariants( + schemaAttributes: Seq[Attribute]): VariantInRelation = { + val variants = new VariantInRelation + val defaultValues = ResolveDefaultColumns.existenceDefaultValues(StructType( + schemaAttributes.map(a => StructField(a.name, a.dataType, a.nullable, a.metadata)))) + + for ((a, defaultValue) <- schemaAttributes.zip(defaultValues)) { + variants.addVariantFields(a.exprId, a.dataType, defaultValue, Nil) + } + + variants + } + + /** + * Build attribute map with rewritten variant types. + */ + private def buildAttributeMap( + schemaAttributes: Seq[Attribute], + variants: VariantInRelation): Map[ExprId, AttributeReference] = { + schemaAttributes.map { a => if (variants.mapping.get(a.exprId).exists(_.nonEmpty)) { val newType = variants.rewriteType(a.exprId, a.dataType, Nil) val newAttr = AttributeReference(a.name, newType, a.nullable, a.metadata)( @@ -316,21 +387,24 @@ object PushVariantIntoScan extends Rule[LogicalPlan] { (a.exprId, a.asInstanceOf[AttributeReference]) } }.toMap - val newFields = schemaAttributes.map { a => - val dataType = attributeMap(a.exprId).dataType - StructField(a.name, dataType, a.nullable, a.metadata) - } - val newOutput = relation.output.map(a => attributeMap.getOrElse(a.exprId, a)) + } - val newHadoopFsRelation = hadoopFsRelation.copy(dataSchema = StructType(newFields))( - hadoopFsRelation.sparkSession) - val newRelation = relation.copy(relation = newHadoopFsRelation, output = newOutput.toIndexedSeq) + /** + * Build the final Project(Filter(relation)) plan with rewritten expressions. + */ + private def buildFilterAndProject( + relation: LogicalPlan, + projectList: Seq[NamedExpression], + filters: Seq[Expression], + variants: VariantInRelation, + attributeMap: Map[ExprId, AttributeReference]): LogicalPlan = { val withFilter = if (filters.nonEmpty) { - Filter(filters.map(variants.rewriteExpr(_, attributeMap)).reduce(And), newRelation) + Filter(filters.map(variants.rewriteExpr(_, attributeMap)).reduce(And), relation) } else { - newRelation + relation } + val newProjectList = projectList.map { e => val rewritten = variants.rewriteExpr(e, attributeMap) rewritten match { @@ -341,6 +415,7 @@ object PushVariantIntoScan extends Rule[LogicalPlan] { case _ => Alias(rewritten, e.name)(e.exprId, e.qualifier) } } + Project(newProjectList, withFilter) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/VariantV2ReadSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/VariantV2ReadSuite.scala new file mode 100644 index 0000000000000..a6521dfe76da1 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/VariantV2ReadSuite.scala @@ -0,0 +1,148 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.v2 + +import org.apache.spark.sql.QueryTest +import org.apache.spark.sql.execution.datasources.VariantMetadata +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.test.SharedSparkSession +import org.apache.spark.sql.types.{IntegerType, StringType, StructType, VariantType} + +class VariantV2ReadSuite extends QueryTest with SharedSparkSession { + + private val testCatalogClass = "org.apache.spark.sql.connector.catalog.InMemoryTableCatalog" + + private def withV2Catalog(f: => Unit): Unit = { + withSQLConf( + SQLConf.DEFAULT_CATALOG.key -> "testcat", + s"spark.sql.catalog.testcat" -> testCatalogClass, + SQLConf.USE_V1_SOURCE_LIST.key -> "", + SQLConf.PUSH_VARIANT_INTO_SCAN.key -> "true", + SQLConf.VARIANT_ALLOW_READING_SHREDDED.key -> "true") { + f + } + } + + test("DSV2: push variant_get fields") { + withV2Catalog { + sql("DROP TABLE IF EXISTS testcat.ns.users") + sql( + """CREATE TABLE testcat.ns.users ( + | id bigint, + | name string, + | v variant, + | vd variant default parse_json('1') + |) USING parquet""".stripMargin) + + val out = sql( + """ + |SELECT + | id, + | variant_get(v, '$.username', 'string') as username, + | variant_get(v, '$.age', 'int') as age + |FROM testcat.ns.users + |WHERE variant_get(v, '$.status', 'string') = 'active' + |""".stripMargin) + + checkAnswer(out, Seq.empty) + + // Verify variant column rewrite + val optimized = out.queryExecution.optimizedPlan + val relOutput = optimized.collectFirst { + case s: DataSourceV2ScanRelation => s.output + }.getOrElse(fail("Expected DSv2 relation in optimized plan")) + + val vAttr = relOutput.find(_.name == "v").getOrElse(fail("Missing 'v' column")) + vAttr.dataType match { + case s: StructType => + assert(s.fields.length == 3, + s"Expected 3 fields (username, age, status), got ${s.fields.length}") + assert(s.fields.forall(_.metadata.contains(VariantMetadata.METADATA_KEY)), + "All fields should have VariantMetadata") + + val paths = s.fields.map(f => VariantMetadata.fromMetadata(f.metadata).path).toSet + assert(paths == Set("$.username", "$.age", "$.status"), + s"Expected username, age, status paths, got: $paths") + + val fieldTypes = s.fields.map(_.dataType).toSet + assert(fieldTypes.contains(StringType), "Expected StringType for string fields") + assert(fieldTypes.contains(IntegerType), "Expected IntegerType for age") + + case other => + fail(s"Expected StructType for 'v', got: $other") + } + + // Verify variant with default value is NOT rewritten + relOutput.find(_.name == "vd").foreach { vdAttr => + assert(vdAttr.dataType == VariantType, + "Variant column with default value should not be rewritten") + } + } + } + + test("DSV2: nested column pruning for variant struct") { + withV2Catalog { + sql("DROP TABLE IF EXISTS testcat.ns.users2") + sql( + """CREATE TABLE testcat.ns.users2 ( + | id bigint, + | name string, + | v variant + |) USING parquet""".stripMargin) + + val out = sql( + """ + |SELECT id, variant_get(v, '$.username', 'string') as username + |FROM testcat.ns.users2 + |""".stripMargin) + + checkAnswer(out, Seq.empty) + + val scan = out.queryExecution.executedPlan.collectFirst { + case b: BatchScanExec => b.scan + }.getOrElse(fail("Expected BatchScanExec in physical plan")) + + val readSchema = scan.readSchema() + + // Verify 'v' field exists and is a struct + val vField = readSchema.fields.find(_.name == "v").getOrElse( + fail("Expected 'v' field in read schema") + ) + + vField.dataType match { + case s: StructType => + assert(s.fields.length == 1, + "Expected only 1 field ($.username) in pruned schema, got " + s.fields.length + ": " + + s.fields.map(f => VariantMetadata.fromMetadata(f.metadata).path).mkString(", ")) + + val field = s.fields(0) + assert(field.metadata.contains(VariantMetadata.METADATA_KEY), + "Field should have VariantMetadata") + + val metadata = VariantMetadata.fromMetadata(field.metadata) + assert(metadata.path == "$.username", + "Expected path '$.username', got '" + metadata.path + "'") + assert(field.dataType == StringType, + s"Expected StringType, got ${field.dataType}") + + case other => + fail(s"Expected StructType for 'v' after rewrite and pruning, got: $other") + } + } + } +}