Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-38679][CORE] Expose the number partitions in a stage to TaskContext #35995

Closed
wants to merge 5 commits into from

Conversation

vkorukanti
Copy link
Member

@vkorukanti vkorukanti commented Mar 29, 2022

What changes were proposed in this pull request?

Add a new api to expose total partition count in the stage belonging to the task in TaskContext,

Why are the changes needed?

Add a new api to expose total partition count in the stage belonging to the task in TaskContext, so that the task knows what fraction of the computation is doing.

With this extra information, users can generate 32bit unique int ids as below rather than using monotonically_increasing_id which generates 64bit long ids.

rdd.mapPartitions { rowsIter =>
  val partitionId = TaskContext.get().partitionId()
  val numPartitions = TaskContext.get().numPartitions()
  var i = 0
  rowsIter.map { row =>
    val rowId = partitionId + i * numPartitions
    i += 1
    (rowId, row)
  }
}

Does this PR introduce any user-facing change?

Yes. We add a new API TaskContext.numPartitions.

How was this patch tested?

Added new unit tests to verify the number of partitions retrieved from TaskContext is expected.

@HeartSaVioR HeartSaVioR changed the title [SPARK-38679] Expose the number partitions in a stage to TaskContext [SPARK-38679][CORE] Expose the number partitions in a stage to TaskContext Mar 29, 2022
Copy link
Contributor

@jiangxb1987 jiangxb1987 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@jiangxb1987
Copy link
Contributor

It looks like barrier execution can also use this api to simplify the implementation of numTasks : )

@vkorukanti
Copy link
Member Author

Thank you @jiangxb1987 for reviewing. Could you please take another look at the change? Had to exclude the new method from binary compatibility test.

@zsxwing
Copy link
Member

zsxwing commented Mar 29, 2022

@cloud-fan @Ngone51 Although this one adds a new API, it's a pretty straightforward change. It looks pretty safe to me to backport into 3.3. What do you think? Also cc @MaxGekk

@MaxGekk
Copy link
Member

MaxGekk commented Mar 29, 2022

What do you think? Also cc @MaxGekk

Since this is either not a bug fix nor in the allow list https://lists.apache.org/thread/zrd7lcm5f5f3md7wffjy7x6w2pdmxxp7, we cannot just silently merge to branch-3.3. @zsxwing @vkorukanti Could you write an email to the thread in the dev list, and explain why we need this in 3.3 and cannot postpone to 3.4.

@zsxwing
Copy link
Member

zsxwing commented Mar 29, 2022

@MaxGekk Thanks for the feedback. We will not merge this to 3.3.

@@ -215,6 +215,8 @@ class BarrierTaskContext private[spark] (

override def partitionId(): Int = taskContext.partitionId()

override def numPartitions(): Int = taskContext.numPartitions()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can remove lazy val numTasks in this file and use numPartitions() directly.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated in the latest commit.

Add a new api to expose total partition count in a task. so that the task knows what fraction of the computation is doing.

With this extra information, users can generate 32bit unique int ids as below rather than using `monotonically_increasing_id` which generates 64bit long ids.

```scala
val rdd = ...
rdd.mapPartitions { rowsIter =>
  val partitionId = TaskContext.get().partitionId()
  val numPartitions = TaskContext.get().numPartitions()
  var i = 0
  rowsIter.map { row =>
    val rowId = partitionId + i * numPartitions
    i += 1
    (rowId, row)
  }
}
```

Test: Added new unit tests to verify the number of partitions retrieved from TaskContext is expected.
@cloud-fan
Copy link
Contributor

thanks, merging to master!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
8 participants