-
Notifications
You must be signed in to change notification settings - Fork 28k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[SPARK-38679][CORE] Expose the number partitions in a stage to TaskContext #35995
Conversation
c86b83e
to
ecf7112
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
It looks like barrier execution can also use this api to simplify the implementation of |
7354419
to
b48a670
Compare
Thank you @jiangxb1987 for reviewing. Could you please take another look at the change? Had to exclude the new method from binary compatibility test. |
@cloud-fan @Ngone51 Although this one adds a new API, it's a pretty straightforward change. It looks pretty safe to me to backport into 3.3. What do you think? Also cc @MaxGekk |
Since this is either not a bug fix nor in the allow list https://lists.apache.org/thread/zrd7lcm5f5f3md7wffjy7x6w2pdmxxp7, we cannot just silently merge to branch-3.3. @zsxwing @vkorukanti Could you write an email to the thread in the dev list, and explain why we need this in 3.3 and cannot postpone to 3.4. |
@MaxGekk Thanks for the feedback. We will not merge this to 3.3. |
@@ -215,6 +215,8 @@ class BarrierTaskContext private[spark] ( | |||
|
|||
override def partitionId(): Int = taskContext.partitionId() | |||
|
|||
override def numPartitions(): Int = taskContext.numPartitions() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can remove lazy val numTasks
in this file and use numPartitions()
directly.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Updated in the latest commit.
Add a new api to expose total partition count in a task. so that the task knows what fraction of the computation is doing. With this extra information, users can generate 32bit unique int ids as below rather than using `monotonically_increasing_id` which generates 64bit long ids. ```scala val rdd = ... rdd.mapPartitions { rowsIter => val partitionId = TaskContext.get().partitionId() val numPartitions = TaskContext.get().numPartitions() var i = 0 rowsIter.map { row => val rowId = partitionId + i * numPartitions i += 1 (rowId, row) } } ``` Test: Added new unit tests to verify the number of partitions retrieved from TaskContext is expected.
b48a670
to
c975242
Compare
thanks, merging to master! |
What changes were proposed in this pull request?
Add a new api to expose total partition count in the stage belonging to the task in TaskContext,
Why are the changes needed?
Add a new api to expose total partition count in the stage belonging to the task in TaskContext, so that the task knows what fraction of the computation is doing.
With this extra information, users can generate 32bit unique int ids as below rather than using
monotonically_increasing_id
which generates 64bit long ids.Does this PR introduce any user-facing change?
Yes. We add a new API
TaskContext.numPartitions
.How was this patch tested?
Added new unit tests to verify the number of partitions retrieved from TaskContext is expected.