Skip to content

Commit

Permalink
support approx percentile
Browse files Browse the repository at this point in the history
  • Loading branch information
WangGuangxin committed Mar 17, 2024
1 parent 2ca27fb commit 97db869
Show file tree
Hide file tree
Showing 8 changed files with 157 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ import io.glutenproject.expression.ConverterUtils.FunctionConfig
import io.glutenproject.extension.columnar.RewriteTypedImperativeAggregate
import io.glutenproject.substrait.`type`.{TypeBuilder, TypeNode}
import io.glutenproject.substrait.{AggregationParams, SubstraitContext}
import io.glutenproject.substrait.expression.{AggregateFunctionNode, ExpressionBuilder, ExpressionNode, ScalarFunctionNode}
import io.glutenproject.substrait.expression.{AggregateFunctionNode, DoubleLiteralNode, ExpressionBuilder, ExpressionNode, IntLiteralNode, ScalarFunctionNode}
import io.glutenproject.substrait.extensions.{AdvancedExtensionNode, ExtensionBuilder}
import io.glutenproject.substrait.rel.{RelBuilder, RelNode}
import io.glutenproject.utils.VeloxIntermediateData
Expand Down Expand Up @@ -71,7 +71,7 @@ abstract class HashAggregateExecTransformer(
aggFunc: AggregateFunction,
mode: AggregateMode): Boolean = {
aggFunc match {
case _: HLLAdapter =>
case _: HLLAdapter | _: ApproximatePercentile =>
mode match {
case Partial | Final => true
case _ => false
Expand Down Expand Up @@ -264,6 +264,55 @@ abstract class HashAggregateExecTransformer(
case other =>
throw new UnsupportedOperationException(s"$other is not supported.")
}
case p: ApproximatePercentile =>
aggregateMode match {
case Partial =>
// The datatype of ApproximatePercentile's third child are different
// between Spark and Velox.
// In Spark, the `accuracy` parameter is a
// positive numeric literal, the `1.0/accuracy` is the relative error
// of the approximation.
// While in Velox, the `accuracy` parameter is the relative error itself.
if (childrenNodeList.size() != 3) {
throw new IllegalArgumentException(
s"Expected three children for " +
s"ApproximatePercentile, but got ${childrenNodeList.size()}")
}
val accuracyChild = childrenNodeList.get(2)
if (!accuracyChild.isInstanceOf[IntLiteralNode]) {
throw new IllegalArgumentException(
s"Expected a Integer Literal " +
s"for ApproximatePercentile's accuracy, but got $accuracyChild")
}
val accuracy = accuracyChild.asInstanceOf[IntLiteralNode].getValue
val newAccuracyNode = new DoubleLiteralNode(1.0 / accuracy)
val newChildrenNodeList = new JArrayList[ExpressionNode]()
newChildrenNodeList.add(childrenNodeList.get(0))
newChildrenNodeList.add(childrenNodeList.get(1))
newChildrenNodeList.add(newAccuracyNode)

// For Partial mode output type is struct
val outputType =
RewriteTypedImperativeAggregate.getPercentileLikeInterminateDataType(p)
val partialNode = ExpressionBuilder.makeAggregateFunction(
VeloxAggregateFunctionsBuilder.create(args, aggregateFunction, aggregateMode),
newChildrenNodeList,
modeKeyWord,
ConverterUtils.getTypeNode(outputType, p.nullable)
)
aggregateNodeList.add(partialNode)
case Final =>
// For final mode output type is as the original type.
val aggFunctionNode = ExpressionBuilder.makeAggregateFunction(
VeloxAggregateFunctionsBuilder.create(args, aggregateFunction, aggregateMode),
childrenNodeList,
modeKeyWord,
ConverterUtils.getTypeNode(aggregateFunction.dataType, aggregateFunction.nullable)
)
aggregateNodeList.add(aggFunctionNode)
case other =>
throw new UnsupportedOperationException(s"$other is not supported.")
}
case _ if aggregateFunction.aggBufferAttributes.size > 1 =>
generateMergeCompanionNode()
case _ =>
Expand Down Expand Up @@ -826,5 +875,13 @@ case class HashAggregateExecPullOutHelper(
aggBufferAttr.copy(dataType = ae.aggregateFunction.dataType)(
aggBufferAttr.exprId,
aggBufferAttr.qualifier))
case ae: AggregateExpression
if RewriteTypedImperativeAggregate.shouldRewriteForPercentileLikeExpr(ae) =>
val aggBufferAttr = ae.aggregateFunction.inputAggBufferAttributes.head
val newAggBufferDataType = RewriteTypedImperativeAggregate
.getPercentileLikeInterminateDataType(ae.aggregateFunction)
Seq(
aggBufferAttr
.copy(dataType = newAggBufferDataType)(aggBufferAttr.exprId, aggBufferAttr.qualifier))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
package io.glutenproject.utils

import io.glutenproject.expression.ConverterUtils
import io.glutenproject.extension.columnar.RewriteTypedImperativeAggregate
import io.glutenproject.substrait.`type`.{TypeBuilder, TypeNode}

import org.apache.spark.sql.catalyst.expressions.aggregate._
Expand Down Expand Up @@ -81,14 +82,29 @@ object VeloxIntermediateData {
*/
def getInputTypes(aggregateFunc: AggregateFunction, forMergeCompanion: Boolean): Seq[DataType] = {
if (!forMergeCompanion) {
return aggregateFunc.children.map(_.dataType)
aggregateFunc match {
case p: ApproximatePercentile =>
p.children.map(_.dataType) match {
case Seq(childType, percentageType, accuracyType) =>
// The datatype of ApproximatePercentile's third child
// are different between Spark and Velox.
return Seq(childType, percentageType, DoubleType)
case s =>
throw new IllegalArgumentException(s"Expected three children for " +
s"ApproximatePercentile, but got $s")
}
case _ =>
return aggregateFunc.children.map(_.dataType)
}
}
aggregateFunc match {
case _ @Type(veloxDataTypes: Seq[DataType]) =>
Seq(StructType(veloxDataTypes.map(StructField("", _)).toArray))
case _: CollectList | _: CollectSet =>
// CollectList and CollectSet should use data type of agg function.
Seq(aggregateFunc.dataType)
case p: ApproximatePercentile =>
Seq(RewriteTypedImperativeAggregate.getPercentileLikeInterminateDataType(p))
case _ =>
// Not use StructType for single column agg intermediate data
aggregateFunc.aggBufferAttributes.map(_.dataType)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1233,4 +1233,12 @@ class TestOperator extends VeloxWholeStageTransformerSuite {
checkOperatorMatch[HashAggregateExecTransformer]
}
}

test("Support ApproximatePercentile") {
runQueryAndCompare("""
|SELECT approx_percentile(col, array(0.5, 0.4, 0.1), 100)
|FROM VALUES (0), (1), (2), (10) AS tab(col)
|""".stripMargin) {
checkOperatorMatch[HashAggregateExecTransformer]
}
}
15 changes: 13 additions & 2 deletions cpp/velox/substrait/SubstraitToVeloxPlanValidator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,6 @@ static const std::unordered_set<std::string> kBlackList = {
"trunc",
"sequence",
"arrays_overlap",
"approx_percentile",
"get_array_struct_fields"};

} // namespace
Expand Down Expand Up @@ -992,6 +991,7 @@ bool SubstraitToVeloxPlanValidator::validateAggRelFunctionType(const ::substrait
std::vector<TypePtr> types;
bool isDecimal = false;
try {
std::cout << "####debug### function spec: " << funcSpec << std::endl;
types = SubstraitParser::sigToTypes(funcSpec);
for (const auto& type : types) {
if (!isDecimal && type->isDecimal()) {
Expand All @@ -1013,6 +1013,16 @@ bool SubstraitToVeloxPlanValidator::validateAggRelFunctionType(const ::substrait

bool resolved = false;
for (const auto& signature : signaturesOpt.value()) {
const auto& formalArgs = signature->argumentTypes();
auto formalArgsCnt = formalArgs.size();
std::cout << "####debug### function signature: " << signature->toString() << ". base function name: " << baseFuncName << ", function name: " << funcName << std::endl;
std::cout << "###debug### signature argument args: " << formalArgsCnt << std::endl;
for (auto i = 0; i < formalArgsCnt; ++i) {
std::cout << "###debug### << signature args " << formalArgs[i].toString() << std::endl;
}
for (auto i = 0; i < types.size(); i++) {
std::cout << "###debug### << acture types args " << types[i]->toString() << std::endl;
}
exec::SignatureBinder binder(*signature, types);
if (binder.tryBind()) {
auto resolveType = binder.tryResolveType(
Expand Down Expand Up @@ -1141,7 +1151,8 @@ bool SubstraitToVeloxPlanValidator::validate(const ::substrait::AggregateRel& ag
"covar_pop",
"covar_samp",
"approx_distinct",
"skewness"};
"skewness",
"approx_percentile"};

for (const auto& funcSpec : funcSpecs) {
auto funcName = SubstraitParser::getNameBeforeDelimiter(funcSpec);
Expand Down
4 changes: 2 additions & 2 deletions ep/build-velox/src/get_velox.sh
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@

set -exu

VELOX_REPO=https://github.com/oap-project/velox.git
VELOX_BRANCH=2024_03_15
VELOX_REPO=https://github.com/wangguangxin/velox.git
VELOX_BRANCH=2024_03_15_approx_percentile
VELOX_HOME=""

#Set on run gluten on HDFS
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -270,7 +270,8 @@ object ExpressionMappings {
Sig[CovSample](COVAR_SAMP),
Sig[Last](LAST),
Sig[First](FIRST),
Sig[Skewness](SKEWNESS)
Sig[Skewness](SKEWNESS),
Sig[ApproximatePercentile](APPROX_PERCENTILE)
)

/** Mapping Spark window expression to Substrait function name */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import org.apache.spark.sql.catalyst.expressions.aggregate._
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.execution.aggregate.BaseAggregateExec
import org.apache.spark.sql.types._

object RewriteTypedImperativeAggregate extends Rule[SparkPlan] with PullOutProjectHelper {
private lazy val shouldRewriteTypedImperativeAggregate =
Expand All @@ -40,6 +41,38 @@ object RewriteTypedImperativeAggregate extends Rule[SparkPlan] with PullOutProje
}
}

def shouldRewriteForPercentileLikeExpr(ae: AggregateExpression): Boolean = {
ae.aggregateFunction match {
case _: ApproximatePercentile =>
ae.mode match {
case Partial | PartialMerge => true
case _ => false
}
case _ => false
}
}

def getPercentileLikeInterminateDataType(aggFunc: AggregateFunction): StructType = {
aggFunc match {
case a: ApproximatePercentile =>
val childType = a.child.dataType
StructType(
Array(
StructField("col1", ArrayType(DoubleType)),
StructField("col2", BooleanType, false),
StructField("col3", DoubleType, false),
StructField("col4", IntegerType, false),
StructField("col5", LongType, false),
StructField("col6", childType, false),
StructField("col7", childType, false),
StructField("col8", ArrayType(childType)),
StructField("col9", ArrayType(IntegerType))
))
case f =>
throw new IllegalArgumentException(s"Unsupported aggregate function $f")
}
}

override def apply(plan: SparkPlan): SparkPlan = {
if (!shouldRewriteTypedImperativeAggregate) {
return plan
Expand Down Expand Up @@ -67,6 +100,28 @@ object RewriteTypedImperativeAggregate extends Rule[SparkPlan] with PullOutProje
}
copyBaseAggregateExec(agg)(newResultExpressions = newResultExpressions)

case agg: BaseAggregateExec
if agg.aggregateExpressions.exists(shouldRewriteForPercentileLikeExpr) =>
val exprMap = agg.aggregateExpressions
.filter(shouldRewriteForPercentileLikeExpr)
.map(ae => ae.aggregateFunction.inputAggBufferAttributes.head -> ae)
.toMap
val newResultExpressions = agg.resultExpressions.map {
case attr: AttributeReference =>
exprMap
.get(attr)
.map {
ae =>
attr.copy(dataType = getPercentileLikeInterminateDataType(ae.aggregateFunction))(
exprId = attr.exprId,
qualifier = attr.qualifier
)
}
.getOrElse(attr)
case other => other
}
copyBaseAggregateExec(agg)(newResultExpressions = newResultExpressions)

case _ => plan
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ object ExpressionNames {
final val FIRST_IGNORE_NULL = "first_ignore_null"
final val APPROX_DISTINCT = "approx_distinct"
final val SKEWNESS = "skewness"
final val APPROX_PERCENTILE = "approx_percentile"

// Function names used by Substrait plan.
final val ADD = "add"
Expand Down

0 comments on commit 97db869

Please sign in to comment.