diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/LimitPushDownThroughWindow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/LimitPushDownThroughWindow.scala index 635434741b944..88f92262dcc20 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/LimitPushDownThroughWindow.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/LimitPushDownThroughWindow.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst.optimizer -import org.apache.spark.sql.catalyst.expressions.{Alias, CurrentRow, DenseRank, IntegerLiteral, NamedExpression, NTile, Rank, RowFrame, RowNumber, SpecifiedWindowFrame, UnboundedPreceding, WindowExpression, WindowSpecDefinition} +import org.apache.spark.sql.catalyst.expressions.{Alias, CurrentRow, DenseRank, IntegerLiteral, NamedExpression, Rank, RowFrame, RowNumber, SpecifiedWindowFrame, UnboundedPreceding, WindowExpression, WindowSpecDefinition} import org.apache.spark.sql.catalyst.plans.logical.{Limit, LocalLimit, LogicalPlan, Project, Sort, Window} import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.catalyst.trees.TreePattern.{LIMIT, WINDOW} @@ -33,8 +33,7 @@ object LimitPushDownThroughWindow extends Rule[LogicalPlan] { // The window frame of RankLike and RowNumberLike can only be UNBOUNDED PRECEDING to CURRENT ROW. private def supportsPushdownThroughWindow( windowExpressions: Seq[NamedExpression]): Boolean = windowExpressions.forall { - case Alias(WindowExpression(_: Rank | _: DenseRank | _: NTile | _: RowNumber, - WindowSpecDefinition(Nil, _, + case Alias(WindowExpression(_: Rank | _: DenseRank | _: RowNumber, WindowSpecDefinition(Nil, _, SpecifiedWindowFrame(RowFrame, UnboundedPreceding, CurrentRow))), _) => true case _ => false } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LimitPushdownThroughWindowSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LimitPushdownThroughWindowSuite.scala index b09d10b260174..99812d20bf55f 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LimitPushdownThroughWindowSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LimitPushdownThroughWindowSuite.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.optimizer import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ -import org.apache.spark.sql.catalyst.expressions.{CurrentRow, PercentRank, Rank, RowFrame, RowNumber, SpecifiedWindowFrame, UnboundedPreceding} +import org.apache.spark.sql.catalyst.expressions.{CurrentRow, NTile, PercentRank, Rank, RowFrame, RowNumber, SpecifiedWindowFrame, UnboundedPreceding} import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ @@ -198,4 +198,15 @@ class LimitPushdownThroughWindowSuite extends PlanTest { Optimize.execute(originalQuery.analyze), WithoutOptimize.execute(originalQuery.analyze)) } + + test("SPARK-40002: Should not push through ntile window function") { + val originalQuery = testRelation + .select(a, b, c, + windowExpr(new NTile(), windowSpec(Nil, c.desc :: Nil, windowFrame)).as("nt")) + .limit(2) + + comparePlans( + Optimize.execute(originalQuery.analyze), + WithoutOptimize.execute(originalQuery.analyze)) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala index 78bb3180ff108..6e8a06db9d5d1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala @@ -1127,4 +1127,17 @@ class DataFrameWindowFunctionsSuite extends QueryTest ) ) } + + test("SPARK-40002: ntile should apply before limit") { + val df = Seq.tabulate(101)(identity).toDF("id") + val w = Window.orderBy("id") + checkAnswer( + df.select($"id", ntile(10).over(w)).limit(3), + Seq( + Row(0, 1), + Row(1, 1), + Row(2, 1) + ) + ) + } }