Skip to content

Commit

Permalink
[CORE] Code refactoring for HashAggregateExecBaseTransformer (#4719)
Browse files Browse the repository at this point in the history
  • Loading branch information
liujiayi771 committed Feb 20, 2024
1 parent 740746e commit c3614f8
Show file tree
Hide file tree
Showing 3 changed files with 132 additions and 158 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ case class CHHashAggregateExecTransformer(
aggRel
}

override def getAggRelInternal(
private def getAggRelInternal(
context: SubstraitContext,
originalInputAttributes: Seq[Attribute],
operatorId: Long,
Expand Down Expand Up @@ -373,7 +373,7 @@ case class CHHashAggregateExecTransformer(
copy(child = newChild)
}

override protected def getAdvancedExtension(
private def getAdvancedExtension(
validation: Boolean = false,
originalInputAttributes: Seq[Attribute] = Seq.empty): AdvancedExtensionNode = {
val enhancement = if (validation) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ import io.glutenproject.expression.ConverterUtils.FunctionConfig
import io.glutenproject.substrait.`type`.{TypeBuilder, TypeNode}
import io.glutenproject.substrait.{AggregationParams, SubstraitContext}
import io.glutenproject.substrait.expression.{AggregateFunctionNode, ExpressionBuilder, ExpressionNode, ScalarFunctionNode}
import io.glutenproject.substrait.extensions.ExtensionBuilder
import io.glutenproject.substrait.extensions.{AdvancedExtensionNode, ExtensionBuilder}
import io.glutenproject.substrait.rel.{RelBuilder, RelNode}
import io.glutenproject.utils.VeloxIntermediateData

Expand All @@ -33,6 +33,8 @@ import org.apache.spark.sql.execution._
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._

import com.google.protobuf.StringValue

import java.lang.{Long => JLong}
import java.util.{ArrayList => JArrayList, HashMap => JHashMap, List => JList}

Expand All @@ -56,6 +58,15 @@ abstract class HashAggregateExecTransformer(
resultExpressions,
child) {

override def doTransform(context: SubstraitContext): TransformContext = {
val childCtx = child.asInstanceOf[TransformSupport].doTransform(context)

val aggParams = new AggregationParams
val operatorId = context.nextOperatorId(this.nodeName)
val relNode = getAggRel(context, operatorId, aggParams, childCtx.root)
TransformContext(childCtx.outputAttributes, output, relNode)
}

override protected def checkAggFuncModeSupport(
aggFunc: AggregateFunction,
mode: AggregateMode): Boolean = {
Expand Down Expand Up @@ -176,14 +187,14 @@ abstract class HashAggregateExecTransformer(
// to be read.
protected def allowFlush: Boolean

override protected def formatExtOptimizationString(isStreaming: Boolean): String = {
private def formatExtOptimizationString(isStreaming: Boolean): String = {
val isStreamingStr = if (isStreaming) "1" else "0"
val allowFlushStr = if (allowFlush) "1" else "0"
s"isStreaming=$isStreamingStr\nallowFlush=$allowFlushStr\n"
}

// Create aggregate function node and add to list.
override protected def addFunctionNode(
private def addFunctionNode(
args: java.lang.Object,
aggregateFunction: AggregateFunction,
childrenNodeList: JList[ExpressionNode],
Expand Down Expand Up @@ -517,6 +528,121 @@ abstract class HashAggregateExecTransformer(
aggRel
}

private def getAggRelInternal(
context: SubstraitContext,
originalInputAttributes: Seq[Attribute],
operatorId: Long,
input: RelNode = null,
validation: Boolean): RelNode = {
val args = context.registeredFunction
// Get the grouping nodes.
// Use 'child.output' as based Seq[Attribute], the originalInputAttributes
// may be different for each backend.
val groupingList = groupingExpressions
.map(
ExpressionConverter
.replaceWithExpressionTransformer(_, child.output)
.doTransform(args))
.asJava
// Get the aggregate function nodes.
val aggFilterList = new JArrayList[ExpressionNode]()
val aggregateFunctionList = new JArrayList[AggregateFunctionNode]()
aggregateExpressions.foreach(
aggExpr => {
if (aggExpr.filter.isDefined) {
val exprNode = ExpressionConverter
.replaceWithExpressionTransformer(aggExpr.filter.get, child.output)
.doTransform(args)
aggFilterList.add(exprNode)
} else {
// The number of filters should be aligned with that of aggregate functions.
aggFilterList.add(null)
}
val aggregateFunc = aggExpr.aggregateFunction
val childrenNodes = aggExpr.mode match {
case Partial =>
aggregateFunc.children.toList.map(
expr => {
ExpressionConverter
.replaceWithExpressionTransformer(expr, originalInputAttributes)
.doTransform(args)
})
case PartialMerge | Final =>
aggregateFunc.inputAggBufferAttributes.toList.map {
attr =>
val sameAttr = originalInputAttributes.find(_.exprId == attr.exprId)
val rewriteAttr = if (sameAttr.isEmpty) {
// When aggregateExpressions includes subquery, Spark's PlanAdaptiveSubqueries
// Rule will transform the subquery within the final agg. The aggregateFunction
// in the aggregateExpressions of the final aggregation will be cloned, resulting
// in creating new aggregateFunction object. The inputAggBufferAttributes will
// also generate new AttributeReference instances with larger exprId, which leads
// to a failure in binding with the output of the partial agg. We need to adapt
// to this situation; when encountering a failure to bind, it is necessary to
// allow the binding of inputAggBufferAttribute with the same name but different
// exprId.
val attrsWithSameName =
originalInputAttributes.drop(groupingExpressions.size).collect {
case a if a.name == attr.name => a
}
val aggBufferAttrsWithSameName = aggregateExpressions.toIndexedSeq
.flatMap(_.aggregateFunction.inputAggBufferAttributes)
.filter(_.name == attr.name)
assert(
attrsWithSameName.size == aggBufferAttrsWithSameName.size,
"The attribute with the same name in final agg inputAggBufferAttribute must" +
"have the same size of corresponding attributes in originalInputAttributes."
)
attrsWithSameName(aggBufferAttrsWithSameName.indexOf(attr))
} else attr
ExpressionConverter
.replaceWithExpressionTransformer(rewriteAttr, originalInputAttributes)
.doTransform(args)
}
case other =>
throw new UnsupportedOperationException(s"$other not supported.")
}
addFunctionNode(
args,
aggregateFunc,
childrenNodes.asJava,
aggExpr.mode,
aggregateFunctionList)
})

val extensionNode = getAdvancedExtension(validation, originalInputAttributes)
RelBuilder.makeAggregateRel(
input,
groupingList,
aggregateFunctionList,
aggFilterList,
extensionNode,
context,
operatorId)
}

private def getAdvancedExtension(
validation: Boolean = false,
originalInputAttributes: Seq[Attribute] = Seq.empty): AdvancedExtensionNode = {
val enhancement = if (validation) {
// Use a extension node to send the input types through Substrait plan for validation.
val inputTypeNodeList = originalInputAttributes
.map(attr => ConverterUtils.getTypeNode(attr.dataType, attr.nullable))
.asJava
BackendsApiManager.getTransformerApiInstance.packPBMessage(
TypeBuilder.makeStruct(false, inputTypeNodeList).toProtobuf)
} else {
null
}

val optimization =
BackendsApiManager.getTransformerApiInstance.packPBMessage(
StringValue.newBuilder
.setValue(formatExtOptimizationString(isCapableForStreamingAggregation))
.build)
ExtensionBuilder.makeAdvancedExtension(optimization, enhancement)
}

def isStreaming: Boolean = false

def numShufflePartitions: Option[Int] = Some(0)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,8 @@ import io.glutenproject.backendsapi.BackendsApiManager
import io.glutenproject.expression._
import io.glutenproject.extension.ValidationResult
import io.glutenproject.metrics.MetricsUpdater
import io.glutenproject.substrait.`type`.TypeBuilder
import io.glutenproject.substrait.{AggregationParams, SubstraitContext}
import io.glutenproject.substrait.expression.{AggregateFunctionNode, ExpressionBuilder, ExpressionNode}
import io.glutenproject.substrait.extensions.{AdvancedExtensionNode, ExtensionBuilder}
import io.glutenproject.substrait.rel.{RelBuilder, RelNode}
import io.glutenproject.substrait.rel.RelNode

import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate._
Expand All @@ -34,12 +31,6 @@ import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.aggregate._
import org.apache.spark.sql.types._

import com.google.protobuf.StringValue

import java.util.{ArrayList => JArrayList, List => JList}

import scala.collection.JavaConverters._

/** Columnar Based HashAggregateExec. */
abstract class HashAggregateExecBaseTransformer(
requiredChildDistributionExpressions: Option[Seq[Expression]],
Expand Down Expand Up @@ -145,33 +136,9 @@ abstract class HashAggregateExecBaseTransformer(
doNativeValidation(substraitContext, relNode)
}

override def doTransform(context: SubstraitContext): TransformContext = {
val childCtx = child.asInstanceOf[TransformSupport].doTransform(context)

val aggParams = new AggregationParams
val operatorId = context.nextOperatorId(this.nodeName)
val relNode = getAggRel(context, operatorId, aggParams, childCtx.root)
TransformContext(childCtx.outputAttributes, output, relNode)
}

// Members declared in org.apache.spark.sql.execution.AliasAwareOutputPartitioning
override protected def outputExpressions: Seq[NamedExpression] = resultExpressions

protected def addFunctionNode(
args: java.lang.Object,
aggregateFunction: AggregateFunction,
childrenNodeList: JList[ExpressionNode],
aggregateMode: AggregateMode,
aggregateNodeList: JList[AggregateFunctionNode]): Unit = {
aggregateNodeList.add(
ExpressionBuilder.makeAggregateFunction(
AggregateFunctionsBuilder.create(args, aggregateFunction),
childrenNodeList,
modeToKeyWord(aggregateMode),
ConverterUtils.getTypeNode(aggregateFunction.dataType, aggregateFunction.nullable)
))
}

protected def checkAggFuncModeSupport(
aggFunc: AggregateFunction,
mode: AggregateMode): Boolean = {
Expand Down Expand Up @@ -205,125 +172,6 @@ abstract class HashAggregateExecBaseTransformer(
}
}

protected def getAggRelInternal(
context: SubstraitContext,
originalInputAttributes: Seq[Attribute],
operatorId: Long,
input: RelNode = null,
validation: Boolean): RelNode = {
val args = context.registeredFunction
// Get the grouping nodes.
// Use 'child.output' as based Seq[Attribute], the originalInputAttributes
// may be different for each backend.
val groupingList = groupingExpressions
.map(
ExpressionConverter
.replaceWithExpressionTransformer(_, child.output)
.doTransform(args))
.asJava
// Get the aggregate function nodes.
val aggFilterList = new JArrayList[ExpressionNode]()
val aggregateFunctionList = new JArrayList[AggregateFunctionNode]()
aggregateExpressions.foreach(
aggExpr => {
if (aggExpr.filter.isDefined) {
val exprNode = ExpressionConverter
.replaceWithExpressionTransformer(aggExpr.filter.get, child.output)
.doTransform(args)
aggFilterList.add(exprNode)
} else {
// The number of filters should be aligned with that of aggregate functions.
aggFilterList.add(null)
}
val aggregateFunc = aggExpr.aggregateFunction
val childrenNodes = aggExpr.mode match {
case Partial =>
aggregateFunc.children.toList.map(
expr => {
ExpressionConverter
.replaceWithExpressionTransformer(expr, originalInputAttributes)
.doTransform(args)
})
case PartialMerge | Final =>
aggregateFunc.inputAggBufferAttributes.toList.map {
attr =>
val sameAttr = originalInputAttributes.find(_.exprId == attr.exprId)
val rewriteAttr = if (sameAttr.isEmpty) {
// When aggregateExpressions includes subquery, Spark's PlanAdaptiveSubqueries
// Rule will transform the subquery within the final agg. The aggregateFunction
// in the aggregateExpressions of the final aggregation will be cloned, resulting
// in creating new aggregateFunction object. The inputAggBufferAttributes will
// also generate new AttributeReference instances with larger exprId, which leads
// to a failure in binding with the output of the partial agg. We need to adapt
// to this situation; when encountering a failure to bind, it is necessary to
// allow the binding of inputAggBufferAttribute with the same name but different
// exprId.
val attrsWithSameName =
originalInputAttributes.drop(groupingExpressions.size).collect {
case a if a.name == attr.name => a
}
val aggBufferAttrsWithSameName = aggregateExpressions.toIndexedSeq
.flatMap(_.aggregateFunction.inputAggBufferAttributes)
.filter(_.name == attr.name)
assert(
attrsWithSameName.size == aggBufferAttrsWithSameName.size,
"The attribute with the same name in final agg inputAggBufferAttribute must" +
"have the same size of corresponding attributes in originalInputAttributes."
)
attrsWithSameName(aggBufferAttrsWithSameName.indexOf(attr))
} else attr
ExpressionConverter
.replaceWithExpressionTransformer(rewriteAttr, originalInputAttributes)
.doTransform(args)
}
case other =>
throw new UnsupportedOperationException(s"$other not supported.")
}
addFunctionNode(
args,
aggregateFunc,
childrenNodes.asJava,
aggExpr.mode,
aggregateFunctionList)
})

val extensionNode = getAdvancedExtension(validation, originalInputAttributes)
RelBuilder.makeAggregateRel(
input,
groupingList,
aggregateFunctionList,
aggFilterList,
extensionNode,
context,
operatorId)
}

protected def getAdvancedExtension(
validation: Boolean = false,
originalInputAttributes: Seq[Attribute] = Seq.empty): AdvancedExtensionNode = {
val enhancement = if (validation) {
// Use a extension node to send the input types through Substrait plan for validation.
val inputTypeNodeList = originalInputAttributes
.map(attr => ConverterUtils.getTypeNode(attr.dataType, attr.nullable))
.asJava
BackendsApiManager.getTransformerApiInstance.packPBMessage(
TypeBuilder.makeStruct(false, inputTypeNodeList).toProtobuf)
} else {
null
}

val optimization =
BackendsApiManager.getTransformerApiInstance.packPBMessage(
StringValue.newBuilder
.setValue(formatExtOptimizationString(isCapableForStreamingAggregation))
.build)
ExtensionBuilder.makeAdvancedExtension(optimization, enhancement)
}

protected def formatExtOptimizationString(isStreaming: Boolean): String = {
s"isStreaming=${if (isStreaming) "1" else "0"}\n"
}

protected def getAggRel(
context: SubstraitContext,
operatorId: Long,
Expand Down

0 comments on commit c3614f8

Please sign in to comment.