Skip to content

Commit

Permalink
[GLUTEN-4668][CH] Merge two phase hash-based aggregate into one aggre…
Browse files Browse the repository at this point in the history
…gate in the spark plan when there is no shuffle

Examples:

 HashAggregate(t1.i, SUM, final)
                |                  =>    HashAggregate(t1.i, SUM, complete)
 HashAggregate(t1.i, SUM, partial)

now this feature only support for CH backend.

Close #4668.

Co-authored-by: lgbo <lgbo.ustc@gmail.com>
  • Loading branch information
zzcclp and lgbo-ustc committed Feb 18, 2024
1 parent 4048629 commit f5e1689
Show file tree
Hide file tree
Showing 27 changed files with 788 additions and 97 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -282,4 +282,6 @@ object CHBackendSettings extends BackendSettingsApi with Logging {
override def enableNativeWriteFiles(): Boolean = {
GlutenConfig.getConf.enableNativeWriter.getOrElse(false)
}

override def mergeTwoPhasesHashBaseAggregateIfNeed(): Boolean = true
}
Original file line number Diff line number Diff line change
Expand Up @@ -109,16 +109,25 @@ case class CHHashAggregateExecTransformer(
val typeList = new util.ArrayList[TypeNode]()
val nameList = new util.ArrayList[String]()
val (inputAttrs, outputAttrs) = {
if (modes.isEmpty) {
// When there is no aggregate function, it does not need
if (modes.isEmpty || modes.forall(_ == Complete)) {
// When there is no aggregate function or there is complete mode, it does not need
// to handle outputs according to the AggregateMode
for (attr <- child.output) {
typeList.add(ConverterUtils.getTypeNode(attr.dataType, attr.nullable))
nameList.add(ConverterUtils.genColumnNameWithExprId(attr))
nameList.addAll(ConverterUtils.collectStructFieldNames(attr.dataType))
}
(child.output, output)
} else if (!modes.contains(Partial)) {
} else if (modes.forall(_ == Partial)) {
// partial mode
for (attr <- child.output) {
typeList.add(ConverterUtils.getTypeNode(attr.dataType, attr.nullable))
nameList.add(ConverterUtils.genColumnNameWithExprId(attr))
nameList.addAll(ConverterUtils.collectStructFieldNames(attr.dataType))
}

(child.output, aggregateResultAttributes)
} else {
// non-partial mode
var resultAttrIndex = 0
for (attr <- aggregateResultAttributes) {
Expand All @@ -135,15 +144,6 @@ case class CHHashAggregateExecTransformer(
resultAttrIndex += 1
}
(aggregateResultAttributes, output)
} else {
// partial mode
for (attr <- child.output) {
typeList.add(ConverterUtils.getTypeNode(attr.dataType, attr.nullable))
nameList.add(ConverterUtils.genColumnNameWithExprId(attr))
nameList.addAll(ConverterUtils.collectStructFieldNames(attr.dataType))
}

(child.output, aggregateResultAttributes)
}
}

Expand Down Expand Up @@ -212,7 +212,7 @@ case class CHHashAggregateExecTransformer(
val aggregateFunc = aggExpr.aggregateFunction
val childrenNodeList = new util.ArrayList[ExpressionNode]()
val childrenNodes = aggExpr.mode match {
case Partial =>
case Partial | Complete =>
aggregateFunc.children.toList.map(
expr => {
ExpressionConverter
Expand Down Expand Up @@ -446,7 +446,7 @@ case class CHHashAggregateExecPullOutHelper(
}
resIndex += aggBufferAttr.size
resIndex
case Final =>
case Final | Complete =>
aggregateAttr += aggregateAttributeList(resIndex)
resIndex += 1
resIndex
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,13 @@ class GlutenClickHouseColumnarShuffleAQESuite
}

test("TPCH Q18") {
runTPCHQuery(18) { df => }
runTPCHQuery(18) {
df =>
val hashAggregates = collect(df.queryExecution.executedPlan) {
case hash: HashAggregateExecBaseTransformer => hash
}
assert(hashAggregates.size == 3)
}
}

test("TPCH Q19") {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,13 @@ class GlutenClickHouseDSV2ColumnarShuffleSuite extends GlutenClickHouseTPCHAbstr
}

test("TPCH Q3") {
runTPCHQuery(3) { df => }
runTPCHQuery(3) {
df =>
val aggs = df.queryExecution.executedPlan.collectWithSubqueries {
case agg: HashAggregateExecBaseTransformer => agg
}
assert(aggs.size == 1)
}
}

test("TPCH Q4") {
Expand Down Expand Up @@ -74,43 +80,91 @@ class GlutenClickHouseDSV2ColumnarShuffleSuite extends GlutenClickHouseTPCHAbstr
}

test("TPCH Q11") {
runTPCHQuery(11) { df => }
runTPCHQuery(11) {
df =>
val aggs = df.queryExecution.executedPlan.collectWithSubqueries {
case agg: HashAggregateExecBaseTransformer => agg
}
assert(aggs.size == 3)
}
}

test("TPCH Q12") {
runTPCHQuery(12) { df => }
}

test("TPCH Q13") {
runTPCHQuery(13) { df => }
runTPCHQuery(13) {
df =>
val aggs = df.queryExecution.executedPlan.collectWithSubqueries {
case agg: HashAggregateExecBaseTransformer => agg
}
assert(aggs.size == 3)
}
}

test("TPCH Q14") {
runTPCHQuery(14) { df => }
runTPCHQuery(14) {
df =>
val aggs = df.queryExecution.executedPlan.collectWithSubqueries {
case agg: HashAggregateExecBaseTransformer => agg
}
assert(aggs.size == 1)
}
}

test("TPCH Q15") {
runTPCHQuery(15) { df => }
runTPCHQuery(15) {
df =>
val aggs = df.queryExecution.executedPlan.collectWithSubqueries {
case agg: HashAggregateExecBaseTransformer => agg
}
assert(aggs.size == 4)
}
}

test("TPCH Q16") {
runTPCHQuery(16, noFallBack = false) { df => }
}

test("TPCH Q17") {
runTPCHQuery(17) { df => }
runTPCHQuery(17) {
df =>
val aggs = df.queryExecution.executedPlan.collectWithSubqueries {
case agg: HashAggregateExecBaseTransformer => agg
}
assert(aggs.size == 3)
}
}

test("TPCH Q18") {
runTPCHQuery(18) { df => }
runTPCHQuery(18) {
df =>
val aggs = df.queryExecution.executedPlan.collectWithSubqueries {
case agg: HashAggregateExecBaseTransformer => agg
}
assert(aggs.size == 4)
}
}

test("TPCH Q19") {
runTPCHQuery(19) { df => }
runTPCHQuery(19) {
df =>
val aggs = df.queryExecution.executedPlan.collectWithSubqueries {
case agg: HashAggregateExecBaseTransformer => agg
}
assert(aggs.size == 1)
}
}

test("TPCH Q20") {
runTPCHQuery(20) { df => }
runTPCHQuery(20) {
df =>
val aggs = df.queryExecution.executedPlan.collectWithSubqueries {
case agg: HashAggregateExecBaseTransformer => agg
}
assert(aggs.size == 1)
}
}

test("TPCH Q21") {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,6 @@ class GlutenClickHouseNativeWriteTableSuite
.set("spark.gluten.sql.enable.native.validation", "false")
// TODO: support default ANSI policy
.set("spark.sql.storeAssignmentPolicy", "legacy")
// .set("spark.gluten.sql.columnar.backend.ch.runtime_config.logger.level", "debug")
.set("spark.sql.warehouse.dir", getWarehouseDir)
.setMaster("local[1]")
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,6 @@ class GlutenClickHouseTPCDSParquetSuite extends GlutenClickHouseTPCDSAbstractSui
.set("spark.memory.offHeap.size", "4g")
.set("spark.gluten.sql.validation.logLevel", "ERROR")
.set("spark.gluten.sql.validation.printStackOnFailure", "true")
// .set("spark.gluten.sql.columnar.backend.ch.runtime_config.logger.level", "debug")
// .setMaster("local[1]")
}

executeTPCDSTest(false)
Expand Down
Loading

0 comments on commit f5e1689

Please sign in to comment.