Skip to content

RoPE enhancements#1478

Merged
sudhakarsingh27 merged 28 commits intoNVIDIA:mainfrom
sudhakarsingh27:rope_enhancement
Apr 22, 2025
Merged

RoPE enhancements#1478
sudhakarsingh27 merged 28 commits intoNVIDIA:mainfrom
sudhakarsingh27:rope_enhancement

Conversation

@sudhakarsingh27
Copy link
Collaborator

@sudhakarsingh27 sudhakarsingh27 commented Feb 11, 2025

Description

TLDR;
Enable application of staggered rope embeddings to different sequences within the same batch.

During generation tasks, different sequences in a batch might have different start positions (technically different end positions as well but that's bounded by the max sequence length in the batch so something we can afford to ignore for now). This change simply modifies the rope kernel to apply the rope embeddings in a staggered manner to different sequences in the batch using an argument start_positions.

(The start_positions and related changes are directly adapted from #829 which was authored by @pggPL)

  1. start_positions is only intended to be used in generation/inference mode and works with sbhd/bshd/thd input tensor formats.
  2. start_positions is not intended for Context Parallelism use-cases as CP is not used during inference/generation. Although, it should be possible to support that as well but it's not the scope of this PR.

Fixes # (issue)

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

  • Non breaking changes to apply_rotary_pos_emb function but this is non breaking since start_positions is a default kwarg here.
  • Breaking changes to FusedRoPEFunc and all the extensions/kernels that are called internally by this function.

Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
…make staggered rope application faster

Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
@sudhakarsingh27 sudhakarsingh27 changed the title RoPE functionality enhancements RoPE enhancements Feb 11, 2025
@sudhakarsingh27 sudhakarsingh27 self-assigned this Feb 11, 2025
@cyanguwa cyanguwa added the 2.3.0 label Mar 12, 2025
@cyanguwa cyanguwa requested a review from yaox12 March 31, 2025 18:35
Copy link
Member

@yaox12 yaox12 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Generally LGTM. With #1626, the fused and unfused versions have the same support matrix for RoPE options, so it's better to merge the unfused implementation to apply_rotary_pos_emb in rope.py.

@cyanguwa
Copy link
Collaborator

cyanguwa commented Apr 11, 2025

I agree with @yaox12's comments. I think we need to add some documentation about our support matrix for starting_positions and how to use it (None or [s,b,1,d] freqs). We should expand our support to all three qkv_formats and CP/non-CP cases as well.

Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
pre-commit-ci bot and others added 7 commits April 16, 2025 20:41
@sudhakarsingh27 sudhakarsingh27 requested a review from yaox12 April 18, 2025 07:36
@sudhakarsingh27
Copy link
Collaborator Author

/te-ci pytorch

sudhakarsingh27 and others added 7 commits April 20, 2025 16:36
Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
sudhakarsingh27 and others added 4 commits April 21, 2025 13:14
Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
@sudhakarsingh27
Copy link
Collaborator Author

/te-ci pytorch

yaox12
yaox12 previously approved these changes Apr 22, 2025
Copy link
Member

@yaox12 yaox12 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM.

Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
@sudhakarsingh27 sudhakarsingh27 merged commit 94bff09 into NVIDIA:main Apr 22, 2025
11 checks passed
KshitijLakhani pushed a commit that referenced this pull request Apr 23, 2025
* add support for `sb1d` freqs tensor in Fused RoPE

Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>

* add `start_positions` variable to `apply_rotary_pos_emb` function to make staggered rope application faster

Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* add pytorch path for `start_positions` and corresponding tests

Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* add tests for start_positions with thd

Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fixes from feedback

Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* remove start_positions from backward pass

Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* from feedback

Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* make notes shorter

Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>

---------

Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants