New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[SPARK-41513][SQL] Implement an accumulator to collect per mapper row count metrics #39057
Conversation
* | ||
* @since 3.4.0 | ||
*/ | ||
class MapperRowCounter extends AccumulatorV2[jl.Long, java.util.List[java.util.List[jl.Long]]] { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can put it in the sql module, or probably the same file with shuffle node.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Moved to sql module.
Can one of the admins verify this patch? |
sql/core/src/main/scala/org/apache/spark/sql/util/MapperRowCounter.scala
Outdated
Show resolved
Hide resolved
|
||
def setPartitionId(id: Long): Unit = { | ||
this.synchronized { | ||
val p = id |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what does this do?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hmm I can remove this assignment and use id
directly.
sql/core/src/main/scala/org/apache/spark/sql/util/MapperRowCounter.scala
Outdated
Show resolved
Hide resolved
* | ||
* @since 3.4.0 | ||
*/ | ||
class MapperRowCounter extends AccumulatorV2[jl.Long, java.util.List[(jl.Long, jl.Long)]] { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
TaskContext.partitionId
is int, do we really need long here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The firs long here (aka IN
) defines
/**
* Takes the inputs and accumulates.
*/
def add(v: IN): Unit
So this should be a long for the row count?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
row count should be long, map index should be int.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think I wrongly understand the template type.
Changed to Integer for the map index.
|
||
override def add(v: jl.Long): Unit = { | ||
this.synchronized { | ||
assert(!isZero, "agg must have been initialized") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
assert(getOrCreate.size == 1)?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
if (isZero) { | ||
getOrCreate.add((id, 0)) | ||
} else { | ||
val n = getOrCreate.get(0)._2 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
when can we hit this branch?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is because I don't know the invocation sequence of the accumulator APIs on the executor side. So I added this branch for safe.
If the setPartitionId
is always called before any add
, then we can does an assert on isZero and remove this branch.
What do you think?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
let's add an assert
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
thanks, merging to master! |
What changes were proposed in this pull request?
In current Spark optimizer, a single partition shuffle might be created for a limit if this limit is not the last non-action operation (e.g. a filter following the limit and the data size exceeds a threshold). There is a possibility that the previous output partitions before go into this limit are sorted. The single partition shuffle approach has a correctness bug in this case: shuffle read partitions could be out of partition order and the limit exec just take the first limit rows which could lose the order thus result into wrong result. This is a shuffle so it is relatively costly. Meanwhile, to correct this bug, a native solution is to sort all the data fed into limit again, which is another overhead.
So we propose a row count based AQE algorithm that optimizes this problem by two folds:
Note that 1. is only applied for the sorted partition case where 2. is applied for general single partition shuffle + limit case
The algorithm works as the following:
This is the first step to implement the idea in https://issues.apache.org/jira/browse/SPARK-41512, which is to implement a row count accumulator that will be used to collect row count metrics.
Why are the changes needed?
Optimization algorithm for global limit with single partition shuffle
Does this PR introduce any user-facing change?
NO
How was this patch tested?
UT