Skip to content

[PyTorch][Common] Refactor RoPE#1626

Merged
yaox12 merged 8 commits intoNVIDIA:mainfrom
yaox12:xiny/refactor_rope
Apr 7, 2025
Merged

[PyTorch][Common] Refactor RoPE#1626
yaox12 merged 8 commits intoNVIDIA:mainfrom
yaox12:xiny/refactor_rope

Conversation

@yaox12
Copy link
Member

@yaox12 yaox12 commented Mar 31, 2025

Description

Refactor RoPE to

  1. Now both fused and non-fused RoPE support CP + SBHD/BSDH/THD.
  2. When CP > 1, the freqs tensor is required to be the full tensor.
    1. Previously, the freqs tensor is supposed to be the full tensor with the THD format while sliced (outside of TE) with other formats (sbhd/bshd). This PR unifies them.
    2. This is not a breaking change, because the previous apply_rotary_pos_emb doesn't accept cp_size and cp_rank for sbhd and bshd, so passing sliced freqs with cp_size = 1 and cp_rank = 0 should still work as before.
  3. Both fused and non-fused RoPE support interleaved mode.

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Change A
  • Change B

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

yaox12 added 2 commits March 30, 2025 18:16
Signed-off-by: Xin Yao <xiny@nvidia.com>
Signed-off-by: Xin Yao <xiny@nvidia.com>
@yaox12 yaox12 force-pushed the xiny/refactor_rope branch from 2525377 to d727f1c Compare March 31, 2025 01:17
Signed-off-by: Xin Yao <xiny@nvidia.com>
@yaox12 yaox12 requested review from cyanguwa and xrennvidia March 31, 2025 02:20
@tomlifu
Copy link
Contributor

tomlifu commented Mar 31, 2025

Rotary interleaved part looks good to me.

@yaox12
Copy link
Member Author

yaox12 commented Apr 1, 2025

/te-ci pytorch

@cyanguwa cyanguwa added the 2.3.0 label Apr 1, 2025
@yaox12 yaox12 mentioned this pull request Apr 2, 2025
6 tasks
if tensor_format == "sbhd":
output = tex.fused_rope_forward(t, freqs, False)
elif tensor_format == "bshd":
output = tex.fused_rope_forward(t.transpose(0, 1), freqs, True).transpose(0, 1)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

curious where the t.transpose happens now

Copy link
Member Author

@yaox12 yaox12 Apr 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Previously, the C++ interface fused_rope_forward only accepts sbhd format, so for bshd we hack it by transposing bshd to sbhd and then transposing the result back. There was a flag called transpose_output_memory to make sure the transpose of result doesn't really do a memory copy. This method is not intuitive. Now I add a new argument qkv_format to support all formats natively.

yaox12 added 3 commits April 6, 2025 18:58
Signed-off-by: Xin Yao <xiny@nvidia.com>
Signed-off-by: Xin Yao <xiny@nvidia.com>
@yaox12
Copy link
Member Author

yaox12 commented Apr 7, 2025

/te-ci pytorch

@yaox12
Copy link
Member Author

yaox12 commented Apr 7, 2025

CI passed except for irrelevant failures and Blackwell runners got stuck.

@yaox12 yaox12 merged commit ba605f1 into NVIDIA:main Apr 7, 2025
11 of 12 checks passed
wdykas pushed a commit to wdykas/TransformerEngine that referenced this pull request Apr 14, 2025
* refactor to add cp support for sbhd/bshd

Signed-off-by: Xin Yao <xiny@nvidia.com>

* support interleaved

Signed-off-by: Xin Yao <xiny@nvidia.com>

* format

Signed-off-by: Xin Yao <xiny@nvidia.com>

* add interleaved to RotaryPositionEmbedding in test

Signed-off-by: Xin Yao <xiny@nvidia.com>

* update

Signed-off-by: Xin Yao <xiny@nvidia.com>

* merge sbhd/bshd and thd functions

Signed-off-by: Xin Yao <xiny@nvidia.com>

---------

Signed-off-by: Xin Yao <xiny@nvidia.com>
Signed-off-by: Peter Dykas <wdykas@nvidia.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants