From 354a4198da03e456f3d6a2d8bf0bbe0f1f2c8319 Mon Sep 17 00:00:00 2001 From: pavle-martinovic_data Date: Thu, 27 Mar 2025 10:55:19 +0100 Subject: [PATCH] Refactor to merge PushProjectionThroughLimit and PushProjectionThroughOffset --- .../sql/catalyst/optimizer/Optimizer.scala | 3 +- ...PushProjectionThroughLimitAndOffset.scala} | 14 +-- .../PushProjectionThroughOffset.scala | 35 -------- ...rojectionThroughLimitAndOffsetSuite.scala} | 62 ++++++++++++- .../PushProjectionThroughLimitSuite.scala | 90 ------------------- .../spark/sql/execution/SparkOptimizer.scala | 2 +- .../python/ExtractPythonUDFsSuite.scala | 4 +- 7 files changed, 71 insertions(+), 139 deletions(-) rename sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/{PushProjectionThroughLimit.scala => PushProjectionThroughLimitAndOffset.scala} (75%) delete mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PushProjectionThroughOffset.scala rename sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/{PushProjectionThroughOffsetSuite.scala => PushProjectionThroughLimitAndOffsetSuite.scala} (66%) delete mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PushProjectionThroughLimitSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 7a57c17879951..3727b3ea19ed9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -100,8 +100,7 @@ abstract class Optimizer(catalogManager: CatalogManager) Seq( // Operator push down PushProjectionThroughUnion, - PushProjectionThroughLimit, - PushProjectionThroughOffset, + PushProjectionThroughLimitAndOffset, ReorderJoin, EliminateOuterJoin, PushDownPredicates, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PushProjectionThroughLimit.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PushProjectionThroughLimitAndOffset.scala similarity index 75% rename from sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PushProjectionThroughLimit.scala rename to sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PushProjectionThroughLimitAndOffset.scala index 6280cc5e42c9f..e329251c36083 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PushProjectionThroughLimit.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PushProjectionThroughLimitAndOffset.scala @@ -17,16 +17,16 @@ package org.apache.spark.sql.catalyst.optimizer -import org.apache.spark.sql.catalyst.plans.logical.{GlobalLimit, LocalLimit, LogicalPlan, Project} +import org.apache.spark.sql.catalyst.plans.logical.{GlobalLimit, LocalLimit, LogicalPlan, Offset, Project} import org.apache.spark.sql.catalyst.rules.Rule -import org.apache.spark.sql.catalyst.trees.TreePattern.{LIMIT, PROJECT} +import org.apache.spark.sql.catalyst.trees.TreePattern.{LIMIT, OFFSET, PROJECT} /** * Pushes Project operator through Limit operator. */ -object PushProjectionThroughLimit extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan.transformWithPruning( - _.containsAllPatterns(PROJECT, LIMIT)) { +object PushProjectionThroughLimitAndOffset extends Rule[LogicalPlan] { + def apply(plan: LogicalPlan): LogicalPlan = plan.transformWithPruning(treeBits => + treeBits.containsPattern(PROJECT) && treeBits.containsAnyPattern(OFFSET, LIMIT)) { case p @ Project(projectList, limit @ LocalLimit(_, child)) if projectList.forall(_.deterministic) => @@ -35,5 +35,9 @@ object PushProjectionThroughLimit extends Rule[LogicalPlan] { case p @ Project(projectList, g @ GlobalLimit(_, limit @ LocalLimit(_, child))) if projectList.forall(_.deterministic) => g.copy(child = limit.copy(child = p.copy(projectList, child))) + + case p @ Project(projectList, offset @ Offset(_, child)) + if projectList.forall(_.deterministic) => + offset.copy(child = p.copy(projectList, child)) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PushProjectionThroughOffset.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PushProjectionThroughOffset.scala deleted file mode 100644 index 498b0131b1b62..0000000000000 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PushProjectionThroughOffset.scala +++ /dev/null @@ -1,35 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.catalyst.optimizer - -import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Offset, Project} -import org.apache.spark.sql.catalyst.rules.Rule -import org.apache.spark.sql.catalyst.trees.TreePattern.{OFFSET, PROJECT} - -/** - * Pushes Project operator through Offset operator. - */ -object PushProjectionThroughOffset extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan.transformWithPruning( - _.containsAllPatterns(PROJECT, OFFSET)) { - - case p @ Project(projectList, offset @ Offset(_, child)) - if projectList.forall(_.deterministic) => - offset.copy(child = p.copy(projectList, child)) - } -} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PushProjectionThroughOffsetSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PushProjectionThroughLimitAndOffsetSuite.scala similarity index 66% rename from sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PushProjectionThroughOffsetSuite.scala rename to sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PushProjectionThroughLimitAndOffsetSuite.scala index 7b0a5b0fd2e0b..9a57630ebc13b 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PushProjectionThroughOffsetSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PushProjectionThroughLimitAndOffsetSuite.scala @@ -25,17 +25,71 @@ import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} import org.apache.spark.sql.catalyst.rules.RuleExecutor -class PushProjectionThroughOffsetSuite extends PlanTest { - +class PushProjectionThroughLimitAndOffsetSuite extends PlanTest { object Optimize extends RuleExecutor[LogicalPlan] { val batches = Batch("Optimizer Batch", FixedPoint(100), - PushProjectionThroughLimit, - PushProjectionThroughOffset, + PushProjectionThroughLimitAndOffset, EliminateLimits, LimitPushDown) :: Nil } + test("SPARK-40501: push projection through limit") { + val testRelation = LocalRelation.fromExternalRows( + Seq("a".attr.int, "b".attr.int, "c".attr.int), + 1.to(20).map(_ => Row(1, 2, 3))) + + val query1 = testRelation + .limit(10) + .select(Symbol("a"), Symbol("b"), 'c') + .limit(15).analyze + val optimized1 = Optimize.execute(query1) + val expected1 = testRelation + .select(Symbol("a"), Symbol("b"), 'c') + .limit(10).analyze + comparePlans(optimized1, expected1) + + val query2 = testRelation + .sortBy($"a".asc) + .limit(10) + .select(Symbol("a"), Symbol("b"), 'c') + .limit(15).analyze + val optimized2 = Optimize.execute(query2) + val expected2 = testRelation + .sortBy($"a".asc) + .select(Symbol("a"), Symbol("b"), 'c') + .limit(10).analyze + comparePlans(optimized2, expected2) + + val query3 = testRelation + .limit(10) + .select(Symbol("a"), Symbol("b"), 'c') + .limit(20) + .select(Symbol("a")) + .limit(15).analyze + val optimized3 = Optimize.execute(query3) + val expected3 = testRelation + .select(Symbol("a"), Symbol("b"), 'c') + .select(Symbol("a")) + .limit(10).analyze + comparePlans(optimized3, expected3) + + val query4 = testRelation + .sortBy($"a".asc) + .limit(10) + .select(Symbol("a"), Symbol("b"), 'c') + .limit(20) + .select(Symbol("a")) + .limit(15).analyze + val optimized4 = Optimize.execute(query4) + val expected4 = testRelation + .sortBy($"a".asc) + .select(Symbol("a"), Symbol("b"), 'c') + .select(Symbol("a")) + .limit(10).analyze + comparePlans(optimized4, expected4) + } + test("push projection through offset") { val testRelation = LocalRelation.fromExternalRows( Seq("a".attr.int, "b".attr.int, "c".attr.int), diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PushProjectionThroughLimitSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PushProjectionThroughLimitSuite.scala deleted file mode 100644 index 9af73158ee732..0000000000000 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PushProjectionThroughLimitSuite.scala +++ /dev/null @@ -1,90 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.catalyst.optimizer - -import org.apache.spark.sql.Row -import org.apache.spark.sql.catalyst.dsl.expressions._ -import org.apache.spark.sql.catalyst.dsl.plans._ -import org.apache.spark.sql.catalyst.plans.PlanTest -import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} -import org.apache.spark.sql.catalyst.rules.RuleExecutor - -class PushProjectionThroughLimitSuite extends PlanTest { - object Optimize extends RuleExecutor[LogicalPlan] { - val batches = Batch("Optimizer Batch", - FixedPoint(100), - PushProjectionThroughLimit, - EliminateLimits) :: Nil - } - - test("SPARK-40501: push projection through limit") { - val testRelation = LocalRelation.fromExternalRows( - Seq("a".attr.int, "b".attr.int, "c".attr.int), - 1.to(20).map(_ => Row(1, 2, 3))) - - val query1 = testRelation - .limit(10) - .select(Symbol("a"), Symbol("b"), 'c') - .limit(15).analyze - val optimized1 = Optimize.execute(query1) - val expected1 = testRelation - .select(Symbol("a"), Symbol("b"), 'c') - .limit(10).analyze - comparePlans(optimized1, expected1) - - val query2 = testRelation - .sortBy($"a".asc) - .limit(10) - .select(Symbol("a"), Symbol("b"), 'c') - .limit(15).analyze - val optimized2 = Optimize.execute(query2) - val expected2 = testRelation - .sortBy($"a".asc) - .select(Symbol("a"), Symbol("b"), 'c') - .limit(10).analyze - comparePlans(optimized2, expected2) - - val query3 = testRelation - .limit(10) - .select(Symbol("a"), Symbol("b"), 'c') - .limit(20) - .select(Symbol("a")) - .limit(15).analyze - val optimized3 = Optimize.execute(query3) - val expected3 = testRelation - .select(Symbol("a"), Symbol("b"), 'c') - .select(Symbol("a")) - .limit(10).analyze - comparePlans(optimized3, expected3) - - val query4 = testRelation - .sortBy($"a".asc) - .limit(10) - .select(Symbol("a"), Symbol("b"), 'c') - .limit(20) - .select(Symbol("a")) - .limit(15).analyze - val optimized4 = Optimize.execute(query4) - val expected4 = testRelation - .sortBy($"a".asc) - .select(Symbol("a"), Symbol("b"), 'c') - .select(Symbol("a")) - .limit(10).analyze - comparePlans(optimized4, expected4) - } -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala index a51870cfd7fdd..60bde20fe235c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala @@ -87,7 +87,7 @@ class SparkOptimizer( ColumnPruning, LimitPushDown, PushPredicateThroughNonJoin, - PushProjectionThroughLimit, + PushProjectionThroughLimitAndOffset, RemoveNoopOperators), Batch("Infer window group limit", Once, InferWindowGroupLimit, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFsSuite.scala index 0ab8691801d7f..79e157c4db6ae 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFsSuite.scala @@ -198,8 +198,8 @@ class ExtractPythonUDFsSuite extends SparkPlanTest with SharedSparkSession { test("Infers LocalLimit for Python evaluator") { val df = Seq(("Hello", 4), ("World", 8)).toDF("a", "b") - // Check that PushProjectionThroughLimit brings GlobalLimit - LocalLimit to the top (for - // CollectLimit) and that LimitPushDown keeps LocalLimit under UDF. + // Check that PushProjectionThroughLimitAndOffset brings GlobalLimit - LocalLimit to the top + // (for CollectLimit) and that LimitPushDown keeps LocalLimit under UDF. val df2 = df.limit(1).select(batchedPythonUDF(col("b"))) assert(df2.queryExecution.optimizedPlan match { case Limit(_, _) => true