Skip to content

[MoE][PyTorch] Add prob permutation to mask-based MoE permutation; Fix FP8 related codes#1468

Merged
timmoon10 merged 15 commits into
NVIDIA:mainfrom
hxbai:permute_probs
Feb 18, 2025
Merged

[MoE][PyTorch] Add prob permutation to mask-based MoE permutation; Fix FP8 related codes#1468
timmoon10 merged 15 commits into
NVIDIA:mainfrom
hxbai:permute_probs

Conversation

@hxbai
Copy link
Copy Markdown
Contributor

@hxbai hxbai commented Feb 9, 2025

Description

Add probs permutation codes to the mask-based permutation. With this, we can apply the probs to the MoE expert MLP rather than to the unpermutation to avoid saving huge input tensor of unpermute function.

Fix FP8 Tensor usages in the permutation codes since TE 2.0 has some breaking changes on FP8 Tensor interfaces.

Depiction for probs application:
image

Fixes # (issue)

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Add probs permutation to the functions of moe_permute and moe_sort_chunks_by_index
  • Fix FP8 related codes

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

hxbai and others added 9 commits February 9, 2025 07:17
Signed-off-by: Hongxiao Bai <hongxiaob@nvidia.com>
Signed-off-by: Hongxiao Bai <hongxiaob@nvidia.com>
Signed-off-by: Hongxiao Bai <hongxiaob@nvidia.com>
Signed-off-by: Hongxiao Bai <hongxiaob@nvidia.com>
Signed-off-by: Hongxiao Bai <hongxiaob@nvidia.com>
Signed-off-by: Hongxiao Bai <hongxiaob@nvidia.com>
Signed-off-by: Hongxiao Bai <hongxiaob@nvidia.com>
@yaox12 yaox12 mentioned this pull request Feb 13, 2025
13 tasks
@yaox12
Copy link
Copy Markdown
Member

yaox12 commented Feb 13, 2025

/te-ci pytorch

print(f"chunk sort\t\tbwd: pytorch: {t1:.3f} ms, TE: {t2:.3f} ms")


def _test_permutation_mask_map_alongside_probs(
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi,

I found this test almost identical to the _test_permutation_mask_map except that we have TP thus involves calling te_sort_chunks_by_index.

Is it correct that the _test_permutation_mask_map is the special case of _test_permutation_mask_map_alongside_probs, in which the tp_size=1?

If yes, I suggest we combine these two test and eliminate code duplications.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi, it is not exactly same.

For _test_permutation_mask_map, the probs are applied to the unpermutation and the results of permute and unpermute are verified separately.

For _test_permutation_mask_map_alongside_probs, it is roughly an end-to-end test. The probs are applied to the unpermutation for the PyTorch version and applied to the permute output for the TE version. We want to make sure that the two methods can have same final output value.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. Thank you!

merging_prob = tl.load(merging_probs_ptr + merging_prob_off).to(compute_type)
inp *= merging_prob
accumulator += inp
if PERMUTE_PROBS:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi,

I wonder in which case we can have PERMUTE_PROBS != WITH_MERGING_PROBS.
I interpret the code as when we have WITH_MERGING_PROBS=True, we do the accumulation then reset the prob to zeros in the if PERMUTE_PROBS block. Then in which case we don't need to reset them?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi, Phuong,

Yes, it is easily confused by this part. I just added a graph to this PR description to depict the usage.

The previous codes are for the right workflow and this PR is to support the left workflow. For the right workflow, we don't permute the probs tensor and pass it directly to the unpermute operation. For the left workflow, we permute the probs tensor and apply it on the GroupedGEMM; for this case, no probs is passed to the unpermute operation.

  • Left: moe_permute(probs=probs), moe_unpermute(merging_probs=None)
  • Right: moe_permute(probs=None), moe_unpermute(merging_probs=probs)

For this kernel, it is used by both permute_bwd and unpermute_fwd:

  • permute_bwd in the left workflow: PERMUTE_PROBS=True and WITH_MERGING_PROBS=False
  • permute_bwd in the right workflow: PERMUTE_PROBS=False and WITH_MERGING_PROBS=False
  • unpermute_fwd in the left workflow: PERMUTE_PROBS=False and WITH_MERGING_PROBS=False
  • unpermute_fwd in the left workflow: PERMUTE_PROBS=False and WITH_MERGING_PROBS=True

So, these two args would not be True at the same time.

Copy link
Copy Markdown
Member

@timmoon10 timmoon10 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mostly looks reasonable, but we should make sure the user-facing APIs are not too messy.

Comment thread transformer_engine/pytorch/permutation.py
Comment thread transformer_engine/pytorch/permutation.py
Comment thread transformer_engine/pytorch/permutation.py
Comment thread tests/pytorch/test_permutation.py Outdated
Signed-off-by: Hongxiao Bai <hongxiaob@nvidia.com>
Signed-off-by: Hongxiao Bai <hongxiaob@nvidia.com>
@ptrendx ptrendx added the 2.1.0 label Feb 15, 2025
@phu0ngng
Copy link
Copy Markdown
Collaborator

/te-ci pytorch

Copy link
Copy Markdown
Member

@timmoon10 timmoon10 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Test failures in pipeline 24020411 are unrelated.

@timmoon10 timmoon10 merged commit eb9857d into NVIDIA:main Feb 18, 2025
ptrendx pushed a commit that referenced this pull request Feb 19, 2025
…x FP8 related codes (#1468)

* add prob permute; fix fp8tensor

Signed-off-by: Hongxiao Bai <hongxiaob@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* revert unnecessary changes in UT

Signed-off-by: Hongxiao Bai <hongxiaob@nvidia.com>

* remove unnecessary probs dtype convert

Signed-off-by: Hongxiao Bai <hongxiaob@nvidia.com>

* keep the output nums if probs is not provided

Signed-off-by: Hongxiao Bai <hongxiaob@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* refine the doc string

Signed-off-by: Hongxiao Bai <hongxiaob@nvidia.com>

* fix lint

Signed-off-by: Hongxiao Bai <hongxiaob@nvidia.com>

* use fp32 compute type

Signed-off-by: Hongxiao Bai <hongxiaob@nvidia.com>

* style fix

Signed-off-by: Hongxiao Bai <hongxiaob@nvidia.com>

* fix empty input return

Signed-off-by: Hongxiao Bai <hongxiaob@nvidia.com>

* separate prob related functions out

Signed-off-by: Hongxiao Bai <hongxiaob@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Signed-off-by: Hongxiao Bai <hongxiaob@nvidia.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Xin Yao <xiny@nvidia.com>
Co-authored-by: Phuong Nguyen <phuonguyen@nvidia.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants