diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala index 9b8d48c3f3a8..789f1b0ca000 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala @@ -182,7 +182,13 @@ case class DataSourceV2ScanRelation( relation = this.relation.copy( output = this.relation.output.map(QueryPlan.normalizeExpressions(_, this.relation.output)) ), - output = this.output.map(QueryPlan.normalizeExpressions(_, this.output)) + output = this.output.map(QueryPlan.normalizeExpressions(_, this.output)), + keyGroupedPartitioning = keyGroupedPartitioning.map( + _.map(QueryPlan.normalizeExpressions(_, output)) + ), + ordering = ordering.map( + _.map(o => o.copy(child = QueryPlan.normalizeExpressions(o.child, output))) + ) ) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2Suite.scala index 01fa2b13b86f..a09b7e0827c4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2Suite.scala @@ -21,6 +21,8 @@ import java.io.File import java.util import java.util.OptionalLong +import scala.jdk.CollectionConverters._ + import test.org.apache.spark.sql.connector._ import org.apache.spark.SparkUnsupportedOperationException @@ -37,7 +39,7 @@ import org.apache.spark.sql.connector.read.Scan.ColumnarSupportMode import org.apache.spark.sql.connector.read.partitioning.{KeyGroupedPartitioning, Partitioning, UnknownPartitioning} import org.apache.spark.sql.execution.SortExec import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper -import org.apache.spark.sql.execution.datasources.v2.{BatchScanExec, DataSourceV2Relation, DataSourceV2ScanRelation} +import org.apache.spark.sql.execution.datasources.v2.{BatchScanExec, DataSourceV2Relation, DataSourceV2ScanRelation, V2ScanPartitioningAndOrdering} import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Implicits._ import org.apache.spark.sql.execution.exchange.{Exchange, ShuffleExchangeExec} import org.apache.spark.sql.execution.vectorized.OnHeapColumnVector @@ -1008,6 +1010,52 @@ class DataSourceV2Suite extends QueryTest with SharedSparkSession with AdaptiveS "Canonicalized DataSourceV2ScanRelation instances should be equal") } + test("SPARK-54163: scan canonicalization for partitioning and ordering aware data source") { + val options = new CaseInsensitiveStringMap(Map( + "partitionKeys" -> "i", + "orderKeys" -> "i,j" + ).asJava) + val table = new OrderAndPartitionAwareDataSource().getTable(options) + + def createDsv2ScanRelation(): DataSourceV2ScanRelation = { + val relation = DataSourceV2Relation.create(table, None, None, options) + val scan = relation.table.asReadable.newScanBuilder(relation.options).build() + val scanRelation = DataSourceV2ScanRelation(relation, scan, relation.output) + // Attach partitioning and ordering information to DataSourceV2ScanRelation + V2ScanPartitioningAndOrdering.apply(scanRelation).asInstanceOf[DataSourceV2ScanRelation] + } + + // Create two DataSourceV2ScanRelation instances, representing the scan of the same table + val scanRelation1 = createDsv2ScanRelation() + val scanRelation2 = createDsv2ScanRelation() + + // assert scanRelations have partitioning and ordering + assert(scanRelation1.keyGroupedPartitioning.isDefined && + scanRelation1.keyGroupedPartitioning.get.nonEmpty, + "DataSourceV2ScanRelation should have key grouped partitioning") + assert(scanRelation1.ordering.isDefined && scanRelation1.ordering.get.nonEmpty, + "DataSourceV2ScanRelation should have ordering") + + // the two instances should not be the same, as they should have different attribute IDs + assert(scanRelation1 != scanRelation2, + "Two created DataSourceV2ScanRelation instances should not be the same") + assert(scanRelation1.output.map(_.exprId).toSet != scanRelation2.output.map(_.exprId).toSet, + "Output attributes should have different expression IDs before canonicalization") + assert(scanRelation1.relation.output.map(_.exprId).toSet != + scanRelation2.relation.output.map(_.exprId).toSet, + "Relation output attributes should have different expression IDs before canonicalization") + assert(scanRelation1.keyGroupedPartitioning.get.flatMap(_.references.map(_.exprId)).toSet != + scanRelation2.keyGroupedPartitioning.get.flatMap(_.references.map(_.exprId)).toSet, + "Partitioning columns should have different expression IDs before canonicalization") + assert(scanRelation1.ordering.get.flatMap(_.references.map(_.exprId)).toSet != + scanRelation2.ordering.get.flatMap(_.references.map(_.exprId)).toSet, + "Ordering columns should have different expression IDs before canonicalization") + + // After canonicalization, the two instances should be equal + assert(scanRelation1.canonicalized == scanRelation2.canonicalized, + "Canonicalized DataSourceV2ScanRelation instances should be equal") + } + test("SPARK-53809: check mergeScalarSubqueries is effective for DataSourceV2ScanRelation") { val df = spark.read.format(classOf[SimpleDataSourceV2].getName).load() df.createOrReplaceTempView("df") @@ -1052,6 +1100,64 @@ class DataSourceV2Suite extends QueryTest with SharedSparkSession with AdaptiveS // Verify the query produces correct results checkAnswer(query, Row(9, 0)) } + + test( + "SPARK-54163: check mergeScalarSubqueries is effective for OrderAndPartitionAwareDataSource" + ) { + withSQLConf(SQLConf.V2_BUCKETING_ENABLED.key -> "true") { + val options = Map( + "partitionKeys" -> "i", + "orderKeys" -> "i,j" + ) + + // Create the OrderAndPartitionAwareDataSource DataFrame + val df = spark.read + .format(classOf[OrderAndPartitionAwareDataSource].getName) + .options(options) + .load() + df.createOrReplaceTempView("df") + + val query = sql("select (select max(i) from df) as max_i, (select min(i) from df) as min_i") + val optimizedPlan = query.queryExecution.optimizedPlan + + // check optimizedPlan merged scalar subqueries `select max(i), min(i) from df` + val sub1 = optimizedPlan.asInstanceOf[Project].projectList.head.collect { + case s: ScalarSubquery => s + } + val sub2 = optimizedPlan.asInstanceOf[Project].projectList(1).collect { + case s: ScalarSubquery => s + } + + // Both subqueries should reference the same merged plan `select max(i), min(i) from df` + assert(sub1.nonEmpty && sub2.nonEmpty, "Both scalar subqueries should exist") + assert(sub1.head.plan == sub2.head.plan, + "Both subqueries should reference the same merged plan") + + // Extract the aggregate from the merged plan sub1 + val agg = sub1.head.plan.collect { + case a: Aggregate => a + }.head + + // Check that the aggregate contains both max(i) and min(i) + val aggFunctionSet = agg.aggregateExpressions.flatMap { expr => + expr.collect { + case ae: org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression => + ae.aggregateFunction + } + }.toSet + + assert(aggFunctionSet.size == 2, "Aggregate should contain exactly two aggregate functions") + assert(aggFunctionSet + .exists(_.isInstanceOf[org.apache.spark.sql.catalyst.expressions.aggregate.Max]), + "Aggregate should contain max(i)") + assert(aggFunctionSet + .exists(_.isInstanceOf[org.apache.spark.sql.catalyst.expressions.aggregate.Min]), + "Aggregate should contain min(i)") + + // Verify the query produces correct results + checkAnswer(query, Row(4, 1)) + } + } } case class RangeInputPartition(start: Int, end: Int) extends InputPartition @@ -1093,6 +1199,18 @@ abstract class SimpleScanBuilder extends ScanBuilder override def readSchema(): StructType = TestingV2Source.schema override def createReaderFactory(): PartitionReaderFactory = SimpleReaderFactory + + override def equals(obj: Any): Boolean = { + obj match { + case s: Scan => + this.readSchema() == s.readSchema() + case _ => false + } + } + + override def hashCode(): Int = { + this.readSchema().hashCode() + } } trait TestingV2Source extends TableProvider { @@ -1157,18 +1275,6 @@ class SimpleDataSourceV2 extends TestingV2Source { override def planInputPartitions(): Array[InputPartition] = { Array(RangeInputPartition(0, 5), RangeInputPartition(5, 10)) } - - override def equals(obj: Any): Boolean = { - obj match { - case s: Scan => - this.readSchema() == s.readSchema() - case _ => false - } - } - - override def hashCode(): Int = { - this.readSchema().hashCode() - } } override def getTable(options: CaseInsensitiveStringMap): Table = new SimpleBatchTable {