Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-22084][SQL] Fix performance regression in aggregation strategy #19301

Closed
wants to merge 7 commits into from

Conversation

stanzhai
Copy link
Contributor

@stanzhai stanzhai commented Sep 21, 2017

What changes were proposed in this pull request?

This PR fix a performance regression in aggregation strategy which introduced in Spark 2.0.

For the following SQL:

SELECT a, SUM(b) AS b0, SUM(b) AS b1 
FROM VALUES(1, 1), (2, 2) AS (a, b) 
GROUP BY a

Before the fix:

== Physical Plan ==
*HashAggregate(keys=[a#11], functions=[sum(cast(b#12 as bigint)), sum(cast(b#12 as bigint))])
+- Exchange hashpartitioning(a#11, 200)
   +- *HashAggregate(keys=[a#11], functions=[partial_sum(cast(b#12 as bigint)), partial_sum(cast(b#12 as bigint))])
      +- LocalTableScan [a#11, b#12]

After

== Physical Plan ==
*HashAggregate(keys=[a#11], functions=[sum(cast(b#12 as bigint))])
+- Exchange hashpartitioning(a#11, 2)
   +- *HashAggregate(keys=[a#11], functions=[partial_sum(cast(b#12 as bigint))])
      +- LocalTableScan [a#11, b#12]

How was this patch tested?

Existing tests.
Add a new test case.

@@ -38,7 +38,7 @@ import org.apache.spark.sql.internal.SQLConf
* view resolution, in this way, we are able to get the correct view column ordering and
* omit the extra columns that we don't require);
* 1.2. Else set the child output attributes to `queryOutput`.
* 2. Map the `queryQutput` to view output by index, if the corresponding attributes don't match,
* 2. Map the `queryOutput` to view output by index, if the corresponding attributes don't match,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks all the same?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Q -> O

} else {
Seq(aggregateFunction.toString, mode, isDistinct)
}
val hashCode = state.map(Objects.hashCode).foldLeft(0)((a, b) => 31 * a + b)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what's the purpose here?

@cloud-fan
Copy link
Contributor

can you explain more about how this bug happens?

@stanzhai
Copy link
Contributor Author

stanzhai commented Sep 21, 2017

@cloud-fan

https://github.com/apache/spark/blob/master/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala#L211

      val aggregateExpressions = resultExpressions.flatMap { expr =>
        expr.collect {
          case agg: AggregateExpression => agg
        }
      }.distinct

Before the fix, the exprId of each aggregate expression is different which can cause distinct fail.

@cenyuhai
Copy link
Contributor

my case:

select dt,
geohash_of_latlng,
sum(mt_cnt),
sum(ele_cnt),
round(sum(mt_cnt) * 1.0 * 100 / sum(mt_cnt_all), 2),
round(sum(ele_cnt) * 1.0 * 100 / sum(ele_cnt_all), 2)
from temp.test_geohash_match_parquet
group by dt, geohash_of_latlng
order by dt, geohash_of_latlng limit 10;

before your fix

TakeOrderedAndProject(limit=10, orderBy=[dt#502 ASC NULLS FIRST,geohash_of_latlng#507 ASC NULLS FIRST], output=[dt#502,geohash_of_latlng#507,sum(mt_cnt)#521L,sum(ele_cnt)#522L,round((CAST((CAST((CAST(CAST(sum(CAST(mt_cnt AS BIGINT)) AS DECIMAL(20,0)) AS DECIMAL(21,1)) * CAST(1.0 AS DECIMAL(21,1))) AS DECIMAL(23,1)) * CAST(CAST(100 AS DECIMAL(23,1)) AS DECIMAL(23,1))) AS DECIMAL(38,2)) / CAST(CAST(sum(CAST(mt_cnt_all AS BIGINT)) AS DECIMAL(20,0)) AS DECIMAL(38,2))), 2)#523,round((CAST((CAST((CAST(CAST(sum(CAST(ele_cnt AS BIGINT)) AS DECIMAL(20,0)) AS DECIMAL(21,1)) * CAST(1.0 AS DECIMAL(21,1))) AS DECIMAL(23,1)) * CAST(CAST(100 AS DECIMAL(23,1)) AS DECIMAL(23,1))) AS DECIMAL(38,2)) / CAST(CAST(sum(CAST(ele_cnt_all AS BIGINT)) AS DECIMAL(20,0)) AS DECIMAL(38,2))), 2)#524])
+- *HashAggregate(keys=[dt#502, geohash_of_latlng#507], functions=[sum(cast(mt_cnt#511 as bigint)), sum(cast(ele_cnt#512 as bigint)), sum(cast(mt_cnt#511 as bigint)), sum(cast(mt_cnt_all#513 as bigint)), sum(cast(ele_cnt#512 as bigint)), sum(cast(ele_cnt_all#514 as bigint))])
   +- Exchange(coordinator id: 148401229) hashpartitioning(dt#502, geohash_of_latlng#507, 1000), coordinator[target post-shuffle partition size: 2000000]
      +- *HashAggregate(keys=[dt#502, geohash_of_latlng#507], functions=[partial_sum(cast(mt_cnt#511 as bigint)), partial_sum(cast(ele_cnt#512 as bigint)), partial_sum(cast(mt_cnt#511 as bigint)), partial_sum(cast(mt_cnt_all#513 as bigint)), partial_sum(cast(ele_cnt#512 as bigint)), partial_sum(cast(ele_cnt_all#514 as bigint))])
         +- HiveTableScan [geohash_of_latlng#507, mt_cnt#511, ele_cnt#512, mt_cnt_all#513, ele_cnt_all#514, dt#502], MetastoreRelation temp, test_geohash_match_parquet

after your fix

TakeOrderedAndProject(limit=10, orderBy=[dt#467 ASC NULLS FIRST,geohash_of_latlng#472 ASC NULLS FIRST], output=[dt#467,geohash_of_latlng#472,sum(mt_cnt)#486L,sum(ele_cnt)#487L,round((CAST((CAST((CAST(CAST(sum(CAST(mt_cnt AS BIGINT)) AS DECIMAL(20,0)) AS DECIMAL(21,1)) * CAST(1.0 AS DECIMAL(21,1))) AS DECIMAL(23,1)) * CAST(CAST(100 AS DECIMAL(23,1)) AS DECIMAL(23,1))) AS DECIMAL(38,2)) / CAST(CAST(sum(CAST(mt_cnt_all AS BIGINT)) AS DECIMAL(20,0)) AS DECIMAL(38,2))), 2)#488,round((CAST((CAST((CAST(CAST(sum(CAST(ele_cnt AS BIGINT)) AS DECIMAL(20,0)) AS DECIMAL(21,1)) * CAST(1.0 AS DECIMAL(21,1))) AS DECIMAL(23,1)) * CAST(CAST(100 AS DECIMAL(23,1)) AS DECIMAL(23,1))) AS DECIMAL(38,2)) / CAST(CAST(sum(CAST(ele_cnt_all AS BIGINT)) AS DECIMAL(20,0)) AS DECIMAL(38,2))), 2)#489])
+- *HashAggregate(keys=[dt#467, geohash_of_latlng#472], functions=[sum(cast(mt_cnt#476 as bigint)), sum(cast(ele_cnt#477 as bigint)), sum(cast(mt_cnt#476 as bigint)), sum(cast(mt_cnt_all#478 as bigint)), sum(cast(ele_cnt#477 as bigint)), sum(cast(ele_cnt_all#479 as bigint))])
   +- Exchange(coordinator id: 227998366) hashpartitioning(dt#467, geohash_of_latlng#472, 1000), coordinator[target post-shuffle partition size: 2000000]
      +- *HashAggregate(keys=[dt#467, geohash_of_latlng#472], functions=[partial_sum(cast(mt_cnt#476 as bigint)), partial_sum(cast(ele_cnt#477 as bigint)), partial_sum(cast(mt_cnt#476 as bigint)), partial_sum(cast(mt_cnt_all#478 as bigint)), partial_sum(cast(ele_cnt#477 as bigint)), partial_sum(cast(ele_cnt_all#479 as bigint))])
         +- HiveTableScan [geohash_of_latlng#472, mt_cnt#476, ele_cnt#477, mt_cnt_all#478, ele_cnt_all#479, dt#467], MetastoreRelation temp, test_geohash_match_parquet

@cenyuhai
Copy link
Contributor

I don't know wether my case can be optimized or not.

@cenyuhai
Copy link
Contributor

should sum(mt_cnt) and sum(ele_cnt) be compute again?

@stanzhai
Copy link
Contributor Author

@cenyuhai This is an optimize for physical plan, and your case can be optimized.

select dt,
geohash_of_latlng,
sum(mt_cnt),
sum(ele_cnt),
round(sum(mt_cnt) * 1.0 * 100 / sum(mt_cnt_all), 2),
round(sum(ele_cnt) * 1.0 * 100 / sum(ele_cnt_all), 2)
from values(1, 2, 3, 4, 5, 6) as (dt, geohash_of_latlng, mt_cnt, ele_cnt, mt_cnt_all, ele_cnt_all)
group by dt, geohash_of_latlng
order by dt, geohash_of_latlng limit 10

Before:

== Physical Plan ==
TakeOrderedAndProject(limit=10, orderBy=[dt#26 ASC NULLS FIRST,geohash_of_latlng#27 ASC NULLS FIRST], output=[dt#26,geohash_of_latlng#27,sum(mt_cnt)#38L,sum(ele_cnt)#39L,round((CAST((CAST((CAST(CAST(sum(CAST(mt_cnt AS BIGINT)) AS DECIMAL(20,0)) AS DECIMAL(21,1)) * CAST(1.0 AS DECIMAL(21,1))) AS DECIMAL(23,1)) * CAST(CAST(100 AS DECIMAL(23,1)) AS DECIMAL(23,1))) AS DECIMAL(38,2)) / CAST(CAST(sum(CAST(mt_cnt_all AS BIGINT)) AS DECIMAL(20,0)) AS DECIMAL(38,2))), 2)#40,round((CAST((CAST((CAST(CAST(sum(CAST(ele_cnt AS BIGINT)) AS DECIMAL(20,0)) AS DECIMAL(21,1)) * CAST(1.0 AS DECIMAL(21,1))) AS DECIMAL(23,1)) * CAST(CAST(100 AS DECIMAL(23,1)) AS DECIMAL(23,1))) AS DECIMAL(38,2)) / CAST(CAST(sum(CAST(ele_cnt_all AS BIGINT)) AS DECIMAL(20,0)) AS DECIMAL(38,2))), 2)#41])
+- *HashAggregate(keys=[dt#26, geohash_of_latlng#27], functions=[sum(cast(mt_cnt#28 as bigint)), sum(cast(ele_cnt#29 as bigint)), sum(cast(mt_cnt#28 as bigint)), sum(cast(mt_cnt_all#30 as bigint)), sum(cast(ele_cnt#29 as bigint)), sum(cast(ele_cnt_all#31 as bigint))])
   +- Exchange hashpartitioning(dt#26, geohash_of_latlng#27, 200)
      +- *HashAggregate(keys=[dt#26, geohash_of_latlng#27], functions=[partial_sum(cast(mt_cnt#28 as bigint)), partial_sum(cast(ele_cnt#29 as bigint)), partial_sum(cast(mt_cnt#28 as bigint)), partial_sum(cast(mt_cnt_all#30 as bigint)), partial_sum(cast(ele_cnt#29 as bigint)), partial_sum(cast(ele_cnt_all#31 as bigint))])
         +- LocalTableScan [dt#26, geohash_of_latlng#27, mt_cnt#28, ele_cnt#29, mt_cnt_all#30, ele_cnt_all#31]

After:

== Physical Plan ==
TakeOrderedAndProject(limit=10, orderBy=[dt#28 ASC NULLS FIRST,geohash_of_latlng#29 ASC NULLS FIRST], output=[dt#28,geohash_of_latlng#29,sum(mt_cnt)#34L,sum(ele_cnt)#35L,round((CAST((CAST((CAST(CAST(sum(CAST(mt_cnt AS BIGINT)) AS DECIMAL(20,0)) AS DECIMAL(21,1)) * CAST(1.0 AS DECIMAL(21,1))) AS DECIMAL(23,1)) * CAST(CAST(100 AS DECIMAL(23,1)) AS DECIMAL(23,1))) AS DECIMAL(38,2)) / CAST(CAST(sum(CAST(mt_cnt_all AS BIGINT)) AS DECIMAL(20,0)) AS DECIMAL(38,2))), 2)#36,round((CAST((CAST((CAST(CAST(sum(CAST(ele_cnt AS BIGINT)) AS DECIMAL(20,0)) AS DECIMAL(21,1)) * CAST(1.0 AS DECIMAL(21,1))) AS DECIMAL(23,1)) * CAST(CAST(100 AS DECIMAL(23,1)) AS DECIMAL(23,1))) AS DECIMAL(38,2)) / CAST(CAST(sum(CAST(ele_cnt_all AS BIGINT)) AS DECIMAL(20,0)) AS DECIMAL(38,2))), 2)#37])
+- *HashAggregate(keys=[dt#28, geohash_of_latlng#29], functions=[sum(cast(mt_cnt#30 as bigint)), sum(cast(ele_cnt#31 as bigint)), sum(cast(mt_cnt_all#32 as bigint)), sum(cast(ele_cnt_all#33 as bigint))])
   +- Exchange hashpartitioning(dt#28, geohash_of_latlng#29, 200)
      +- *HashAggregate(keys=[dt#28, geohash_of_latlng#29], functions=[partial_sum(cast(mt_cnt#30 as bigint)), partial_sum(cast(ele_cnt#31 as bigint)), partial_sum(cast(mt_cnt_all#32 as bigint)), partial_sum(cast(ele_cnt_all#33 as bigint))])
         +- LocalTableScan [dt#28, geohash_of_latlng#29, mt_cnt#30, ele_cnt#31, mt_cnt_all#32, ele_cnt_all#33]

AggregateExpression(
aggregateFunction,
mode,
isDistinct,
NamedExpression.newExprId)
ExprId(hashCode))
Copy link
Contributor

@cloud-fan cloud-fan Sep 22, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this is the right fix. Semantically the b0 and b1 in SELECT SUM(b) AS b0, SUM(b) AS b1 are different aggregate functions, so they should have different resultId.

It's kind of an optimization in aggregate planner, we should detect these semantically different but duplicated aggregate functions and only plan one of them.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agreed with @cloud-fan. This should be an optimization done in aggregate planner, instead of forcibly setting expr id here.

Copy link
Contributor Author

@stanzhai stanzhai Sep 25, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@cloud-fan @viirya I've tried to optimize in aggregate planner (https://github.com/apache/spark/blob/master/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala#L211).

      // A single aggregate expression might appear multiple times in resultExpressions.
      // In order to avoid evaluating an individual aggregate function multiple times, we'll
      // build a set of the distinct aggregate expressions and build a function which can
      // be used to re-write expressions so that they reference the single copy of the
      // aggregate function which actually gets computed.
      val aggregateExpressions = resultExpressions.flatMap { expr =>
        expr.collect {
          case agg: AggregateExpression =>
            val aggregateFunction = agg.aggregateFunction
            val state = if (aggregateFunction.resolved) {
              Seq(aggregateFunction.toString, aggregateFunction.dataType,
                aggregateFunction.nullable, agg.mode, agg.isDistinct)
            } else {
              Seq(aggregateFunction.toString, agg.mode, agg.isDistinct)
            }
            val hashCode = state.map(Objects.hashCode).foldLeft(0)((a, b) => 31 * a + b)
            (hashCode, agg)
        }
      }.groupBy(_._1).map { case (_, values) =>
        values.head._2
      }.toSeq

But it's difficult to distinguish between different typed aggregators without expr id. Current solution can works well for all of aggregate functions.

I'm not familiar with typed aggregators, any suggestions will be appreciated.

@viirya
Copy link
Member

viirya commented Sep 22, 2017

Regarding performance regression, I think you should post benchmark numbers.

@stanzhai
Copy link
Contributor Author

@viirya The problem is already obvious, and the same aggregate expression will be computed multi times. I will provide a benchmark result later.

@viirya
Copy link
Member

viirya commented Sep 22, 2017

I asked it because by considering subexpressionElimination, you may not actually run it multiple times. So the benchmark numbers can tell if your fix really improves the performance.

@stanzhai
Copy link
Contributor Author

@viirya
Benchmark code:

val N = 500L << 22
val benchmark = new Benchmark("agg", N)
val expressions = (0 until 50).map(i => s"sum(id) as r$i")

benchmark.addCase("agg with optimize", numIters = 2) { iter =>
  sparkSession.range(N).selectExpr(expressions: _*).collect()
}

benchmark.run()

Result:

Java HotSpot(TM) 64-Bit Server VM 1.8.0_91-b14 on Mac OS X 10.12.6
Intel(R) Core(TM) i5-4278U CPU @ 2.60GHz

agg:                                     Best/Avg Time(ms)    Rate(M/s)   Per Row(ns)   Relative
------------------------------------------------------------------------------------------------
agg with optimize                             1306 / 1354       1605.7           0.6       1.0X
agg without optimize                      121799 / 148115         17.2          58.1       1.0X

@viirya
Copy link
Member

viirya commented Sep 22, 2017

@stanzhai Thanks. I see. Because the aggregation functions are bound to individual buffer slots, they are recognized as different expressions and won't be eliminated.

@AmplabJenkins
Copy link

Can one of the admins verify this patch?

@cloud-fan
Copy link
Contributor

I believe this has been fixed, can we close it?

@stanzhai stanzhai closed this Jan 19, 2018
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
5 participants