diff --git a/backends-velox/src/main/scala/io/glutenproject/backendsapi/velox/VeloxBackend.scala b/backends-velox/src/main/scala/io/glutenproject/backendsapi/velox/VeloxBackend.scala index eaad9afd6a27..fc374e7d1544 100644 --- a/backends-velox/src/main/scala/io/glutenproject/backendsapi/velox/VeloxBackend.scala +++ b/backends-velox/src/main/scala/io/glutenproject/backendsapi/velox/VeloxBackend.scala @@ -26,7 +26,7 @@ import io.glutenproject.substrait.rel.LocalFilesNode.ReadFileFormat import io.glutenproject.substrait.rel.LocalFilesNode.ReadFileFormat.{DwrfReadFormat, OrcReadFormat, ParquetReadFormat} import org.apache.spark.sql.catalyst.catalog.BucketSpec -import org.apache.spark.sql.catalyst.expressions.{Alias, CumeDist, DenseRank, Descending, Expression, Literal, NamedExpression, NthValue, PercentRank, Rand, RangeFrame, Rank, RowNumber, SortOrder, SpecialFrameBoundary, SpecifiedWindowFrame} +import org.apache.spark.sql.catalyst.expressions.{Alias, CumeDist, DenseRank, Descending, Expression, Literal, NamedExpression, NthValue, NTile, PercentRank, Rand, RangeFrame, Rank, RowNumber, SortOrder, SpecialFrameBoundary, SpecifiedWindowFrame} import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, Count, Sum} import org.apache.spark.sql.catalyst.plans.JoinType import org.apache.spark.sql.catalyst.util.CharVarcharUtils @@ -286,7 +286,7 @@ object BackendSettings extends BackendSettingsApi { } windowExpression.windowFunction match { case _: RowNumber | _: AggregateExpression | _: Rank | _: CumeDist | _: DenseRank | - _: PercentRank | _: NthValue => + _: PercentRank | _: NthValue | _: NTile => case _ => allSupported = false } diff --git a/backends-velox/src/test/scala/io/glutenproject/execution/TestOperator.scala b/backends-velox/src/test/scala/io/glutenproject/execution/TestOperator.scala index c32248af1b8c..4f3dfd7b1cc5 100644 --- a/backends-velox/src/test/scala/io/glutenproject/execution/TestOperator.scala +++ b/backends-velox/src/test/scala/io/glutenproject/execution/TestOperator.scala @@ -212,6 +212,12 @@ class TestOperator extends VeloxWholeStageTransformerSuite with AdaptiveSparkPla Seq("sort", "streaming").foreach { windowType => withSQLConf("spark.gluten.sql.columnar.backend.velox.window.type" -> windowType) { + runQueryAndCompare( + "select ntile(4) over" + + " (partition by l_suppkey order by l_orderkey) from lineitem ") { + assertWindowOffloaded + } + runQueryAndCompare( "select row_number() over" + " (partition by l_suppkey order by l_orderkey) from lineitem ") { diff --git a/docs/velox-backend-support-progress.md b/docs/velox-backend-support-progress.md index 927be47283d1..a5fd0ae87cc9 100644 --- a/docs/velox-backend-support-progress.md +++ b/docs/velox-backend-support-progress.md @@ -387,7 +387,7 @@ Gluten supports 199 functions. (Draw to right to see all data types) | lag | | | | | | | | | | | | | | | | | | | | | | | | lead | | | | | | | | | | | | | | | | | | | | | | | | nth_value | nth_value | nth_value | PS | | | | | | | | | | | | | | | | | | | | -| ntile | | | S | | | | | | | | | | | | | | | | | | | | +| ntile | ntile | ntile | S | | | | | | | | | | | | | | | | | | | | | percent_rank | percent_rank | | S | | | | | | | | | | | | | | | | | | | | | rank | rank | | S | | | | | | | | | | | | | | | | | | | | | row_number | row_number | | S | | | | S | S | S | | | | | | | | | | | | | | diff --git a/gluten-core/src/main/scala/io/glutenproject/backendsapi/SparkPlanExecApi.scala b/gluten-core/src/main/scala/io/glutenproject/backendsapi/SparkPlanExecApi.scala index 7df4a14e178e..f239a8766206 100644 --- a/gluten-core/src/main/scala/io/glutenproject/backendsapi/SparkPlanExecApi.scala +++ b/gluten-core/src/main/scala/io/glutenproject/backendsapi/SparkPlanExecApi.scala @@ -551,6 +551,21 @@ trait SparkPlanExecApi { frame.frameType.sql ) windowExpressionNodes.add(windowFunctionNode) + case wf @ NTile(buckets: Expression) => + val frame = wExpression.windowSpec.frameSpecification.asInstanceOf[SpecifiedWindowFrame] + val childrenNodeList = new JArrayList[ExpressionNode]() + val literal = buckets.asInstanceOf[Literal] + childrenNodeList.add(LiteralTransformer(literal).doTransform(args)) + val windowFunctionNode = ExpressionBuilder.makeWindowFunction( + WindowFunctionsBuilder.create(args, wf).toInt, + childrenNodeList, + columnName, + ConverterUtils.getTypeNode(wf.dataType, wf.nullable), + frame.upper.sql, + frame.lower.sql, + frame.frameType.sql + ) + windowExpressionNodes.add(windowFunctionNode) case _ => throw new UnsupportedOperationException( "unsupported window function type: " +