Skip to content

Commit

Permalink
fix corr distinct
Browse files Browse the repository at this point in the history
  • Loading branch information
rui-mo committed Apr 14, 2023
1 parent cee9e48 commit 8a41884
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 32 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -618,4 +618,15 @@ class TestOperator extends WholeStageTransformerSuite {
assert(result.collect()(0).get(0).toString.equals("0.0345678900000000000000000000000000000"))
checkOperatorMatch[GlutenHashAggregateExecTransformer](result)
}

test("corr distinct") {
Seq((1, 1), (2, 2), (2, 2))
.toDF("a", "b").createOrReplaceTempView("view")
runQueryAndCompare("SELECT corr(DISTINCT a, b)," +
"corr(DISTINCT b, a), count(*) FROM view") { df => {
assert(getExecutedPlan(df).count(plan => {
plan.isInstanceOf[GlutenHashAggregateExecTransformer]
}) == 4)
}}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -223,9 +223,9 @@ case class GlutenHashAggregateExecTransformer(
childrenNodeList: java.util.ArrayList[ExpressionNode],
aggregateMode: AggregateMode,
aggregateNodeList: java.util.ArrayList[AggregateFunctionNode]): Unit = {
// A special handling for PartialMerge in the execution of count distinct.
// Use partial phase for this aggregation.
val modeKeyWord = modeToKeyWord(if (partialCountInMerge) Partial else aggregateMode)
// This is a special handling for PartialMerge in the execution of distinct.
// Use Partial phase instead for this aggregation.
val modeKeyWord = modeToKeyWord(if (mixedPartialAndMerge) Partial else aggregateMode)
aggregateFunction match {
case _: Average | _: StddevSamp | _: StddevPop | _: VarianceSamp | _: VariancePop |
_: Corr | _: CovPopulation | _: CovSample =>
Expand All @@ -240,7 +240,7 @@ case class GlutenHashAggregateExecTransformer(
case PartialMerge =>
val aggFunctionNode = ExpressionBuilder.makeAggregateFunction(
VeloxAggregateFunctionsBuilder
.create(args, aggregateFunction, partialCountInMerge),
.create(args, aggregateFunction, mixedPartialAndMerge),
childrenNodeList,
modeKeyWord,
getIntermediateTypeNode(aggregateFunction))
Expand All @@ -267,7 +267,7 @@ case class GlutenHashAggregateExecTransformer(
case PartialMerge =>
val aggFunctionNode = ExpressionBuilder.makeAggregateFunction(
VeloxAggregateFunctionsBuilder
.create(args, aggregateFunction, partialCountInMerge),
.create(args, aggregateFunction, mixedPartialAndMerge),
childrenNodeList,
modeKeyWord,
getIntermediateTypeNode(aggregateFunction))
Expand All @@ -282,17 +282,10 @@ case class GlutenHashAggregateExecTransformer(
case other =>
throw new UnsupportedOperationException(s"$other is not supported.")
}
case _: Count if aggregateMode == Partial =>
val aggFunctionNode = ExpressionBuilder.makeAggregateFunction(
VeloxAggregateFunctionsBuilder.create(args, aggregateFunction),
childrenNodeList,
modeKeyWord,
ConverterUtils.getTypeNode(aggregateFunction.dataType, aggregateFunction.nullable))
aggregateNodeList.add(aggFunctionNode)
case _ =>
val aggFunctionNode = ExpressionBuilder.makeAggregateFunction(
VeloxAggregateFunctionsBuilder.create(
args, aggregateFunction, aggregateMode == PartialMerge && partialCountInMerge),
args, aggregateFunction, aggregateMode == PartialMerge && mixedPartialAndMerge),
childrenNodeList,
modeKeyWord,
ConverterUtils.getTypeNode(aggregateFunction.dataType, aggregateFunction.nullable))
Expand Down Expand Up @@ -397,6 +390,14 @@ case class GlutenHashAggregateExecTransformer(
val functionInputAttributes = aggregateExpression.aggregateFunction.inputAggBufferAttributes
val aggregateFunction = aggregateExpression.aggregateFunction
aggregateFunction match {
case _: Count | _: Corr if mixedPartialAndMerge && aggregateExpression.mode == Partial =>
val childNodes = new util.ArrayList[ExpressionNode](
aggregateFunction.children.map(attr => {
ExpressionConverter
.replaceWithExpressionTransformer(attr, originalInputAttributes)
.doTransform(args)
}).asJava)
exprNodes.addAll(childNodes)
case Average(_, _) =>
aggregateExpression.mode match {
case PartialMerge | Final =>
Expand Down Expand Up @@ -539,16 +540,6 @@ case class GlutenHashAggregateExecTransformer(
case other =>
throw new UnsupportedOperationException(s"$other is not supported.")
}
case _: Count if partialCountInMerge && aggregateExpression.mode == Partial =>
assert(functionInputAttributes.size == 1,
"Only one input attribute is expected for Count.")
val childNodes = new util.ArrayList[ExpressionNode](
aggregateFunction.children.map(attr => {
ExpressionConverter
.replaceWithExpressionTransformer(attr, originalInputAttributes)
.doTransform(args)
}).asJava)
exprNodes.addAll(childNodes)
case _ =>
assert(functionInputAttributes.size == 1, "Only one input attribute is expected.")
val childNodes = new util.ArrayList[ExpressionNode](
Expand Down Expand Up @@ -621,18 +612,17 @@ case class GlutenHashAggregateExecTransformer(
}

/**
* Whether this is a mixed aggregation of partial count and
* other partial-merge aggregation functions.
* @return whether partial count and other partial-merge functions coexist.
* Whether this is a mixed aggregation of partial and partial-merge aggregation functions.
* @return whether partial and partial-merge functions coexist.
*/
def partialCountInMerge: Boolean = {
def mixedPartialAndMerge: Boolean = {
val partialMergeExists = aggregateExpressions.exists(expression => {
expression.mode == PartialMerge
})
val partialCountExists = aggregateExpressions.exists(expression => {
expression.aggregateFunction.isInstanceOf[Count] && expression.mode == Partial
val partialExists = aggregateExpressions.exists(expression => {
expression.mode == Partial
})
partialMergeExists && partialCountExists
partialMergeExists && partialExists
}

/**
Expand Down Expand Up @@ -752,8 +742,8 @@ object VeloxAggregateFunctionsBuilder {
throw new UnsupportedOperationException(s"not currently supported: $aggregateFunc.")
}
// Check whether each backend supports this aggregate function.
if (!BackendsApiManager.getValidatorApiInstance.doAggregateFunctionValidate(
sigName, aggregateFunc)) {
if (!BackendsApiManager.getValidatorApiInstance
.doAggregateFunctionValidate(sigName, aggregateFunc)) {
throw new UnsupportedOperationException(s"not currently supported: $aggregateFunc.")
}
// Use companion function for partial-merge aggregation functions on count distinct.
Expand Down

0 comments on commit 8a41884

Please sign in to comment.