diff --git a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHSparkPlanExecApi.scala b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHSparkPlanExecApi.scala index a8a05c40f1cd..d7faa07a5a2e 100644 --- a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHSparkPlanExecApi.scala +++ b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHSparkPlanExecApi.scala @@ -66,6 +66,7 @@ import org.apache.commons.lang3.ClassUtils import java.lang.{Long => JLong} import java.util.{ArrayList => JArrayList, List => JList, Map => JMap} +import scala.collection.JavaConverters._ import scala.collection.mutable.ArrayBuffer class CHSparkPlanExecApi extends SparkPlanExecApi { @@ -727,9 +728,10 @@ class CHSparkPlanExecApi extends SparkPlanExecApi { new JArrayList[ExpressionNode](), columnName, ConverterUtils.getTypeNode(aggWindowFunc.dataType, aggWindowFunc.nullable), - WindowExecTransformer.getFrameBound(frame.upper), - WindowExecTransformer.getFrameBound(frame.lower), - frame.frameType.sql + frame.upper, + frame.lower, + frame.frameType.sql, + originalInputAttributes.asJava ) windowExpressionNodes.add(windowFunctionNode) case aggExpression: AggregateExpression => @@ -753,9 +755,10 @@ class CHSparkPlanExecApi extends SparkPlanExecApi { childrenNodeList, columnName, ConverterUtils.getTypeNode(aggExpression.dataType, aggExpression.nullable), - WindowExecTransformer.getFrameBound(frame.upper), - WindowExecTransformer.getFrameBound(frame.lower), - frame.frameType.sql + frame.upper, + frame.lower, + frame.frameType.sql, + originalInputAttributes.asJava ) windowExpressionNodes.add(windowFunctionNode) case wf @ (Lead(_, _, _, _) | Lag(_, _, _, _)) => @@ -802,9 +805,10 @@ class CHSparkPlanExecApi extends SparkPlanExecApi { childrenNodeList, columnName, ConverterUtils.getTypeNode(offsetWf.dataType, offsetWf.nullable), - WindowExecTransformer.getFrameBound(frame.upper), - WindowExecTransformer.getFrameBound(frame.lower), - frame.frameType.sql + frame.upper, + frame.lower, + frame.frameType.sql, + originalInputAttributes.asJava ) windowExpressionNodes.add(windowFunctionNode) case _ => diff --git a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxBackend.scala b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxBackend.scala index f06929fff620..21e6246d1271 100644 --- a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxBackend.scala +++ b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxBackend.scala @@ -296,15 +296,9 @@ object VeloxBackendSettings extends BackendSettingsApi { case _ => throw new GlutenNotSupportException(s"$func is not supported.") } - // Block the offloading by checking Velox's current limitations - // when literal bound type is used for RangeFrame. def checkLimitations(swf: SpecifiedWindowFrame, orderSpec: Seq[SortOrder]): Unit = { - def doCheck(bound: Expression, isUpperBound: Boolean): Unit = { + def doCheck(bound: Expression): Unit = { bound match { - case e if e.foldable => - throw new GlutenNotSupportException( - "Window frame of type RANGE does" + - " not support constant arguments in velox backend") case _: SpecialFrameBoundary => case e if e.foldable => orderSpec.foreach( @@ -325,17 +319,11 @@ object VeloxBackendSettings extends BackendSettingsApi { "Only integral type & date type are" + " supported for sort key when literal bound type is used!") }) - val rawValue = e.eval().toString.toLong - if (isUpperBound && rawValue < 0) { - throw new GlutenNotSupportException("Negative upper bound is not supported!") - } else if (!isUpperBound && rawValue > 0) { - throw new GlutenNotSupportException("Positive lower bound is not supported!") - } case _ => } } - doCheck(swf.upper, true) - doCheck(swf.lower, false) + doCheck(swf.upper) + doCheck(swf.lower) } windowExpression.windowSpec.frameSpecification match { @@ -495,4 +483,6 @@ object VeloxBackendSettings extends BackendSettingsApi { override def supportColumnarArrowUdf(): Boolean = true override def generateHdfsConfForLibhdfs(): Boolean = true + + override def needPreComputeRangeFrameBoundary(): Boolean = true } diff --git a/backends-velox/src/test/scala/org/apache/gluten/execution/TestOperator.scala b/backends-velox/src/test/scala/org/apache/gluten/execution/TestOperator.scala index ae8d64a09937..3cf485aac06b 100644 --- a/backends-velox/src/test/scala/org/apache/gluten/execution/TestOperator.scala +++ b/backends-velox/src/test/scala/org/apache/gluten/execution/TestOperator.scala @@ -212,17 +212,56 @@ class TestOperator extends VeloxWholeStageTransformerSuite with AdaptiveSparkPla Seq("sort", "streaming").foreach { windowType => withSQLConf("spark.gluten.sql.columnar.backend.velox.window.type" -> windowType) { + runQueryAndCompare( + "select max(l_partkey) over" + + " (partition by l_suppkey order by l_orderkey" + + " RANGE BETWEEN 1 PRECEDING AND CURRENT ROW), " + + "min(l_comment) over" + + " (partition by l_suppkey order by l_linenumber" + + " RANGE BETWEEN 1 PRECEDING AND CURRENT ROW) from lineitem ") { + checkSparkOperatorMatch[WindowExecTransformer] + } + runQueryAndCompare( "select max(l_partkey) over" + " (partition by l_suppkey order by l_orderkey" + " RANGE BETWEEN CURRENT ROW AND 2 FOLLOWING) from lineitem ") { - checkSparkOperatorMatch[WindowExec] + checkSparkOperatorMatch[WindowExecTransformer] } runQueryAndCompare( "select max(l_partkey) over" + " (partition by l_suppkey order by l_orderkey" + " RANGE BETWEEN 6 PRECEDING AND CURRENT ROW) from lineitem ") { + checkSparkOperatorMatch[WindowExecTransformer] + } + + runQueryAndCompare( + "select max(l_partkey) over" + + " (partition by l_suppkey order by l_orderkey" + + " RANGE BETWEEN 6 PRECEDING AND 2 FOLLOWING) from lineitem ") { + checkSparkOperatorMatch[WindowExecTransformer] + } + + runQueryAndCompare( + "select max(l_partkey) over" + + " (partition by l_suppkey order by l_orderkey" + + " RANGE BETWEEN 6 PRECEDING AND 3 PRECEDING) from lineitem ") { + checkSparkOperatorMatch[WindowExecTransformer] + } + + runQueryAndCompare( + "select max(l_partkey) over" + + " (partition by l_suppkey order by l_orderkey" + + " RANGE BETWEEN 3 FOLLOWING AND 6 FOLLOWING) from lineitem ") { + checkSparkOperatorMatch[WindowExecTransformer] + } + + // DecimalType as order by column is not supported + runQueryAndCompare( + "select min(l_comment) over" + + " (partition by l_suppkey order by l_discount" + + " RANGE BETWEEN 1 PRECEDING AND CURRENT ROW) from lineitem ") { checkSparkOperatorMatch[WindowExec] } diff --git a/cpp/velox/substrait/SubstraitToVeloxPlan.cc b/cpp/velox/substrait/SubstraitToVeloxPlan.cc index b82eead2c565..4e875d4790e5 100644 --- a/cpp/velox/substrait/SubstraitToVeloxPlan.cc +++ b/cpp/velox/substrait/SubstraitToVeloxPlan.cc @@ -823,10 +823,11 @@ core::PlanNodePtr SubstraitToVeloxPlanConverter::toVeloxPlan(const ::substrait:: nextPlanNodeId(), replicated, unnest, std::move(unnestNames), ordinalityName, childNode); } -const core::WindowNode::Frame createWindowFrame( +const core::WindowNode::Frame SubstraitToVeloxPlanConverter::createWindowFrame( const ::substrait::Expression_WindowFunction_Bound& lower_bound, const ::substrait::Expression_WindowFunction_Bound& upper_bound, - const ::substrait::WindowType& type) { + const ::substrait::WindowType& type, + const RowTypePtr& inputType) { core::WindowNode::Frame frame; switch (type) { case ::substrait::WindowType::ROWS: @@ -839,9 +840,22 @@ const core::WindowNode::Frame createWindowFrame( VELOX_FAIL("the window type only support ROWS and RANGE, and the input type is ", std::to_string(type)); } - auto boundTypeConversion = [](::substrait::Expression_WindowFunction_Bound boundType) + auto specifiedBound = + [&](bool hasOffset, int64_t offset, const ::substrait::Expression& columnRef) -> core::TypedExprPtr { + if (hasOffset) { + VELOX_CHECK( + frame.type != core::WindowNode::WindowType::kRange, + "for RANGE frame offset, we should pre-calculate the range frame boundary and pass the column reference, but got a constant offset.") + return std::make_shared(BIGINT(), variant(offset)); + } else { + VELOX_CHECK( + frame.type != core::WindowNode::WindowType::kRows, "for ROW frame offset, we should pass a constant offset.") + return exprConverter_->toVeloxExpr(columnRef, inputType); + } + }; + + auto boundTypeConversion = [&](::substrait::Expression_WindowFunction_Bound boundType) -> std::tuple { - // TODO: support non-literal expression. if (boundType.has_current_row()) { return std::make_tuple(core::WindowNode::BoundType::kCurrentRow, nullptr); } else if (boundType.has_unbounded_following()) { @@ -849,13 +863,15 @@ const core::WindowNode::Frame createWindowFrame( } else if (boundType.has_unbounded_preceding()) { return std::make_tuple(core::WindowNode::BoundType::kUnboundedPreceding, nullptr); } else if (boundType.has_following()) { + auto following = boundType.following(); return std::make_tuple( core::WindowNode::BoundType::kFollowing, - std::make_shared(BIGINT(), variant(boundType.following().offset()))); + specifiedBound(following.has_offset(), following.offset(), following.ref())); } else if (boundType.has_preceding()) { + auto preceding = boundType.preceding(); return std::make_tuple( core::WindowNode::BoundType::kPreceding, - std::make_shared(BIGINT(), variant(boundType.preceding().offset()))); + specifiedBound(preceding.has_offset(), preceding.offset(), preceding.ref())); } else { VELOX_FAIL("The BoundType is not supported."); } @@ -906,7 +922,7 @@ core::PlanNodePtr SubstraitToVeloxPlanConverter::toVeloxPlan(const ::substrait:: windowColumnNames.push_back(windowFunction.column_name()); windowNodeFunctions.push_back( - {std::move(windowCall), std::move(createWindowFrame(lowerBound, upperBound, type)), ignoreNulls}); + {std::move(windowCall), std::move(createWindowFrame(lowerBound, upperBound, type, inputType)), ignoreNulls}); } // Construct partitionKeys diff --git a/cpp/velox/substrait/SubstraitToVeloxPlan.h b/cpp/velox/substrait/SubstraitToVeloxPlan.h index 567ebb215078..3a0e677afeaa 100644 --- a/cpp/velox/substrait/SubstraitToVeloxPlan.h +++ b/cpp/velox/substrait/SubstraitToVeloxPlan.h @@ -555,6 +555,12 @@ class SubstraitToVeloxPlanConverter { return toVeloxPlan(rel.input()); } + const core::WindowNode::Frame createWindowFrame( + const ::substrait::Expression_WindowFunction_Bound& lower_bound, + const ::substrait::Expression_WindowFunction_Bound& upper_bound, + const ::substrait::WindowType& type, + const RowTypePtr& inputType); + /// The unique identification for each PlanNode. int planNodeId_ = 0; diff --git a/docs/developers/SubstraitModifications.md b/docs/developers/SubstraitModifications.md index 38406425af96..24a9c1a2128d 100644 --- a/docs/developers/SubstraitModifications.md +++ b/docs/developers/SubstraitModifications.md @@ -27,6 +27,7 @@ changed `Unbounded` in `WindowFunction` into `Unbounded_Preceding` and `Unbounde * Added `PartitionColumn` in `LocalFiles`([#2405](https://github.com/apache/incubator-gluten/pull/2405)). * Added `WriteRel` ([#3690](https://github.com/apache/incubator-gluten/pull/3690)). * Added `TopNRel` ([#5409](https://github.com/apache/incubator-gluten/pull/5409)). +* Added `ref` field in window bound `Preceding` and `Following` ([#5626](https://github.com/apache/incubator-gluten/pull/5626)). ## Modifications to type.proto diff --git a/gluten-core/src/main/java/org/apache/gluten/substrait/expression/ExpressionBuilder.java b/gluten-core/src/main/java/org/apache/gluten/substrait/expression/ExpressionBuilder.java index 5d106938cef5..e322e1528cac 100644 --- a/gluten-core/src/main/java/org/apache/gluten/substrait/expression/ExpressionBuilder.java +++ b/gluten-core/src/main/java/org/apache/gluten/substrait/expression/ExpressionBuilder.java @@ -21,6 +21,8 @@ import org.apache.gluten.substrait.type.*; import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.catalyst.expressions.Attribute; +import org.apache.spark.sql.catalyst.expressions.Expression; import org.apache.spark.sql.catalyst.util.ArrayData; import org.apache.spark.sql.catalyst.util.MapData; import org.apache.spark.sql.types.*; @@ -264,9 +266,10 @@ public static WindowFunctionNode makeWindowFunction( List expressionNodes, String columnName, TypeNode outputTypeNode, - String upperBound, - String lowerBound, - String frameType) { + Expression upperBound, + Expression lowerBound, + String frameType, + List originalInputAttributes) { return makeWindowFunction( functionId, expressionNodes, @@ -275,7 +278,8 @@ public static WindowFunctionNode makeWindowFunction( upperBound, lowerBound, frameType, - false); + false, + originalInputAttributes); } public static WindowFunctionNode makeWindowFunction( @@ -283,10 +287,11 @@ public static WindowFunctionNode makeWindowFunction( List expressionNodes, String columnName, TypeNode outputTypeNode, - String upperBound, - String lowerBound, + Expression upperBound, + Expression lowerBound, String frameType, - boolean ignoreNulls) { + boolean ignoreNulls, + List originalInputAttributes) { return new WindowFunctionNode( functionId, expressionNodes, @@ -295,6 +300,7 @@ public static WindowFunctionNode makeWindowFunction( upperBound, lowerBound, frameType, - ignoreNulls); + ignoreNulls, + originalInputAttributes); } } diff --git a/gluten-core/src/main/java/org/apache/gluten/substrait/expression/WindowFunctionNode.java b/gluten-core/src/main/java/org/apache/gluten/substrait/expression/WindowFunctionNode.java index 67d0d6e575ff..b9f1fbc126cc 100644 --- a/gluten-core/src/main/java/org/apache/gluten/substrait/expression/WindowFunctionNode.java +++ b/gluten-core/src/main/java/org/apache/gluten/substrait/expression/WindowFunctionNode.java @@ -16,17 +16,24 @@ */ package org.apache.gluten.substrait.expression; +import org.apache.gluten.exception.GlutenException; +import org.apache.gluten.expression.ExpressionConverter; import org.apache.gluten.substrait.type.TypeNode; import io.substrait.proto.Expression; import io.substrait.proto.FunctionArgument; import io.substrait.proto.FunctionOption; import io.substrait.proto.WindowType; +import org.apache.spark.sql.catalyst.expressions.Attribute; +import org.apache.spark.sql.catalyst.expressions.PreComputeRangeFrameBound; import java.io.Serializable; import java.util.ArrayList; +import java.util.HashMap; import java.util.List; +import scala.collection.JavaConverters; + public class WindowFunctionNode implements Serializable { private final Integer functionId; private final List expressionNodes = new ArrayList<>(); @@ -34,23 +41,26 @@ public class WindowFunctionNode implements Serializable { private final String columnName; private final TypeNode outputTypeNode; - private final String upperBound; + private final org.apache.spark.sql.catalyst.expressions.Expression upperBound; - private final String lowerBound; + private final org.apache.spark.sql.catalyst.expressions.Expression lowerBound; private final String frameType; private final boolean ignoreNulls; + private final List originalInputAttributes; + WindowFunctionNode( Integer functionId, List expressionNodes, String columnName, TypeNode outputTypeNode, - String upperBound, - String lowerBound, + org.apache.spark.sql.catalyst.expressions.Expression upperBound, + org.apache.spark.sql.catalyst.expressions.Expression lowerBound, String frameType, - boolean ignoreNulls) { + boolean ignoreNulls, + List originalInputAttributes) { this.functionId = functionId; this.expressionNodes.addAll(expressionNodes); this.columnName = columnName; @@ -59,11 +69,13 @@ public class WindowFunctionNode implements Serializable { this.lowerBound = lowerBound; this.frameType = frameType; this.ignoreNulls = ignoreNulls; + this.originalInputAttributes = originalInputAttributes; } private Expression.WindowFunction.Bound.Builder setBound( - Expression.WindowFunction.Bound.Builder builder, String boundType) { - switch (boundType) { + Expression.WindowFunction.Bound.Builder builder, + org.apache.spark.sql.catalyst.expressions.Expression boundType) { + switch (boundType.sql()) { case ("CURRENT ROW"): Expression.WindowFunction.Bound.CurrentRow.Builder currentRowBuilder = Expression.WindowFunction.Bound.CurrentRow.newBuilder(); @@ -80,8 +92,36 @@ private Expression.WindowFunction.Bound.Builder setBound( builder.setUnboundedFollowing(followingBuilder.build()); break; default: - try { - Long offset = Long.valueOf(boundType); + if (boundType instanceof PreComputeRangeFrameBound) { + // Used only when backend is velox and frame type is RANGE. + if (!frameType.equals("RANGE")) { + throw new GlutenException( + "Only Range frame supports PreComputeRangeFrameBound, but got " + frameType); + } + ExpressionNode refNode = + ExpressionConverter.replaceWithExpressionTransformer( + ((PreComputeRangeFrameBound) boundType).child().toAttribute(), + JavaConverters.asScalaIteratorConverter(originalInputAttributes.iterator()) + .asScala() + .toSeq()) + .doTransform(new HashMap()); + Long offset = Long.valueOf(boundType.eval(null).toString()); + if (offset < 0) { + Expression.WindowFunction.Bound.Preceding.Builder refPrecedingBuilder = + Expression.WindowFunction.Bound.Preceding.newBuilder(); + refPrecedingBuilder.setRef(refNode.toProtobuf()); + builder.setPreceding(refPrecedingBuilder.build()); + } else { + Expression.WindowFunction.Bound.Following.Builder refFollowingBuilder = + Expression.WindowFunction.Bound.Following.newBuilder(); + refFollowingBuilder.setRef(refNode.toProtobuf()); + builder.setFollowing(refFollowingBuilder.build()); + } + } else if (boundType.foldable()) { + // Used when + // 1. Velox backend and frame type is ROW + // 2. Clickhouse backend + Long offset = Long.valueOf(boundType.eval(null).toString()); if (offset < 0) { Expression.WindowFunction.Bound.Preceding.Builder offsetPrecedingBuilder = Expression.WindowFunction.Bound.Preceding.newBuilder(); @@ -93,9 +133,9 @@ private Expression.WindowFunction.Bound.Builder setBound( offsetFollowingBuilder.setOffset(offset); builder.setFollowing(offsetFollowingBuilder.build()); } - } catch (NumberFormatException e) { + } else { throw new UnsupportedOperationException( - "Unsupported Window Function Frame Type:" + boundType); + "Unsupported Window Function Frame Bound Type: " + boundType); } } return builder; diff --git a/gluten-core/src/main/resources/substrait/proto/substrait/algebra.proto b/gluten-core/src/main/resources/substrait/proto/substrait/algebra.proto index 877493439f95..0e51baf5ad4c 100644 --- a/gluten-core/src/main/resources/substrait/proto/substrait/algebra.proto +++ b/gluten-core/src/main/resources/substrait/proto/substrait/algebra.proto @@ -996,18 +996,28 @@ message Expression { message Bound { // Defines that the bound extends this far back from the current record. message Preceding { - // A strictly positive integer specifying the number of records that - // the window extends back from the current record. Required. Use - // CurrentRow for offset zero and Following for negative offsets. - int64 offset = 1; + oneof kind { + // A strictly positive integer specifying the number of records that + // the window extends back from the current record. Use + // CurrentRow for offset zero and Following for negative offsets. + int64 offset = 1; + + // the reference to pre-project range frame boundary. + Expression ref = 2; + } } // Defines that the bound extends this far ahead of the current record. message Following { - // A strictly positive integer specifying the number of records that - // the window extends ahead of the current record. Required. Use - // CurrentRow for offset zero and Preceding for negative offsets. - int64 offset = 1; + oneof kind { + // A strictly positive integer specifying the number of records that + // the window extends ahead of the current record. Use + // CurrentRow for offset zero and Preceding for negative offsets. + int64 offset = 1; + + // the reference to pre-project range frame boundary. + Expression ref = 2; + } } // Defines that the bound extends to or from the current record. diff --git a/gluten-core/src/main/scala/org/apache/gluten/backendsapi/BackendSettingsApi.scala b/gluten-core/src/main/scala/org/apache/gluten/backendsapi/BackendSettingsApi.scala index d18273af2faa..b7a3bc1b6ef2 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/backendsapi/BackendSettingsApi.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/backendsapi/BackendSettingsApi.scala @@ -152,4 +152,6 @@ trait BackendSettingsApi { def supportColumnarArrowUdf(): Boolean = false def generateHdfsConfForLibhdfs(): Boolean = false + + def needPreComputeRangeFrameBoundary(): Boolean = false } diff --git a/gluten-core/src/main/scala/org/apache/gluten/backendsapi/SparkPlanExecApi.scala b/gluten-core/src/main/scala/org/apache/gluten/backendsapi/SparkPlanExecApi.scala index 8a1baae51092..8bc8e136bd5d 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/backendsapi/SparkPlanExecApi.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/backendsapi/SparkPlanExecApi.scala @@ -529,9 +529,10 @@ trait SparkPlanExecApi { new JArrayList[ExpressionNode](), columnName, ConverterUtils.getTypeNode(aggWindowFunc.dataType, aggWindowFunc.nullable), - WindowExecTransformer.getFrameBound(frame.upper), - WindowExecTransformer.getFrameBound(frame.lower), - frame.frameType.sql + frame.upper, + frame.lower, + frame.frameType.sql, + originalInputAttributes.asJava ) windowExpressionNodes.add(windowFunctionNode) case aggExpression: AggregateExpression => @@ -554,9 +555,10 @@ trait SparkPlanExecApi { childrenNodeList, columnName, ConverterUtils.getTypeNode(aggExpression.dataType, aggExpression.nullable), - WindowExecTransformer.getFrameBound(frame.upper), - WindowExecTransformer.getFrameBound(frame.lower), - frame.frameType.sql + frame.upper, + frame.lower, + frame.frameType.sql, + originalInputAttributes.asJava ) windowExpressionNodes.add(windowFunctionNode) case wf @ (_: Lead | _: Lag) => @@ -590,10 +592,11 @@ trait SparkPlanExecApi { childrenNodeList, columnName, ConverterUtils.getTypeNode(offsetWf.dataType, offsetWf.nullable), - WindowExecTransformer.getFrameBound(frame.upper), - WindowExecTransformer.getFrameBound(frame.lower), + frame.upper, + frame.lower, frame.frameType.sql, - offsetWf.ignoreNulls + offsetWf.ignoreNulls, + originalInputAttributes.asJava ) windowExpressionNodes.add(windowFunctionNode) case wf @ NthValue(input, offset: Literal, ignoreNulls: Boolean) => @@ -609,10 +612,11 @@ trait SparkPlanExecApi { childrenNodeList, columnName, ConverterUtils.getTypeNode(wf.dataType, wf.nullable), - frame.upper.sql, - frame.lower.sql, + frame.upper, + frame.lower, frame.frameType.sql, - ignoreNulls + ignoreNulls, + originalInputAttributes.asJava ) windowExpressionNodes.add(windowFunctionNode) case wf @ NTile(buckets: Expression) => @@ -625,9 +629,10 @@ trait SparkPlanExecApi { childrenNodeList, columnName, ConverterUtils.getTypeNode(wf.dataType, wf.nullable), - frame.upper.sql, - frame.lower.sql, - frame.frameType.sql + frame.upper, + frame.lower, + frame.frameType.sql, + originalInputAttributes.asJava ) windowExpressionNodes.add(windowFunctionNode) case _ => diff --git a/gluten-core/src/main/scala/org/apache/gluten/execution/WindowExecTransformer.scala b/gluten-core/src/main/scala/org/apache/gluten/execution/WindowExecTransformer.scala index ef6a767b5604..6832221a404d 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/execution/WindowExecTransformer.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/execution/WindowExecTransformer.scala @@ -197,16 +197,3 @@ case class WindowExecTransformer( override protected def withNewChildInternal(newChild: SparkPlan): WindowExecTransformer = copy(child = newChild) } - -object WindowExecTransformer { - - /** Gets lower/upper bound represented in string. */ - def getFrameBound(bound: Expression): String = { - // The lower/upper can be either a foldable Expression or a SpecialFrameBoundary. - if (bound.foldable) { - bound.eval().toString - } else { - bound.sql - } - } -} diff --git a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/rewrite/PullOutPreProject.scala b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/rewrite/PullOutPreProject.scala index 50dc55423605..73b8ab2607eb 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/rewrite/PullOutPreProject.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/rewrite/PullOutPreProject.scala @@ -75,6 +75,17 @@ object PullOutPreProject extends RewriteSingleNode with PullOutProjectHelper { case _ => false } case _ => false + }.isDefined) || + window.windowExpression.exists(_.find { + case we: WindowExpression => + we.windowSpec.frameSpecification match { + case swf: SpecifiedWindowFrame + if needPreComputeRangeFrame(swf) && supportPreComputeRangeFrame( + we.windowSpec.orderSpec) => + true + case _ => false + } + case _ => false }.isDefined) case plan if SparkShimLoader.getSparkShims.isWindowGroupLimitExec(plan) => val window = SparkShimLoader.getSparkShims @@ -174,7 +185,9 @@ object PullOutPreProject extends RewriteSingleNode with PullOutProjectHelper { // Handle windowExpressions. val newWindowExpressions = window.windowExpression.toIndexedSeq.map { - _.transform { case we: WindowExpression => rewriteWindowExpression(we, expressionMap) } + _.transform { + case we: WindowExpression => rewriteWindowExpression(we, newOrderSpec, expressionMap) + } } val newWindow = window.copy( diff --git a/gluten-core/src/main/scala/org/apache/gluten/utils/PullOutProjectHelper.scala b/gluten-core/src/main/scala/org/apache/gluten/utils/PullOutProjectHelper.scala index 505f13f263a2..12055f9e9721 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/utils/PullOutProjectHelper.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/utils/PullOutProjectHelper.scala @@ -16,11 +16,13 @@ */ package org.apache.gluten.utils -import org.apache.gluten.exception.GlutenNotSupportException +import org.apache.gluten.backendsapi.BackendsApiManager +import org.apache.gluten.exception.{GlutenException, GlutenNotSupportException} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, AggregateFunction} import org.apache.spark.sql.execution.aggregate._ +import org.apache.spark.sql.types.{ByteType, DateType, IntegerType, LongType, ShortType} import java.util.concurrent.atomic.AtomicInteger @@ -143,8 +145,49 @@ trait PullOutProjectHelper { ae.copy(aggregateFunction = newAggFunc, filter = newFilter) } + private def needPreComputeRangeFrameBoundary(bound: Expression): Boolean = { + bound match { + case _: PreComputeRangeFrameBound => false + case _ if !bound.foldable => false + case _ => true + } + } + + private def preComputeRangeFrameBoundary( + bound: Expression, + orderSpec: SortOrder, + expressionMap: mutable.HashMap[Expression, NamedExpression]): Expression = { + bound match { + case _: PreComputeRangeFrameBound => bound + case _ if !bound.foldable => bound + case _ if bound.foldable => + val a = expressionMap + .getOrElseUpdate( + bound.canonicalized, + Alias(Add(orderSpec.child, bound), generatePreAliasName)()) + PreComputeRangeFrameBound(a.asInstanceOf[Alias], bound) + } + } + + protected def needPreComputeRangeFrame(swf: SpecifiedWindowFrame): Boolean = { + BackendsApiManager.getSettings.needPreComputeRangeFrameBoundary && + swf.frameType == RangeFrame && + (needPreComputeRangeFrameBoundary(swf.lower) || needPreComputeRangeFrameBoundary(swf.upper)) + } + + protected def supportPreComputeRangeFrame(sortOrders: Seq[SortOrder]): Boolean = { + sortOrders.forall { + _.dataType match { + case ByteType | ShortType | IntegerType | LongType | DateType => true + // Only integral type & date type are supported for sort key with Range Frame + case _ => false + } + } + } + protected def rewriteWindowExpression( we: WindowExpression, + orderSpecs: Seq[SortOrder], expressionMap: mutable.HashMap[Expression, NamedExpression]): WindowExpression = { val newWindowFunc = we.windowFunction match { case windowFunc: WindowFunction => @@ -156,6 +199,22 @@ trait PullOutProjectHelper { case ae: AggregateExpression => rewriteAggregateExpression(ae, expressionMap) case other => other } - we.copy(windowFunction = newWindowFunc) + + val newWindowSpec = we.windowSpec.frameSpecification match { + case swf: SpecifiedWindowFrame if needPreComputeRangeFrame(swf) => + // This is guaranteed by Spark, but we still check it here + if (orderSpecs.size != 1) { + throw new GlutenException( + s"A range window frame with value boundaries expects one and only one " + + s"order by expression: ${orderSpecs.mkString(",")}") + } + val orderSpec = orderSpecs.head + val lowerFrameCol = preComputeRangeFrameBoundary(swf.lower, orderSpec, expressionMap) + val upperFrameCol = preComputeRangeFrameBoundary(swf.upper, orderSpec, expressionMap) + val newFrame = swf.copy(lower = lowerFrameCol, upper = upperFrameCol) + we.windowSpec.copy(frameSpecification = newFrame) + case _ => we.windowSpec + } + we.copy(windowFunction = newWindowFunc, windowSpec = newWindowSpec) } } diff --git a/gluten-core/src/main/scala/org/apache/spark/sql/catalyst/expressions/PreComputeRangeFrameBound.scala b/gluten-core/src/main/scala/org/apache/spark/sql/catalyst/expressions/PreComputeRangeFrameBound.scala new file mode 100644 index 000000000000..73c1cb3de609 --- /dev/null +++ b/gluten-core/src/main/scala/org/apache/spark/sql/catalyst/expressions/PreComputeRangeFrameBound.scala @@ -0,0 +1,80 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.catalyst.expressions + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} +import org.apache.spark.sql.types.{DataType, Metadata} + +/** + * Represents a pre-compute boundary for range frame when boundary is non-SpecialFrameBoundary, + * since Velox doesn't support constant offset for range frame. It acts like the original boundary + * which is foldable and generate the same result when eval is invoked so that if the WindowExec + * fallback to Vanilla Spark it can still work correctly. + * @param child + * The alias to pre-compute projection column + * @param originalBound + * The original boundary which is a foldable expression + */ +case class PreComputeRangeFrameBound(child: Alias, originalBound: Expression) + extends UnaryExpression + with NamedExpression { + + override def foldable: Boolean = true + + override def eval(input: InternalRow): Any = originalBound.eval(input) + + override def genCode(ctx: CodegenContext): ExprCode = originalBound.genCode(ctx) + + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = + originalBound.genCode(ctx) + + override def name: String = child.name + + override def exprId: ExprId = child.exprId + + override def qualifier: Seq[String] = child.qualifier + + override def newInstance(): NamedExpression = + PreComputeRangeFrameBound(child.newInstance().asInstanceOf[Alias], originalBound) + + override lazy val resolved: Boolean = originalBound.resolved + + override def dataType: DataType = child.dataType + + override def nullable: Boolean = child.nullable + + override def metadata: Metadata = child.metadata + + override def toAttribute: Attribute = child.toAttribute + + override def toString: String = child.toString + + override def hashCode(): Int = child.hashCode() + + override def equals(other: Any): Boolean = other match { + case a: PreComputeRangeFrameBound => + child.equals(a.child) + case _ => false + } + + override def sql: String = child.sql + + override protected def withNewChildInternal(newChild: Expression): PreComputeRangeFrameBound = + copy(child = newChild.asInstanceOf[Alias]) + +}