Merged
Conversation
Signed-off-by: Xin Yao <xiny@nvidia.com>
Signed-off-by: Xin Yao <xiny@nvidia.com>
2525377 to
d727f1c
Compare
yaox12
commented
Mar 31, 2025
Contributor
|
Rotary interleaved part looks good to me. |
xrennvidia
approved these changes
Apr 1, 2025
Signed-off-by: Xin Yao <xiny@nvidia.com>
Member
Author
|
/te-ci pytorch |
cyanguwa
approved these changes
Apr 3, 2025
| if tensor_format == "sbhd": | ||
| output = tex.fused_rope_forward(t, freqs, False) | ||
| elif tensor_format == "bshd": | ||
| output = tex.fused_rope_forward(t.transpose(0, 1), freqs, True).transpose(0, 1) |
Collaborator
There was a problem hiding this comment.
curious where the t.transpose happens now
Member
Author
There was a problem hiding this comment.
Previously, the C++ interface fused_rope_forward only accepts sbhd format, so for bshd we hack it by transposing bshd to sbhd and then transposing the result back. There was a flag called transpose_output_memory to make sure the transpose of result doesn't really do a memory copy. This method is not intuitive. Now I add a new argument qkv_format to support all formats natively.
Signed-off-by: Xin Yao <xiny@nvidia.com>
Member
Author
|
/te-ci pytorch |
Member
Author
|
CI passed except for irrelevant failures and Blackwell runners got stuck. |
wdykas
pushed a commit
to wdykas/TransformerEngine
that referenced
this pull request
Apr 14, 2025
* refactor to add cp support for sbhd/bshd Signed-off-by: Xin Yao <xiny@nvidia.com> * support interleaved Signed-off-by: Xin Yao <xiny@nvidia.com> * format Signed-off-by: Xin Yao <xiny@nvidia.com> * add interleaved to RotaryPositionEmbedding in test Signed-off-by: Xin Yao <xiny@nvidia.com> * update Signed-off-by: Xin Yao <xiny@nvidia.com> * merge sbhd/bshd and thd functions Signed-off-by: Xin Yao <xiny@nvidia.com> --------- Signed-off-by: Xin Yao <xiny@nvidia.com> Signed-off-by: Peter Dykas <wdykas@nvidia.com>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Description
Refactor RoPE to
freqstensor is required to be the full tensor.freqstensor is supposed to be the full tensor with the THD format while sliced (outside of TE) with other formats (sbhd/bshd). This PR unifies them.apply_rotary_pos_embdoesn't acceptcp_sizeandcp_rankfor sbhd and bshd, so passing slicedfreqswithcp_size = 1andcp_rank = 0should still work as before.interleavedmode.Type of change
Changes
Please list the changes introduced in this PR:
Checklist: