Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[GLUTEN-5625][VL] Support window range frame #5626

Merged
merged 9 commits into from
Jun 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ import org.apache.commons.lang3.ClassUtils
import java.lang.{Long => JLong}
import java.util.{ArrayList => JArrayList, List => JList, Map => JMap}

import scala.collection.JavaConverters._
import scala.collection.mutable.ArrayBuffer

class CHSparkPlanExecApi extends SparkPlanExecApi {
Expand Down Expand Up @@ -727,9 +728,10 @@ class CHSparkPlanExecApi extends SparkPlanExecApi {
new JArrayList[ExpressionNode](),
columnName,
ConverterUtils.getTypeNode(aggWindowFunc.dataType, aggWindowFunc.nullable),
WindowExecTransformer.getFrameBound(frame.upper),
WindowExecTransformer.getFrameBound(frame.lower),
frame.frameType.sql
frame.upper,
frame.lower,
frame.frameType.sql,
originalInputAttributes.asJava
)
windowExpressionNodes.add(windowFunctionNode)
case aggExpression: AggregateExpression =>
Expand All @@ -753,9 +755,10 @@ class CHSparkPlanExecApi extends SparkPlanExecApi {
childrenNodeList,
columnName,
ConverterUtils.getTypeNode(aggExpression.dataType, aggExpression.nullable),
WindowExecTransformer.getFrameBound(frame.upper),
WindowExecTransformer.getFrameBound(frame.lower),
frame.frameType.sql
frame.upper,
frame.lower,
frame.frameType.sql,
originalInputAttributes.asJava
)
windowExpressionNodes.add(windowFunctionNode)
case wf @ (Lead(_, _, _, _) | Lag(_, _, _, _)) =>
Expand Down Expand Up @@ -802,9 +805,10 @@ class CHSparkPlanExecApi extends SparkPlanExecApi {
childrenNodeList,
columnName,
ConverterUtils.getTypeNode(offsetWf.dataType, offsetWf.nullable),
WindowExecTransformer.getFrameBound(frame.upper),
WindowExecTransformer.getFrameBound(frame.lower),
frame.frameType.sql
frame.upper,
frame.lower,
frame.frameType.sql,
originalInputAttributes.asJava
)
windowExpressionNodes.add(windowFunctionNode)
case _ =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -296,15 +296,9 @@ object VeloxBackendSettings extends BackendSettingsApi {
case _ => throw new GlutenNotSupportException(s"$func is not supported.")
}

// Block the offloading by checking Velox's current limitations
// when literal bound type is used for RangeFrame.
def checkLimitations(swf: SpecifiedWindowFrame, orderSpec: Seq[SortOrder]): Unit = {
def doCheck(bound: Expression, isUpperBound: Boolean): Unit = {
def doCheck(bound: Expression): Unit = {
bound match {
case e if e.foldable =>
throw new GlutenNotSupportException(
"Window frame of type RANGE does" +
" not support constant arguments in velox backend")
case _: SpecialFrameBoundary =>
case e if e.foldable =>
orderSpec.foreach(
Expand All @@ -325,17 +319,11 @@ object VeloxBackendSettings extends BackendSettingsApi {
"Only integral type & date type are" +
" supported for sort key when literal bound type is used!")
})
val rawValue = e.eval().toString.toLong
if (isUpperBound && rawValue < 0) {
throw new GlutenNotSupportException("Negative upper bound is not supported!")
} else if (!isUpperBound && rawValue > 0) {
throw new GlutenNotSupportException("Positive lower bound is not supported!")
}
case _ =>
}
}
doCheck(swf.upper, true)
doCheck(swf.lower, false)
doCheck(swf.upper)
doCheck(swf.lower)
}

windowExpression.windowSpec.frameSpecification match {
Expand Down Expand Up @@ -495,4 +483,6 @@ object VeloxBackendSettings extends BackendSettingsApi {
override def supportColumnarArrowUdf(): Boolean = true

override def generateHdfsConfForLibhdfs(): Boolean = true

override def needPreComputeRangeFrameBoundary(): Boolean = true
}
Original file line number Diff line number Diff line change
Expand Up @@ -212,17 +212,56 @@ class TestOperator extends VeloxWholeStageTransformerSuite with AdaptiveSparkPla
Seq("sort", "streaming").foreach {
windowType =>
withSQLConf("spark.gluten.sql.columnar.backend.velox.window.type" -> windowType) {
runQueryAndCompare(
"select max(l_partkey) over" +
" (partition by l_suppkey order by l_orderkey" +
" RANGE BETWEEN 1 PRECEDING AND CURRENT ROW), " +
"min(l_comment) over" +
" (partition by l_suppkey order by l_linenumber" +
" RANGE BETWEEN 1 PRECEDING AND CURRENT ROW) from lineitem ") {
checkSparkOperatorMatch[WindowExecTransformer]
}

runQueryAndCompare(
"select max(l_partkey) over" +
" (partition by l_suppkey order by l_orderkey" +
" RANGE BETWEEN CURRENT ROW AND 2 FOLLOWING) from lineitem ") {
checkSparkOperatorMatch[WindowExec]
checkSparkOperatorMatch[WindowExecTransformer]
}

runQueryAndCompare(
"select max(l_partkey) over" +
" (partition by l_suppkey order by l_orderkey" +
" RANGE BETWEEN 6 PRECEDING AND CURRENT ROW) from lineitem ") {
checkSparkOperatorMatch[WindowExecTransformer]
}

runQueryAndCompare(
"select max(l_partkey) over" +
" (partition by l_suppkey order by l_orderkey" +
" RANGE BETWEEN 6 PRECEDING AND 2 FOLLOWING) from lineitem ") {
checkSparkOperatorMatch[WindowExecTransformer]
}

runQueryAndCompare(
"select max(l_partkey) over" +
" (partition by l_suppkey order by l_orderkey" +
" RANGE BETWEEN 6 PRECEDING AND 3 PRECEDING) from lineitem ") {
checkSparkOperatorMatch[WindowExecTransformer]
}

runQueryAndCompare(
"select max(l_partkey) over" +
" (partition by l_suppkey order by l_orderkey" +
" RANGE BETWEEN 3 FOLLOWING AND 6 FOLLOWING) from lineitem ") {
checkSparkOperatorMatch[WindowExecTransformer]
}

// DecimalType as order by column is not supported
runQueryAndCompare(
"select min(l_comment) over" +
" (partition by l_suppkey order by l_discount" +
" RANGE BETWEEN 1 PRECEDING AND CURRENT ROW) from lineitem ") {
checkSparkOperatorMatch[WindowExec]
}

Expand Down
30 changes: 23 additions & 7 deletions cpp/velox/substrait/SubstraitToVeloxPlan.cc
Original file line number Diff line number Diff line change
Expand Up @@ -823,10 +823,11 @@ core::PlanNodePtr SubstraitToVeloxPlanConverter::toVeloxPlan(const ::substrait::
nextPlanNodeId(), replicated, unnest, std::move(unnestNames), ordinalityName, childNode);
}

const core::WindowNode::Frame createWindowFrame(
const core::WindowNode::Frame SubstraitToVeloxPlanConverter::createWindowFrame(
const ::substrait::Expression_WindowFunction_Bound& lower_bound,
const ::substrait::Expression_WindowFunction_Bound& upper_bound,
const ::substrait::WindowType& type) {
const ::substrait::WindowType& type,
const RowTypePtr& inputType) {
core::WindowNode::Frame frame;
switch (type) {
case ::substrait::WindowType::ROWS:
Expand All @@ -839,23 +840,38 @@ const core::WindowNode::Frame createWindowFrame(
VELOX_FAIL("the window type only support ROWS and RANGE, and the input type is ", std::to_string(type));
}

auto boundTypeConversion = [](::substrait::Expression_WindowFunction_Bound boundType)
auto specifiedBound =
[&](bool hasOffset, int64_t offset, const ::substrait::Expression& columnRef) -> core::TypedExprPtr {
if (hasOffset) {
VELOX_CHECK(
frame.type != core::WindowNode::WindowType::kRange,
"for RANGE frame offset, we should pre-calculate the range frame boundary and pass the column reference, but got a constant offset.")
return std::make_shared<core::ConstantTypedExpr>(BIGINT(), variant(offset));
} else {
VELOX_CHECK(
frame.type != core::WindowNode::WindowType::kRows, "for ROW frame offset, we should pass a constant offset.")
return exprConverter_->toVeloxExpr(columnRef, inputType);
}
};

auto boundTypeConversion = [&](::substrait::Expression_WindowFunction_Bound boundType)
-> std::tuple<core::WindowNode::BoundType, core::TypedExprPtr> {
// TODO: support non-literal expression.
if (boundType.has_current_row()) {
return std::make_tuple(core::WindowNode::BoundType::kCurrentRow, nullptr);
} else if (boundType.has_unbounded_following()) {
return std::make_tuple(core::WindowNode::BoundType::kUnboundedFollowing, nullptr);
} else if (boundType.has_unbounded_preceding()) {
return std::make_tuple(core::WindowNode::BoundType::kUnboundedPreceding, nullptr);
} else if (boundType.has_following()) {
auto following = boundType.following();
return std::make_tuple(
core::WindowNode::BoundType::kFollowing,
std::make_shared<core::ConstantTypedExpr>(BIGINT(), variant(boundType.following().offset())));
specifiedBound(following.has_offset(), following.offset(), following.ref()));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can has_offset be true for Velox backend after the java side handling?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's reused both for ROW and RANGE frame, so offset is used for ROW frame here.

} else if (boundType.has_preceding()) {
auto preceding = boundType.preceding();
return std::make_tuple(
core::WindowNode::BoundType::kPreceding,
std::make_shared<core::ConstantTypedExpr>(BIGINT(), variant(boundType.preceding().offset())));
specifiedBound(preceding.has_offset(), preceding.offset(), preceding.ref()));
} else {
VELOX_FAIL("The BoundType is not supported.");
}
Expand Down Expand Up @@ -906,7 +922,7 @@ core::PlanNodePtr SubstraitToVeloxPlanConverter::toVeloxPlan(const ::substrait::
windowColumnNames.push_back(windowFunction.column_name());

windowNodeFunctions.push_back(
{std::move(windowCall), std::move(createWindowFrame(lowerBound, upperBound, type)), ignoreNulls});
{std::move(windowCall), std::move(createWindowFrame(lowerBound, upperBound, type, inputType)), ignoreNulls});
}

// Construct partitionKeys
Expand Down
6 changes: 6 additions & 0 deletions cpp/velox/substrait/SubstraitToVeloxPlan.h
Original file line number Diff line number Diff line change
Expand Up @@ -555,6 +555,12 @@ class SubstraitToVeloxPlanConverter {
return toVeloxPlan(rel.input());
}

const core::WindowNode::Frame createWindowFrame(
const ::substrait::Expression_WindowFunction_Bound& lower_bound,
const ::substrait::Expression_WindowFunction_Bound& upper_bound,
const ::substrait::WindowType& type,
const RowTypePtr& inputType);

/// The unique identification for each PlanNode.
int planNodeId_ = 0;

Expand Down
1 change: 1 addition & 0 deletions docs/developers/SubstraitModifications.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ changed `Unbounded` in `WindowFunction` into `Unbounded_Preceding` and `Unbounde
* Added `PartitionColumn` in `LocalFiles`([#2405](https://github.com/apache/incubator-gluten/pull/2405)).
* Added `WriteRel` ([#3690](https://github.com/apache/incubator-gluten/pull/3690)).
* Added `TopNRel` ([#5409](https://github.com/apache/incubator-gluten/pull/5409)).
* Added `ref` field in window bound `Preceding` and `Following` ([#5626](https://github.com/apache/incubator-gluten/pull/5626)).

## Modifications to type.proto

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
import org.apache.gluten.substrait.type.*;

import org.apache.spark.sql.catalyst.InternalRow;
import org.apache.spark.sql.catalyst.expressions.Attribute;
import org.apache.spark.sql.catalyst.expressions.Expression;
import org.apache.spark.sql.catalyst.util.ArrayData;
import org.apache.spark.sql.catalyst.util.MapData;
import org.apache.spark.sql.types.*;
Expand Down Expand Up @@ -264,9 +266,10 @@ public static WindowFunctionNode makeWindowFunction(
List<ExpressionNode> expressionNodes,
String columnName,
TypeNode outputTypeNode,
String upperBound,
String lowerBound,
String frameType) {
Expression upperBound,
Expression lowerBound,
String frameType,
List<Attribute> originalInputAttributes) {
return makeWindowFunction(
functionId,
expressionNodes,
Expand All @@ -275,18 +278,20 @@ public static WindowFunctionNode makeWindowFunction(
upperBound,
lowerBound,
frameType,
false);
false,
originalInputAttributes);
}

public static WindowFunctionNode makeWindowFunction(
Integer functionId,
List<ExpressionNode> expressionNodes,
String columnName,
TypeNode outputTypeNode,
String upperBound,
String lowerBound,
Expression upperBound,
Expression lowerBound,
String frameType,
boolean ignoreNulls) {
boolean ignoreNulls,
List<Attribute> originalInputAttributes) {
return new WindowFunctionNode(
functionId,
expressionNodes,
Expand All @@ -295,6 +300,7 @@ public static WindowFunctionNode makeWindowFunction(
upperBound,
lowerBound,
frameType,
ignoreNulls);
ignoreNulls,
originalInputAttributes);
}
}
Loading
Loading