-
Notifications
You must be signed in to change notification settings - Fork 28.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[SPARK-22084][SQL] Fix performance regression in aggregation strategy #19301
Conversation
@@ -38,7 +38,7 @@ import org.apache.spark.sql.internal.SQLConf | |||
* view resolution, in this way, we are able to get the correct view column ordering and | |||
* omit the extra columns that we don't require); | |||
* 1.2. Else set the child output attributes to `queryOutput`. | |||
* 2. Map the `queryQutput` to view output by index, if the corresponding attributes don't match, | |||
* 2. Map the `queryOutput` to view output by index, if the corresponding attributes don't match, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It looks all the same?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Q -> O
} else { | ||
Seq(aggregateFunction.toString, mode, isDistinct) | ||
} | ||
val hashCode = state.map(Objects.hashCode).foldLeft(0)((a, b) => 31 * a + b) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what's the purpose here?
can you explain more about how this bug happens? |
val aggregateExpressions = resultExpressions.flatMap { expr =>
expr.collect {
case agg: AggregateExpression => agg
}
}.distinct Before the fix, the exprId of each aggregate expression is different which can cause distinct fail. |
my case: select dt,
geohash_of_latlng,
sum(mt_cnt),
sum(ele_cnt),
round(sum(mt_cnt) * 1.0 * 100 / sum(mt_cnt_all), 2),
round(sum(ele_cnt) * 1.0 * 100 / sum(ele_cnt_all), 2)
from temp.test_geohash_match_parquet
group by dt, geohash_of_latlng
order by dt, geohash_of_latlng limit 10; before your fix TakeOrderedAndProject(limit=10, orderBy=[dt#502 ASC NULLS FIRST,geohash_of_latlng#507 ASC NULLS FIRST], output=[dt#502,geohash_of_latlng#507,sum(mt_cnt)#521L,sum(ele_cnt)#522L,round((CAST((CAST((CAST(CAST(sum(CAST(mt_cnt AS BIGINT)) AS DECIMAL(20,0)) AS DECIMAL(21,1)) * CAST(1.0 AS DECIMAL(21,1))) AS DECIMAL(23,1)) * CAST(CAST(100 AS DECIMAL(23,1)) AS DECIMAL(23,1))) AS DECIMAL(38,2)) / CAST(CAST(sum(CAST(mt_cnt_all AS BIGINT)) AS DECIMAL(20,0)) AS DECIMAL(38,2))), 2)#523,round((CAST((CAST((CAST(CAST(sum(CAST(ele_cnt AS BIGINT)) AS DECIMAL(20,0)) AS DECIMAL(21,1)) * CAST(1.0 AS DECIMAL(21,1))) AS DECIMAL(23,1)) * CAST(CAST(100 AS DECIMAL(23,1)) AS DECIMAL(23,1))) AS DECIMAL(38,2)) / CAST(CAST(sum(CAST(ele_cnt_all AS BIGINT)) AS DECIMAL(20,0)) AS DECIMAL(38,2))), 2)#524])
+- *HashAggregate(keys=[dt#502, geohash_of_latlng#507], functions=[sum(cast(mt_cnt#511 as bigint)), sum(cast(ele_cnt#512 as bigint)), sum(cast(mt_cnt#511 as bigint)), sum(cast(mt_cnt_all#513 as bigint)), sum(cast(ele_cnt#512 as bigint)), sum(cast(ele_cnt_all#514 as bigint))])
+- Exchange(coordinator id: 148401229) hashpartitioning(dt#502, geohash_of_latlng#507, 1000), coordinator[target post-shuffle partition size: 2000000]
+- *HashAggregate(keys=[dt#502, geohash_of_latlng#507], functions=[partial_sum(cast(mt_cnt#511 as bigint)), partial_sum(cast(ele_cnt#512 as bigint)), partial_sum(cast(mt_cnt#511 as bigint)), partial_sum(cast(mt_cnt_all#513 as bigint)), partial_sum(cast(ele_cnt#512 as bigint)), partial_sum(cast(ele_cnt_all#514 as bigint))])
+- HiveTableScan [geohash_of_latlng#507, mt_cnt#511, ele_cnt#512, mt_cnt_all#513, ele_cnt_all#514, dt#502], MetastoreRelation temp, test_geohash_match_parquet after your fix TakeOrderedAndProject(limit=10, orderBy=[dt#467 ASC NULLS FIRST,geohash_of_latlng#472 ASC NULLS FIRST], output=[dt#467,geohash_of_latlng#472,sum(mt_cnt)#486L,sum(ele_cnt)#487L,round((CAST((CAST((CAST(CAST(sum(CAST(mt_cnt AS BIGINT)) AS DECIMAL(20,0)) AS DECIMAL(21,1)) * CAST(1.0 AS DECIMAL(21,1))) AS DECIMAL(23,1)) * CAST(CAST(100 AS DECIMAL(23,1)) AS DECIMAL(23,1))) AS DECIMAL(38,2)) / CAST(CAST(sum(CAST(mt_cnt_all AS BIGINT)) AS DECIMAL(20,0)) AS DECIMAL(38,2))), 2)#488,round((CAST((CAST((CAST(CAST(sum(CAST(ele_cnt AS BIGINT)) AS DECIMAL(20,0)) AS DECIMAL(21,1)) * CAST(1.0 AS DECIMAL(21,1))) AS DECIMAL(23,1)) * CAST(CAST(100 AS DECIMAL(23,1)) AS DECIMAL(23,1))) AS DECIMAL(38,2)) / CAST(CAST(sum(CAST(ele_cnt_all AS BIGINT)) AS DECIMAL(20,0)) AS DECIMAL(38,2))), 2)#489])
+- *HashAggregate(keys=[dt#467, geohash_of_latlng#472], functions=[sum(cast(mt_cnt#476 as bigint)), sum(cast(ele_cnt#477 as bigint)), sum(cast(mt_cnt#476 as bigint)), sum(cast(mt_cnt_all#478 as bigint)), sum(cast(ele_cnt#477 as bigint)), sum(cast(ele_cnt_all#479 as bigint))])
+- Exchange(coordinator id: 227998366) hashpartitioning(dt#467, geohash_of_latlng#472, 1000), coordinator[target post-shuffle partition size: 2000000]
+- *HashAggregate(keys=[dt#467, geohash_of_latlng#472], functions=[partial_sum(cast(mt_cnt#476 as bigint)), partial_sum(cast(ele_cnt#477 as bigint)), partial_sum(cast(mt_cnt#476 as bigint)), partial_sum(cast(mt_cnt_all#478 as bigint)), partial_sum(cast(ele_cnt#477 as bigint)), partial_sum(cast(ele_cnt_all#479 as bigint))])
+- HiveTableScan [geohash_of_latlng#472, mt_cnt#476, ele_cnt#477, mt_cnt_all#478, ele_cnt_all#479, dt#467], MetastoreRelation temp, test_geohash_match_parquet |
I don't know wether my case can be optimized or not. |
should |
@cenyuhai This is an optimize for physical plan, and your case can be optimized. select dt,
geohash_of_latlng,
sum(mt_cnt),
sum(ele_cnt),
round(sum(mt_cnt) * 1.0 * 100 / sum(mt_cnt_all), 2),
round(sum(ele_cnt) * 1.0 * 100 / sum(ele_cnt_all), 2)
from values(1, 2, 3, 4, 5, 6) as (dt, geohash_of_latlng, mt_cnt, ele_cnt, mt_cnt_all, ele_cnt_all)
group by dt, geohash_of_latlng
order by dt, geohash_of_latlng limit 10 Before:
After:
|
AggregateExpression( | ||
aggregateFunction, | ||
mode, | ||
isDistinct, | ||
NamedExpression.newExprId) | ||
ExprId(hashCode)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think this is the right fix. Semantically the b0
and b1
in SELECT SUM(b) AS b0, SUM(b) AS b1
are different aggregate functions, so they should have different resultId
.
It's kind of an optimization in aggregate planner, we should detect these semantically different but duplicated aggregate functions and only plan one of them.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agreed with @cloud-fan. This should be an optimization done in aggregate planner, instead of forcibly setting expr id here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@cloud-fan @viirya I've tried to optimize in aggregate planner (https://github.com/apache/spark/blob/master/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala#L211).
// A single aggregate expression might appear multiple times in resultExpressions.
// In order to avoid evaluating an individual aggregate function multiple times, we'll
// build a set of the distinct aggregate expressions and build a function which can
// be used to re-write expressions so that they reference the single copy of the
// aggregate function which actually gets computed.
val aggregateExpressions = resultExpressions.flatMap { expr =>
expr.collect {
case agg: AggregateExpression =>
val aggregateFunction = agg.aggregateFunction
val state = if (aggregateFunction.resolved) {
Seq(aggregateFunction.toString, aggregateFunction.dataType,
aggregateFunction.nullable, agg.mode, agg.isDistinct)
} else {
Seq(aggregateFunction.toString, agg.mode, agg.isDistinct)
}
val hashCode = state.map(Objects.hashCode).foldLeft(0)((a, b) => 31 * a + b)
(hashCode, agg)
}
}.groupBy(_._1).map { case (_, values) =>
values.head._2
}.toSeq
But it's difficult to distinguish between different typed aggregators without expr id. Current solution can works well for all of aggregate functions.
I'm not familiar with typed aggregators, any suggestions will be appreciated.
Regarding performance regression, I think you should post benchmark numbers. |
@viirya The problem is already obvious, and the same aggregate expression will be computed multi times. I will provide a benchmark result later. |
I asked it because by considering subexpressionElimination, you may not actually run it multiple times. So the benchmark numbers can tell if your fix really improves the performance. |
@viirya val N = 500L << 22
val benchmark = new Benchmark("agg", N)
val expressions = (0 until 50).map(i => s"sum(id) as r$i")
benchmark.addCase("agg with optimize", numIters = 2) { iter =>
sparkSession.range(N).selectExpr(expressions: _*).collect()
}
benchmark.run() Result:
|
@stanzhai Thanks. I see. Because the aggregation functions are bound to individual buffer slots, they are recognized as different expressions and won't be eliminated. |
Can one of the admins verify this patch? |
I believe this has been fixed, can we close it? |
What changes were proposed in this pull request?
This PR fix a performance regression in aggregation strategy which introduced in Spark 2.0.
For the following SQL:
Before the fix:
After
How was this patch tested?
Existing tests.
Add a new test case.