From cee9e4858dc2a370709d1588a69993bc9769e244 Mon Sep 17 00:00:00 2001 From: rui-mo Date: Thu, 6 Apr 2023 12:54:12 +0000 Subject: [PATCH] partial merge --- .../execution/TestOperator.scala | 159 +++++++++--- .../VeloxDataTypeValidationSuite.scala | 22 +- cpp/velox/jni/JniWrapper.cc | 2 +- ep/build-velox/src/get_velox.sh | 4 +- .../HashAggregateExecBaseTransformer.scala | 78 +++--- .../AggregateFunctionsBuilder.scala | 21 +- .../expression/ConverterUtils.scala | 98 +++---- .../WholeStageTransformerSuite.scala | 3 +- .../GlutenHashAggregateExecTransformer.scala | 241 ++++++++++++++---- 9 files changed, 426 insertions(+), 202 deletions(-) diff --git a/backends-velox/src/test/scala/io/glutenproject/execution/TestOperator.scala b/backends-velox/src/test/scala/io/glutenproject/execution/TestOperator.scala index 390192ece1c3..a8df9ff9e478 100644 --- a/backends-velox/src/test/scala/io/glutenproject/execution/TestOperator.scala +++ b/backends-velox/src/test/scala/io/glutenproject/execution/TestOperator.scala @@ -161,23 +161,77 @@ class TestOperator extends WholeStageTransformerSuite { } test("count") { - val df = runQueryAndCompare("select count(*) from lineitem " + - "where l_partkey in (1552, 674, 1062)") { - _ => - } - checkLengthAndPlan(df, 1) + val df = runQueryAndCompare( + "select count(*) from lineitem where l_partkey in (1552, 674, 1062)") { + checkOperatorMatch[GlutenHashAggregateExecTransformer] } + runQueryAndCompare( + "select count(l_quantity), count(distinct l_partkey) from lineitem") { df => { + assert(getExecutedPlan(df).count(plan => { + plan.isInstanceOf[GlutenHashAggregateExecTransformer]}) == 4) + }} } test("avg") { - val df = runQueryAndCompare("select avg(l_partkey) from lineitem " + - "where l_partkey < 1000") { _ => } - checkLengthAndPlan(df, 1) + val df = runQueryAndCompare( + "select avg(l_partkey) from lineitem where l_partkey < 1000") { + checkOperatorMatch[GlutenHashAggregateExecTransformer] } + runQueryAndCompare( + "select avg(l_quantity), count(distinct l_partkey) from lineitem") { df => { + assert(getExecutedPlan(df).count(plan => { + plan.isInstanceOf[GlutenHashAggregateExecTransformer]}) == 4) + }} + runQueryAndCompare( + "select avg(cast (l_quantity as DECIMAL(12, 2))), " + + "count(distinct l_partkey) from lineitem") { df => { + assert(getExecutedPlan(df).count(plan => { + plan.isInstanceOf[GlutenHashAggregateExecTransformer]}) == 4) + }} + runQueryAndCompare( + "select avg(cast (l_quantity as DECIMAL(22, 2))), " + + "count(distinct l_partkey) from lineitem") { df => { + assert(getExecutedPlan(df).count(plan => { + plan.isInstanceOf[GlutenHashAggregateExecTransformer]}) == 4) + }} } test("sum") { - val df = runQueryAndCompare("select sum(l_partkey) from lineitem " + - "where l_partkey < 2000") { _ => } - checkLengthAndPlan(df, 1) + runQueryAndCompare( + "select sum(l_partkey) from lineitem where l_partkey < 2000") { + checkOperatorMatch[GlutenHashAggregateExecTransformer] + } + runQueryAndCompare( + "select sum(l_quantity), count(distinct l_partkey) from lineitem") { df => { + assert(getExecutedPlan(df).count(plan => { + plan.isInstanceOf[GlutenHashAggregateExecTransformer]}) == 4) + }} + runQueryAndCompare( + "select sum(cast (l_quantity as DECIMAL(22, 2))) from lineitem") { + checkOperatorMatch[GlutenHashAggregateExecTransformer] + } + runQueryAndCompare( + "select sum(cast (l_quantity as DECIMAL(12, 2))), " + + "count(distinct l_partkey) from lineitem") { df => { + assert(getExecutedPlan(df).count(plan => { + plan.isInstanceOf[GlutenHashAggregateExecTransformer]}) == 4) + }} + runQueryAndCompare( + "select sum(cast (l_quantity as DECIMAL(22, 2))), " + + "count(distinct l_partkey) from lineitem") { df => { + assert(getExecutedPlan(df).count(plan => { + plan.isInstanceOf[GlutenHashAggregateExecTransformer]}) == 4) + }} + } + + test("min and max") { + runQueryAndCompare( + "select min(l_partkey), max(l_partkey) from lineitem where l_partkey < 2000") { + checkOperatorMatch[GlutenHashAggregateExecTransformer] + } + runQueryAndCompare( + "select min(l_partkey), max(l_partkey), count(distinct l_partkey) from lineitem") { df => { + assert(getExecutedPlan(df).count(plan => { + plan.isInstanceOf[GlutenHashAggregateExecTransformer]}) == 4) + }} } test("groupby") { @@ -328,7 +382,7 @@ class TestOperator extends WholeStageTransformerSuite { } test("union_all three tables") { - val df = runQueryAndCompare( + runQueryAndCompare( """ |select count(orderkey) from ( | select l_orderkey as orderkey from lineitem @@ -383,6 +437,11 @@ class TestOperator extends WholeStageTransformerSuite { |""".stripMargin) { checkOperatorMatch[GlutenHashAggregateExecTransformer] } + runQueryAndCompare( + "select stddev_samp(l_quantity), count(distinct l_partkey) from lineitem") { df => { + assert(getExecutedPlan(df).count(plan => { + plan.isInstanceOf[GlutenHashAggregateExecTransformer]}) == 4) + }} } test("round") { @@ -408,6 +467,11 @@ class TestOperator extends WholeStageTransformerSuite { |""".stripMargin) { checkOperatorMatch[GlutenHashAggregateExecTransformer] } + runQueryAndCompare( + "select stddev_pop(l_quantity), count(distinct l_partkey) from lineitem") { df => { + assert(getExecutedPlan(df).count(plan => { + plan.isInstanceOf[GlutenHashAggregateExecTransformer]}) == 4) + }} } test("var_samp") { @@ -424,6 +488,11 @@ class TestOperator extends WholeStageTransformerSuite { |""".stripMargin) { checkOperatorMatch[GlutenHashAggregateExecTransformer] } + runQueryAndCompare( + "select var_samp(l_quantity), count(distinct l_partkey) from lineitem") { df => { + assert(getExecutedPlan(df).count(plan => { + plan.isInstanceOf[GlutenHashAggregateExecTransformer]}) == 4) + }} } test("var_pop") { @@ -440,6 +509,11 @@ class TestOperator extends WholeStageTransformerSuite { |""".stripMargin) { checkOperatorMatch[GlutenHashAggregateExecTransformer] } + runQueryAndCompare( + "select var_pop(l_quantity), count(distinct l_partkey) from lineitem") { df => { + assert(getExecutedPlan(df).count(plan => { + plan.isInstanceOf[GlutenHashAggregateExecTransformer]}) == 4) + }} } test("bit_and and bit_or") { @@ -450,6 +524,11 @@ class TestOperator extends WholeStageTransformerSuite { |""".stripMargin) { checkOperatorMatch[GlutenHashAggregateExecTransformer] } + runQueryAndCompare( + "select bit_and(l_linenumber), count(distinct l_partkey) from lineitem") { df => { + assert(getExecutedPlan(df).count(plan => { + plan.isInstanceOf[GlutenHashAggregateExecTransformer]}) == 4) + }} runQueryAndCompare( """ |select bit_or(l_linenumber) from lineitem @@ -457,6 +536,11 @@ class TestOperator extends WholeStageTransformerSuite { |""".stripMargin) { checkOperatorMatch[GlutenHashAggregateExecTransformer] } + runQueryAndCompare( + "select bit_or(l_linenumber), count(distinct l_partkey) from lineitem") { df => { + assert(getExecutedPlan(df).count(plan => { + plan.isInstanceOf[GlutenHashAggregateExecTransformer]}) == 4) + }} } test("bool scan") { @@ -491,26 +575,39 @@ class TestOperator extends WholeStageTransformerSuite { } test("corr covar_pop covar_samp") { - withSQLConf("spark.sql.adaptive.enabled" -> "false") { - runQueryAndCompare( - """ - |select corr(l_partkey, l_suppkey) from lineitem; - |""".stripMargin) { - checkOperatorMatch[GlutenHashAggregateExecTransformer] - } - runQueryAndCompare( - """ - |select covar_pop(l_partkey, l_suppkey) from lineitem; - |""".stripMargin) { - checkOperatorMatch[GlutenHashAggregateExecTransformer] - } - runQueryAndCompare( - """ - |select covar_samp(l_partkey, l_suppkey) from lineitem; - |""".stripMargin) { - checkOperatorMatch[GlutenHashAggregateExecTransformer] - } + runQueryAndCompare( + """ + |select corr(l_partkey, l_suppkey) from lineitem; + |""".stripMargin) { + checkOperatorMatch[GlutenHashAggregateExecTransformer] + } + runQueryAndCompare( + "select corr(l_partkey, l_suppkey), count(distinct l_orderkey) from lineitem") { df => { + assert(getExecutedPlan(df).count(plan => { + plan.isInstanceOf[GlutenHashAggregateExecTransformer]}) == 4) + }} + runQueryAndCompare( + """ + |select covar_pop(l_partkey, l_suppkey) from lineitem; + |""".stripMargin) { + checkOperatorMatch[GlutenHashAggregateExecTransformer] } + runQueryAndCompare( + "select covar_pop(l_partkey, l_suppkey), count(distinct l_orderkey) from lineitem") { df => { + assert(getExecutedPlan(df).count(plan => { + plan.isInstanceOf[GlutenHashAggregateExecTransformer]}) == 4) + }} + runQueryAndCompare( + """ + |select covar_samp(l_partkey, l_suppkey) from lineitem; + |""".stripMargin) { + checkOperatorMatch[GlutenHashAggregateExecTransformer] + } + runQueryAndCompare( + "select covar_samp(l_partkey, l_suppkey), count(distinct l_orderkey) from lineitem") { df => { + assert(getExecutedPlan(df).count(plan => { + plan.isInstanceOf[GlutenHashAggregateExecTransformer]}) == 4) + }} } test("Cast double to decimal") { diff --git a/backends-velox/src/test/scala/io/glutenproject/execution/VeloxDataTypeValidationSuite.scala b/backends-velox/src/test/scala/io/glutenproject/execution/VeloxDataTypeValidationSuite.scala index d03069bbf078..f3e8b1efcfe4 100644 --- a/backends-velox/src/test/scala/io/glutenproject/execution/VeloxDataTypeValidationSuite.scala +++ b/backends-velox/src/test/scala/io/glutenproject/execution/VeloxDataTypeValidationSuite.scala @@ -202,24 +202,18 @@ class VeloxDataTypeValidationSuite extends WholeStageTransformerSuite { runQueryAndCompare("select int, date from type1 " + " group by grouping sets(int, date) sort by date, int limit 1") { df => { val executedPlan = getExecutedPlan(df) - assert(executedPlan.exists(plan => - plan.find(child => child.isInstanceOf[BatchScanExecTransformer]).isDefined)) - assert(executedPlan.exists(plan => - plan.find(child => child.isInstanceOf[ProjectExecTransformer]).isDefined)) - assert(executedPlan.exists(plan => - plan.find(child => child.isInstanceOf[GlutenHashAggregateExecTransformer]).isDefined)) - assert(executedPlan.exists(plan => - plan.find(child => child.isInstanceOf[SortExecTransformer]).isDefined)) + assert(executedPlan.exists(plan => plan.isInstanceOf[BatchScanExecTransformer])) + assert(executedPlan.exists(plan => plan.isInstanceOf[ProjectExecTransformer])) + assert(executedPlan.exists(plan => plan.isInstanceOf[GlutenHashAggregateExecTransformer])) + assert(executedPlan.exists(plan => plan.isInstanceOf[SortExecTransformer])) }} // Validation: Expand, Filter. runQueryAndCompare("select date, string, sum(int) from type1 where date > date '1990-01-09' " + "group by rollup(date, string) order by date, string") { df => { val executedPlan = getExecutedPlan(df) - assert(executedPlan.exists(plan => - plan.find(child => child.isInstanceOf[ExpandExecTransformer]).isDefined)) - assert(executedPlan.exists(plan => - plan.find(child => child.isInstanceOf[GlutenFilterExecTransformer]).isDefined)) + assert(executedPlan.exists(plan => plan.isInstanceOf[ExpandExecTransformer])) + assert(executedPlan.exists(plan => plan.isInstanceOf[GlutenFilterExecTransformer])) }} // Validation: Union. @@ -231,9 +225,7 @@ class VeloxDataTypeValidationSuite extends WholeStageTransformerSuite { | select date as d from type1 |); |""".stripMargin) { df => { - val executedPlan = getExecutedPlan(df) - assert(executedPlan.exists(plan => - plan.find(child => child.isInstanceOf[UnionExecTransformer]).isDefined)) + assert(getExecutedPlan(df).exists(plan => plan.isInstanceOf[UnionExecTransformer])) }} // Validation: Limit. diff --git a/cpp/velox/jni/JniWrapper.cc b/cpp/velox/jni/JniWrapper.cc index bd03ec1177f7..608d2d510318 100644 --- a/cpp/velox/jni/JniWrapper.cc +++ b/cpp/velox/jni/JniWrapper.cc @@ -96,7 +96,7 @@ JNIEXPORT jboolean JNICALL Java_io_glutenproject_vectorized_ExpressionEvaluatorJ try { return planValidator.validate(subPlan); } catch (std::invalid_argument& e) { - LOG(INFO) << "Faled to validate substrait plan because " << e.what(); + LOG(INFO) << "Failed to validate substrait plan because " << e.what(); return false; } JNI_METHOD_END(false) diff --git a/ep/build-velox/src/get_velox.sh b/ep/build-velox/src/get_velox.sh index 3d10ba05edf2..8b698328bdea 100755 --- a/ep/build-velox/src/get_velox.sh +++ b/ep/build-velox/src/get_velox.sh @@ -2,8 +2,8 @@ set -exu -VELOX_REPO=https://github.com/oap-project/velox.git -VELOX_BRANCH=main +VELOX_REPO=https://github.com/rui-mo/velox.git +VELOX_BRANCH=companion #Set on run gluten on HDFS ENABLE_HDFS=OFF diff --git a/gluten-core/src/main/scala/io/glutenproject/execution/HashAggregateExecBaseTransformer.scala b/gluten-core/src/main/scala/io/glutenproject/execution/HashAggregateExecBaseTransformer.scala index d70b72ebf6f9..31dda33fe71b 100644 --- a/gluten-core/src/main/scala/io/glutenproject/execution/HashAggregateExecBaseTransformer.scala +++ b/gluten-core/src/main/scala/io/glutenproject/execution/HashAggregateExecBaseTransformer.scala @@ -219,14 +219,14 @@ abstract class HashAggregateExecBaseTransformer( break } expr.mode match { - case Partial | PartialMerge => + case Partial => for (aggChild <- expr.aggregateFunction.children) { if (!aggChild.isInstanceOf[Attribute] && !aggChild.isInstanceOf[Literal]) { needsProjection = true break } } - // No need to consider pre-projection for Final Agg. + // No need to consider pre-projection for PartialMerge and Final Agg. case _ => } } @@ -436,7 +436,7 @@ abstract class HashAggregateExecBaseTransformer( aggregateAttributeList: Seq[Attribute]): List[Attribute] = { var aggregateAttr = new ListBuffer[Attribute]() val size = aggregateExpressions.size - var res_index = 0 + var resIndex = 0 for (expIdx <- 0 until size) { val exp: AggregateExpression = aggregateExpressions(expIdx) val mode = exp.mode @@ -444,25 +444,17 @@ abstract class HashAggregateExecBaseTransformer( aggregateFunc match { case Average(_, _) => mode match { - case Partial => - val avg = aggregateFunc.asInstanceOf[Average] - val aggBufferAttr = avg.inputAggBufferAttributes - for (index <- aggBufferAttr.indices) { - val attr = ConverterUtils.getAttrFromExpr(aggBufferAttr(index)) - aggregateAttr += attr - } - res_index += 2 - case PartialMerge => + case Partial | PartialMerge => val avg = aggregateFunc.asInstanceOf[Average] val aggBufferAttr = avg.inputAggBufferAttributes for (index <- aggBufferAttr.indices) { val attr = ConverterUtils.getAttrFromExpr(aggBufferAttr(index)) aggregateAttr += attr } - res_index += 1 + resIndex += 2 case Final => - aggregateAttr += aggregateAttributeList(res_index) - res_index += 1 + aggregateAttr += aggregateAttributeList(resIndex) + resIndex += 1 case other => throw new UnsupportedOperationException(s"not currently supported: $other.") } @@ -476,15 +468,15 @@ abstract class HashAggregateExecBaseTransformer( aggregateAttr += ConverterUtils.getAttrFromExpr(aggBufferAttr.head) val isEmptyAttr = ConverterUtils.getAttrFromExpr(aggBufferAttr(1)) aggregateAttr += isEmptyAttr - res_index += 2 + resIndex += 2 } else { val attr = ConverterUtils.getAttrFromExpr(aggBufferAttr.head) aggregateAttr += attr - res_index += 1 + resIndex += 1 } case Final => - aggregateAttr += aggregateAttributeList(res_index) - res_index += 1 + aggregateAttr += aggregateAttributeList(resIndex) + resIndex += 1 case other => throw new UnsupportedOperationException(s"not currently supported: $other.") } @@ -495,10 +487,10 @@ abstract class HashAggregateExecBaseTransformer( val aggBufferAttr = count.inputAggBufferAttributes val attr = ConverterUtils.getAttrFromExpr(aggBufferAttr.head) aggregateAttr += attr - res_index += 1 + resIndex += 1 case Final => - aggregateAttr += aggregateAttributeList(res_index) - res_index += 1 + aggregateAttr += aggregateAttributeList(resIndex) + resIndex += 1 case other => throw new UnsupportedOperationException(s"not currently supported: $other.") } @@ -510,10 +502,10 @@ abstract class HashAggregateExecBaseTransformer( s"Aggregate function ${aggregateFunc} expects one buffer attribute.") val attr = ConverterUtils.getAttrFromExpr(aggBufferAttr.head) aggregateAttr += attr - res_index += 1 + resIndex += 1 case Final => - aggregateAttr += aggregateAttributeList(res_index) - res_index += 1 + aggregateAttr += aggregateAttributeList(resIndex) + resIndex += 1 case other => throw new UnsupportedOperationException(s"not currently supported: $other.") } @@ -529,10 +521,10 @@ abstract class HashAggregateExecBaseTransformer( val attr = ConverterUtils.getAttrFromExpr(aggBufferAttr(index)) aggregateAttr += attr } - res_index += expectedBufferSize + resIndex += expectedBufferSize case Final => - aggregateAttr += aggregateAttributeList(res_index) - res_index += 1 + aggregateAttr += aggregateAttributeList(resIndex) + resIndex += 1 case other => throw new UnsupportedOperationException(s"not currently supported: ${other}.") } @@ -548,27 +540,25 @@ abstract class HashAggregateExecBaseTransformer( val attr = ConverterUtils.getAttrFromExpr(aggBufferAttr(index)) aggregateAttr += attr } - res_index += expectedBufferSize + resIndex += expectedBufferSize case Final => - aggregateAttr += aggregateAttributeList(res_index) - res_index += 1 + aggregateAttr += aggregateAttributeList(resIndex) + resIndex += 1 case other => throw new UnsupportedOperationException(s"not currently supported: ${other}.") } case _: StddevSamp | _: StddevPop | _: VarianceSamp | _: VariancePop => mode match { - case Partial => + case Partial | PartialMerge => val aggBufferAttr = aggregateFunc.inputAggBufferAttributes for (index <- aggBufferAttr.indices) { val attr = ConverterUtils.getAttrFromExpr(aggBufferAttr(index)) aggregateAttr += attr } - res_index += 3 - case PartialMerge => - throw new UnsupportedOperationException("not currently supported: PartialMerge.") + resIndex += 3 case Final => - aggregateAttr += aggregateAttributeList(res_index) - res_index += 1 + aggregateAttr += aggregateAttributeList(resIndex) + resIndex += 1 case other => throw new UnsupportedOperationException(s"not currently supported: $other.") } @@ -582,10 +572,10 @@ abstract class HashAggregateExecBaseTransformer( val attr = ConverterUtils.getAttrFromExpr(aggBufferAttr(index)) aggregateAttr += attr } - res_index += aggBufferAttr.size + resIndex += aggBufferAttr.size case Final => - aggregateAttr += aggregateAttributeList(res_index) - res_index += 1 + aggregateAttr += aggregateAttributeList(resIndex) + resIndex += 1 case other => throw new UnsupportedOperationException(s"not currently supported: $other.") } @@ -598,10 +588,10 @@ abstract class HashAggregateExecBaseTransformer( val attr = ConverterUtils.getAttrFromExpr(aggBufferAttr(index)) aggregateAttr += attr } - res_index += aggBufferAttr.size + resIndex += aggBufferAttr.size case Final => - aggregateAttr += aggregateAttributeList(res_index) - res_index += 1 + aggregateAttr += aggregateAttributeList(resIndex) + resIndex += 1 case other => throw new UnsupportedOperationException(s"not currently supported: $other.") } @@ -658,7 +648,7 @@ abstract class HashAggregateExecBaseTransformer( .replaceWithExpressionTransformer(expr, originalInputAttributes) .doTransform(args) }) - case Final => + case PartialMerge | Final => aggregateFunc.inputAggBufferAttributes.toList.map(attr => { ExpressionConverter .replaceWithExpressionTransformer(attr, originalInputAttributes) diff --git a/gluten-core/src/main/scala/io/glutenproject/expression/AggregateFunctionsBuilder.scala b/gluten-core/src/main/scala/io/glutenproject/expression/AggregateFunctionsBuilder.scala index 42cca70019f6..e110a9d3cf4d 100644 --- a/gluten-core/src/main/scala/io/glutenproject/expression/AggregateFunctionsBuilder.scala +++ b/gluten-core/src/main/scala/io/glutenproject/expression/AggregateFunctionsBuilder.scala @@ -21,32 +21,31 @@ import io.glutenproject.backendsapi.BackendsApiManager import io.glutenproject.expression.ConverterUtils.FunctionConfig import io.glutenproject.substrait.expression.ExpressionBuilder import org.apache.spark.sql.catalyst.expressions.aggregate._ +import org.apache.spark.sql.types.DataType object AggregateFunctionsBuilder { - - val veloxCorrIntermediateDataOrder = Seq("ck", "n", "xMk", "yMk", "xAvg", "yAvg") - val veloxCovarIntermediateDataOrder = Seq("ck", "n", "xAvg", "yAvg") - def create(args: java.lang.Object, aggregateFunc: AggregateFunction): Long = { val functionMap = args.asInstanceOf[java.util.HashMap[String, java.lang.Long]] - val substraitAggFuncName = - ExpressionMappings.aggregate_functions_map.getOrElse(aggregateFunc.getClass, - ExpressionMappings.getAggSigOther(aggregateFunc.prettyName)) - // Check whether Gluten supports this aggregate function + val substraitAggFuncName = ExpressionMappings.aggregate_functions_map.getOrElse( + aggregateFunc.getClass, ExpressionMappings.getAggSigOther(aggregateFunc.prettyName)) + // Check whether Gluten supports this aggregate function. if (substraitAggFuncName.isEmpty) { throw new UnsupportedOperationException(s"not currently supported: $aggregateFunc.") } - // Check whether each backend supports this aggregate function + // Check whether each backend supports this aggregate function. if (!BackendsApiManager.getValidatorApiInstance.doAggregateFunctionValidate( - substraitAggFuncName, aggregateFunc )) { + substraitAggFuncName, aggregateFunc)) { throw new UnsupportedOperationException(s"not currently supported: $aggregateFunc.") } + + val inputTypes: Seq[DataType] = aggregateFunc.children.map(child => child.dataType) + ExpressionBuilder.newScalarFunction( functionMap, ConverterUtils.makeFuncName( substraitAggFuncName, - aggregateFunc.children.map(child => child.dataType), + inputTypes, FunctionConfig.REQ)) } } diff --git a/gluten-core/src/main/scala/io/glutenproject/expression/ConverterUtils.scala b/gluten-core/src/main/scala/io/glutenproject/expression/ConverterUtils.scala index a1eedffcfab0..c877a7302b6d 100644 --- a/gluten-core/src/main/scala/io/glutenproject/expression/ConverterUtils.scala +++ b/gluten-core/src/main/scala/io/glutenproject/expression/ConverterUtils.scala @@ -275,6 +275,56 @@ object ConverterUtils extends Logging { val REQ, OPT, NON = Value } + /** + * Get the signature name of a type based on Substrait's definition in + * https://substrait.io/extensions/#function-signature-compound-names. + * @param dataType: the input data type. + * @return the corresponding signature name. + */ + def getTypeSigName(dataType: DataType): String = { + dataType match { + case BooleanType => // TODO: Not in Substrait yet. + "bool" + case ByteType => "i8" + case ShortType => "i16" + case IntegerType => "i32" + case LongType => "i64" + case FloatType => "fp32" + case DoubleType => "fp64" + case DateType => "date" + case TimestampType => "ts" + case StringType => "str" + case BinaryType => "vbin" + case DecimalType() => + val decimalType = dataType.asInstanceOf[DecimalType] + val precision = decimalType.precision + val scale = decimalType.scale + // TODO: different with Substrait due to more details here. + "dec<" + precision + "," + scale + ">" + case ArrayType(_, _) => + "list" + case StructType(fields) => + // TODO: different with Substrait due to more details here. + var sigName = "struct<" + var index = 0 + fields.foreach(field => { + sigName = sigName.concat(getTypeSigName(field.dataType)) + sigName = sigName.concat(if (index < fields.length - 1) "," else "") + index += 1 + }) + sigName = sigName.concat(">") + sigName + case MapType(_, _, _) => + "map" + case CharType(_) => + "fchar" + case NullType => + "nothing" + case other => + throw new UnsupportedOperationException(s"Type $other not supported.") + } + } + // This method is used to create a function name with input types. // The format would be aligned with that specified in Substrait. // The function name Format: @@ -292,51 +342,9 @@ object ConverterUtils extends Logging { throw new UnsupportedOperationException(s"$other is not supported.") } for (idx <- datatypes.indices) { - val datatype = datatypes(idx) - typedFuncName = datatype match { - case BooleanType => - // TODO: Not in Substrait yet. - typedFuncName.concat("bool") - case ByteType => - typedFuncName.concat("i8") - case ShortType => - typedFuncName.concat("i16") - case IntegerType => - typedFuncName.concat("i32") - case LongType => - typedFuncName.concat("i64") - case FloatType => - typedFuncName.concat("fp32") - case DoubleType => - typedFuncName.concat("fp64") - case DateType => - typedFuncName.concat("date") - case TimestampType => - typedFuncName.concat("ts") - case StringType => - typedFuncName.concat("str") - case BinaryType => - typedFuncName.concat("vbin") - case DecimalType() => - val decimalType = datatype.asInstanceOf[DecimalType] - val precision = decimalType.precision - val scale = decimalType.scale - typedFuncName.concat("dec<" + precision + "," + scale + ">") - case ArrayType(_, _) => - typedFuncName.concat("list") - case StructType(_) => - typedFuncName.concat("struct") - case MapType(_, _, _) => - typedFuncName.concat("map") - case CharType(_) => - typedFuncName.concat("fchar") - case NullType => - typedFuncName.concat("nothing") - case other => - throw new UnsupportedOperationException(s"Type $other not supported.") - } - // For the last item, do not need to add _. - if (idx < (datatypes.size - 1)) { + typedFuncName = typedFuncName.concat(getTypeSigName(datatypes(idx))) + // For the last item, no need to append _. + if (idx < datatypes.size - 1) { typedFuncName = typedFuncName.concat("_") } } diff --git a/gluten-core/src/test/scala/io/glutenproject/execution/WholeStageTransformerSuite.scala b/gluten-core/src/test/scala/io/glutenproject/execution/WholeStageTransformerSuite.scala index 1823bc551643..ae8f8102bc73 100644 --- a/gluten-core/src/test/scala/io/glutenproject/execution/WholeStageTransformerSuite.scala +++ b/gluten-core/src/test/scala/io/glutenproject/execution/WholeStageTransformerSuite.scala @@ -199,8 +199,7 @@ abstract class WholeStageTransformerSuite extends GlutenQueryTest with SharedSpa */ def checkOperatorMatch[T <: TransformSupport](df: DataFrame)(implicit tag: ClassTag[T]): Unit = { val executedPlan = getExecutedPlan(df) - assert(executedPlan.exists( - plan => plan.find(child => child.getClass == tag.runtimeClass).isDefined)) + assert(executedPlan.exists(plan => plan.getClass == tag.runtimeClass)) } /** diff --git a/gluten-data/src/main/scala/io/glutenproject/execution/GlutenHashAggregateExecTransformer.scala b/gluten-data/src/main/scala/io/glutenproject/execution/GlutenHashAggregateExecTransformer.scala index c5951c5dfc4c..c8b115d14ef3 100644 --- a/gluten-data/src/main/scala/io/glutenproject/execution/GlutenHashAggregateExecTransformer.scala +++ b/gluten-data/src/main/scala/io/glutenproject/execution/GlutenHashAggregateExecTransformer.scala @@ -18,24 +18,24 @@ package io.glutenproject.execution import scala.collection.JavaConverters._ - import com.google.protobuf.Any +import io.glutenproject.backendsapi.BackendsApiManager +import io.glutenproject.execution.VeloxAggregateFunctionsBuilder.{veloxFourIntermediateTypes, veloxSixIntermediateTypes, veloxThreeIntermediateTypes} import io.glutenproject.expression._ import io.glutenproject.expression.ConverterUtils.FunctionConfig import io.glutenproject.substrait.`type`.{TypeBuilder, TypeNode} import io.glutenproject.substrait.expression.{AggregateFunctionNode, ExpressionBuilder, ExpressionNode, ScalarFunctionNode} import io.glutenproject.substrait.extensions.ExtensionBuilder import io.glutenproject.substrait.rel.{RelBuilder, RelNode} -import java.util +import java.util import io.glutenproject.substrait.{AggregationParams, SubstraitContext} import io.glutenproject.utils.GlutenDecimalUtil - import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.execution._ import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.types.{BooleanType, DecimalType, DoubleType, IntegerType, LongType} +import org.apache.spark.sql.types.{BooleanType, DataType, DecimalType, DoubleType, LongType, StructField, StructType} case class GlutenHashAggregateExecTransformer( requiredChildDistributionExpressions: Option[Seq[Expression]], @@ -65,13 +65,13 @@ case class GlutenHashAggregateExecTransformer( case _: Average | _: StddevSamp | _: StddevPop | _: VarianceSamp | _: VariancePop | _: Corr | _: CovPopulation | _: CovSample => expr.mode match { - case Partial => + case Partial | PartialMerge => return true case _ => } case sum: Sum if sum.dataType.isInstanceOf[DecimalType] => expr.mode match { - case Partial => + case Partial | PartialMerge => return true case _ => } @@ -102,7 +102,7 @@ case class GlutenHashAggregateExecTransformer( for (expr <- aggregateExpressions) { expr.mode match { - case Partial => + case Partial | PartialMerge => case _ => throw new UnsupportedOperationException(s"${expr.mode} not supported.") } @@ -179,25 +179,34 @@ case class GlutenHashAggregateExecTransformer( structTypeNodes.add(ConverterUtils.getTypeNode(LongType, nullable = true)) case _: StddevSamp | _: StddevPop | _: VarianceSamp | _: VariancePop => // Use struct type to represent Velox Row(BIGINT, DOUBLE, DOUBLE). - structTypeNodes.add(ConverterUtils.getTypeNode(LongType, nullable = false)) - structTypeNodes.add(ConverterUtils.getTypeNode(DoubleType, nullable = false)) - structTypeNodes.add(ConverterUtils.getTypeNode(DoubleType, nullable = false)) + structTypeNodes.add(ConverterUtils + .getTypeNode(veloxThreeIntermediateTypes.head, nullable = false)) + structTypeNodes.add(ConverterUtils + .getTypeNode(veloxThreeIntermediateTypes(1), nullable = false)) + structTypeNodes.add(ConverterUtils + .getTypeNode(veloxThreeIntermediateTypes(2), nullable = false)) case _: Corr => - structTypeNodes.add(ConverterUtils.getTypeNode(DoubleType, nullable = false)) - structTypeNodes.add(ConverterUtils.getTypeNode(LongType, nullable = false)) - structTypeNodes.add(ConverterUtils.getTypeNode(DoubleType, nullable = false)) - structTypeNodes.add(ConverterUtils.getTypeNode(DoubleType, nullable = false)) - structTypeNodes.add(ConverterUtils.getTypeNode(DoubleType, nullable = false)) - structTypeNodes.add(ConverterUtils.getTypeNode(DoubleType, nullable = false)) + structTypeNodes.add(ConverterUtils + .getTypeNode(veloxSixIntermediateTypes.head, nullable = false)) + structTypeNodes.add(ConverterUtils + .getTypeNode(veloxSixIntermediateTypes(1), nullable = false)) + structTypeNodes.add(ConverterUtils + .getTypeNode(veloxSixIntermediateTypes(2), nullable = false)) + structTypeNodes.add(ConverterUtils + .getTypeNode(veloxSixIntermediateTypes(3), nullable = false)) + structTypeNodes.add(ConverterUtils + .getTypeNode(veloxSixIntermediateTypes(4), nullable = false)) + structTypeNodes.add(ConverterUtils + .getTypeNode(veloxSixIntermediateTypes(5), nullable = false)) case _: CovPopulation | _: CovSample => - structTypeNodes.add(ConverterUtils.getTypeNode(DoubleType, nullable = false)) - structTypeNodes.add(ConverterUtils.getTypeNode(LongType, nullable = false)) - structTypeNodes.add(ConverterUtils.getTypeNode(DoubleType, nullable = false)) - structTypeNodes.add(ConverterUtils.getTypeNode(DoubleType, nullable = false)) - structTypeNodes.add(ConverterUtils.getTypeNode(DoubleType, nullable = true)) - structTypeNodes.add(ConverterUtils.getTypeNode(LongType, nullable = true)) - structTypeNodes.add(ConverterUtils.getTypeNode(DoubleType, nullable = true)) - structTypeNodes.add(ConverterUtils.getTypeNode(DoubleType, nullable = true)) + structTypeNodes.add(ConverterUtils + .getTypeNode(veloxFourIntermediateTypes.head, nullable = false)) + structTypeNodes.add(ConverterUtils + .getTypeNode(veloxFourIntermediateTypes(1), nullable = false)) + structTypeNodes.add(ConverterUtils + .getTypeNode(veloxFourIntermediateTypes(2), nullable = false)) + structTypeNodes.add(ConverterUtils + .getTypeNode(veloxFourIntermediateTypes(3), nullable = false)) case sum: Sum if sum.dataType.isInstanceOf[DecimalType] => structTypeNodes.add(ConverterUtils.getTypeNode(sum.dataType, nullable = true)) structTypeNodes.add(ConverterUtils.getTypeNode(BooleanType, nullable = false)) @@ -214,22 +223,33 @@ case class GlutenHashAggregateExecTransformer( childrenNodeList: java.util.ArrayList[ExpressionNode], aggregateMode: AggregateMode, aggregateNodeList: java.util.ArrayList[AggregateFunctionNode]): Unit = { + // A special handling for PartialMerge in the execution of count distinct. + // Use partial phase for this aggregation. + val modeKeyWord = modeToKeyWord(if (partialCountInMerge) Partial else aggregateMode) aggregateFunction match { case _: Average | _: StddevSamp | _: StddevPop | _: VarianceSamp | _: VariancePop | _: Corr | _: CovPopulation | _: CovSample => aggregateMode match { case Partial => val partialNode = ExpressionBuilder.makeAggregateFunction( - AggregateFunctionsBuilder.create(args, aggregateFunction), + VeloxAggregateFunctionsBuilder.create(args, aggregateFunction), childrenNodeList, - modeToKeyWord(aggregateMode), + modeKeyWord, getIntermediateTypeNode(aggregateFunction)) aggregateNodeList.add(partialNode) + case PartialMerge => + val aggFunctionNode = ExpressionBuilder.makeAggregateFunction( + VeloxAggregateFunctionsBuilder + .create(args, aggregateFunction, partialCountInMerge), + childrenNodeList, + modeKeyWord, + getIntermediateTypeNode(aggregateFunction)) + aggregateNodeList.add(aggFunctionNode) case Final => val aggFunctionNode = ExpressionBuilder.makeAggregateFunction( - AggregateFunctionsBuilder.create(args, aggregateFunction), + VeloxAggregateFunctionsBuilder.create(args, aggregateFunction), childrenNodeList, - modeToKeyWord(aggregateMode), + modeKeyWord, ConverterUtils.getTypeNode(aggregateFunction.dataType, aggregateFunction.nullable)) aggregateNodeList.add(aggFunctionNode) case other => @@ -239,26 +259,42 @@ case class GlutenHashAggregateExecTransformer( aggregateMode match { case Partial => val partialNode = ExpressionBuilder.makeAggregateFunction( - AggregateFunctionsBuilder.create(args, aggregateFunction), + VeloxAggregateFunctionsBuilder.create(args, aggregateFunction), childrenNodeList, - modeToKeyWord(aggregateMode), + modeKeyWord, getIntermediateTypeNode(aggregateFunction)) aggregateNodeList.add(partialNode) + case PartialMerge => + val aggFunctionNode = ExpressionBuilder.makeAggregateFunction( + VeloxAggregateFunctionsBuilder + .create(args, aggregateFunction, partialCountInMerge), + childrenNodeList, + modeKeyWord, + getIntermediateTypeNode(aggregateFunction)) + aggregateNodeList.add(aggFunctionNode) case Final => val aggFunctionNode = ExpressionBuilder.makeAggregateFunction( - AggregateFunctionsBuilder.create(args, aggregateFunction), + VeloxAggregateFunctionsBuilder.create(args, aggregateFunction), childrenNodeList, - modeToKeyWord(aggregateMode), + modeKeyWord, ConverterUtils.getTypeNode(aggregateFunction.dataType, aggregateFunction.nullable)) aggregateNodeList.add(aggFunctionNode) case other => throw new UnsupportedOperationException(s"$other is not supported.") } + case _: Count if aggregateMode == Partial => + val aggFunctionNode = ExpressionBuilder.makeAggregateFunction( + VeloxAggregateFunctionsBuilder.create(args, aggregateFunction), + childrenNodeList, + modeKeyWord, + ConverterUtils.getTypeNode(aggregateFunction.dataType, aggregateFunction.nullable)) + aggregateNodeList.add(aggFunctionNode) case _ => val aggFunctionNode = ExpressionBuilder.makeAggregateFunction( - AggregateFunctionsBuilder.create(args, aggregateFunction), + VeloxAggregateFunctionsBuilder.create( + args, aggregateFunction, aggregateMode == PartialMerge && partialCountInMerge), childrenNodeList, - modeToKeyWord(aggregateMode), + modeKeyWord, ConverterUtils.getTypeNode(aggregateFunction.dataType, aggregateFunction.nullable)) aggregateNodeList.add(aggFunctionNode) } @@ -279,7 +315,7 @@ case class GlutenHashAggregateExecTransformer( case _: Average | _: StddevSamp | _: StddevPop | _: VarianceSamp | _: VariancePop | _: Corr | _: CovPopulation | _: CovSample => expression.mode match { - case Partial => + case Partial | PartialMerge => typeNodeList.add(getIntermediateTypeNode(aggregateFunction)) case Final => typeNodeList.add( @@ -289,7 +325,7 @@ case class GlutenHashAggregateExecTransformer( } case sum: Sum if sum.dataType.isInstanceOf[DecimalType] => expression.mode match { - case Partial => + case Partial | PartialMerge => typeNodeList.add(getIntermediateTypeNode(aggregateFunction)) case Final => typeNodeList.add( @@ -363,9 +399,9 @@ case class GlutenHashAggregateExecTransformer( aggregateFunction match { case Average(_, _) => aggregateExpression.mode match { - case Final => + case PartialMerge | Final => assert(functionInputAttributes.size == 2, - "Final stage of Average expects two input attributes.") + s"${aggregateExpression.mode.toString} of Average expects two input attributes.") // Use a Velox function to combine the intermediate columns into struct. val childNodes = new util.ArrayList[ExpressionNode]( functionInputAttributes.toList.map(attr => { @@ -379,9 +415,10 @@ case class GlutenHashAggregateExecTransformer( } case _: StddevSamp | _: StddevPop | _: VarianceSamp | _: VariancePop => aggregateExpression.mode match { - case Final => + case PartialMerge | Final => assert(functionInputAttributes.size == 3, - "Final stage of StddevSamp expects three input attributes.") + s"${aggregateExpression.mode.toString} mode of" + + s"${aggregateFunction.getClass.toString} expects three input attributes.") // Use a Velox function to combine the intermediate columns into struct. var index = 0 var newInputAttributes: Seq[Attribute] = Seq() @@ -411,9 +448,9 @@ case class GlutenHashAggregateExecTransformer( } case _: Corr => aggregateExpression.mode match { - case Final => + case PartialMerge | Final => assert(functionInputAttributes.size == 6, - "Final stage of Corr expects 6 input attributes.") + s"${aggregateExpression.mode.toString} mode of Corr expects 6 input attributes.") // Use a Velox function to combine the intermediate columns into struct. var index = 0 var newInputAttributes: Seq[Attribute] = Seq() @@ -421,8 +458,8 @@ case class GlutenHashAggregateExecTransformer( // Velox's Corr order is [ck, n, xMk, yMk, xAvg, yAvg] // Spark's Corr order is [n, xAvg, yAvg, ck, xMk, yMk] val sparkCorrOutputAttr = aggregateFunction.inputAggBufferAttributes.map(_.name) - val veloxInputOrder = AggregateFunctionsBuilder.veloxCorrIntermediateDataOrder.map( - name => sparkCorrOutputAttr.indexOf(name)) + val veloxInputOrder = VeloxAggregateFunctionsBuilder + .veloxCorrIntermediateDataOrder.map(name => sparkCorrOutputAttr.indexOf(name)) for (order <- veloxInputOrder) { val attr = functionInputAttributes(order) val aggExpr: ExpressionTransformer = ExpressionConverter @@ -449,9 +486,10 @@ case class GlutenHashAggregateExecTransformer( } case _: CovPopulation | _: CovSample => aggregateExpression.mode match { - case Final => + case PartialMerge | Final => assert(functionInputAttributes.size == 4, - "Final stage of Corr expects 4 input attributes.") + s"${aggregateExpression.mode.toString} mode of" + + s"${aggregateFunction.getClass.toString} expects 4 input attributes.") // Use a Velox function to combine the intermediate columns into struct. var index = 0 var newInputAttributes: Seq[Attribute] = Seq() @@ -459,8 +497,8 @@ case class GlutenHashAggregateExecTransformer( // Velox's Covar order is [ck, n, xAvg, yAvg] // Spark's Covar order is [n, xAvg, yAvg, ck] val sparkCorrOutputAttr = aggregateFunction.inputAggBufferAttributes.map(_.name) - val veloxInputOrder = AggregateFunctionsBuilder.veloxCovarIntermediateDataOrder.map( - name => sparkCorrOutputAttr.indexOf(name)) + val veloxInputOrder = VeloxAggregateFunctionsBuilder + .veloxCovarIntermediateDataOrder.map(name => sparkCorrOutputAttr.indexOf(name)) for (order <- veloxInputOrder) { val attr = functionInputAttributes(order) val aggExpr: ExpressionTransformer = ExpressionConverter @@ -487,7 +525,7 @@ case class GlutenHashAggregateExecTransformer( } case sum: Sum if sum.dataType.isInstanceOf[DecimalType] => aggregateExpression.mode match { - case Final => + case PartialMerge | Final => assert(functionInputAttributes.size == 2, "Final stage of Average expects two input attributes.") // Use a Velox function to combine the intermediate columns into struct. @@ -501,6 +539,16 @@ case class GlutenHashAggregateExecTransformer( case other => throw new UnsupportedOperationException(s"$other is not supported.") } + case _: Count if partialCountInMerge && aggregateExpression.mode == Partial => + assert(functionInputAttributes.size == 1, + "Only one input attribute is expected for Count.") + val childNodes = new util.ArrayList[ExpressionNode]( + aggregateFunction.children.map(attr => { + ExpressionConverter + .replaceWithExpressionTransformer(attr, originalInputAttributes) + .doTransform(args) + }).asJava) + exprNodes.addAll(childNodes) case _ => assert(functionInputAttributes.size == 1, "Only one input attribute is expected.") val childNodes = new util.ArrayList[ExpressionNode]( @@ -572,6 +620,21 @@ case class GlutenHashAggregateExecTransformer( projectRel, groupingList, aggregateFunctionList, aggFilterList, context, operatorId) } + /** + * Whether this is a mixed aggregation of partial count and + * other partial-merge aggregation functions. + * @return whether partial count and other partial-merge functions coexist. + */ + def partialCountInMerge: Boolean = { + val partialMergeExists = aggregateExpressions.exists(expression => { + expression.mode == PartialMerge + }) + val partialCountExists = aggregateExpressions.exists(expression => { + expression.aggregateFunction.isInstanceOf[Count] && expression.mode == Partial + }) + partialMergeExists && partialCountExists + } + /** * Create and return the Rel for the this aggregation. * @param context the Substrait context @@ -587,9 +650,8 @@ case class GlutenHashAggregateExecTransformer( input: RelNode = null, validation: Boolean = false): RelNode = { val originalInputAttributes = child.output - val preProjectionNeeded = needsPreProjection - var aggRel = if (preProjectionNeeded) { + var aggRel = if (needsPreProjection) { aggParams.preProjectionNeeded = true getAggRelWithPreProjection( context, originalInputAttributes, operatorId, input, validation) @@ -628,3 +690,80 @@ case class GlutenHashAggregateExecTransformer( copy(child = newChild) } } + +/** + * An aggregation function builder specifically used by Velox backend. + */ +object VeloxAggregateFunctionsBuilder { + + val veloxCorrIntermediateDataOrder: Seq[String] = Seq("ck", "n", "xMk", "yMk", "xAvg", "yAvg") + val veloxCovarIntermediateDataOrder: Seq[String] = Seq("ck", "n", "xAvg", "yAvg") + + val veloxThreeIntermediateTypes: Seq[DataType] = Seq(LongType, DoubleType, DoubleType) + val veloxFourIntermediateTypes: Seq[DataType] = Seq(DoubleType, LongType, DoubleType, DoubleType) + val veloxSixIntermediateTypes: Seq[DataType] = + Seq(DoubleType, LongType, DoubleType, DoubleType, DoubleType, DoubleType) + + /** + * Get the compatible input types for a Velox aggregate function. + * @param aggregateFunc: the input aggreagate function. + * @param forMergeCompanion: whether this is a special case to solve mixed aggregation phases. + * @return the input types of a Velox aggregate function. + */ + private def getInputTypes(aggregateFunc: AggregateFunction, + forMergeCompanion: Boolean): Seq[DataType] = { + if (!forMergeCompanion) { + return aggregateFunc.children.map(child => child.dataType) + } + if (aggregateFunc.aggBufferAttributes.size == veloxThreeIntermediateTypes.size) { + return Seq(StructType(veloxThreeIntermediateTypes.map(intermediateType => + StructField("", intermediateType)).toArray)) + } + if (aggregateFunc.aggBufferAttributes.size == veloxFourIntermediateTypes.size) { + return Seq(StructType(veloxFourIntermediateTypes.map(intermediateType => + StructField("", intermediateType)).toArray)) + } + if (aggregateFunc.aggBufferAttributes.size == veloxSixIntermediateTypes.size) { + return Seq(StructType(veloxSixIntermediateTypes.map(intermediateType => + StructField("", intermediateType)).toArray)) + } + if (aggregateFunc.aggBufferAttributes.size > 1) { + return Seq(StructType(aggregateFunc.aggBufferAttributes.map(attribute => + StructField("", attribute.dataType)).toArray)) + } + aggregateFunc.aggBufferAttributes.map(child => child.dataType) + } + + /** + * Create an scalar function for the input aggregate function. + * @param args: the function map. + * @param aggregateFunc: the input aggregate function. + * @param forMergeCompanion: whether this is a special case to solve mixed aggregation phases. + * @return + */ + def create(args: java.lang.Object, aggregateFunc: AggregateFunction, + forMergeCompanion: Boolean = false): Long = { + val functionMap = args.asInstanceOf[java.util.HashMap[String, java.lang.Long]] + + val sigName = ExpressionMappings.aggregate_functions_map.getOrElse( + aggregateFunc.getClass, ExpressionMappings.getAggSigOther(aggregateFunc.prettyName)) + // Check whether Gluten supports this aggregate function. + if (sigName.isEmpty) { + throw new UnsupportedOperationException(s"not currently supported: $aggregateFunc.") + } + // Check whether each backend supports this aggregate function. + if (!BackendsApiManager.getValidatorApiInstance.doAggregateFunctionValidate( + sigName, aggregateFunc)) { + throw new UnsupportedOperationException(s"not currently supported: $aggregateFunc.") + } + // Use companion function for partial-merge aggregation functions on count distinct. + val substraitAggFuncName = if (!forMergeCompanion) sigName else sigName + "_merge" + + ExpressionBuilder.newScalarFunction( + functionMap, + ConverterUtils.makeFuncName( + substraitAggFuncName, + getInputTypes(aggregateFunc, forMergeCompanion), + FunctionConfig.REQ)) + } +}