[MoE][PyTorch] Add prob permutation to mask-based MoE permutation; Fix FP8 related codes#1468
Conversation
Signed-off-by: Hongxiao Bai <hongxiaob@nvidia.com>
for more information, see https://pre-commit.ci
Signed-off-by: Hongxiao Bai <hongxiaob@nvidia.com>
Signed-off-by: Hongxiao Bai <hongxiaob@nvidia.com>
Signed-off-by: Hongxiao Bai <hongxiaob@nvidia.com>
for more information, see https://pre-commit.ci
Signed-off-by: Hongxiao Bai <hongxiaob@nvidia.com>
Signed-off-by: Hongxiao Bai <hongxiaob@nvidia.com>
|
/te-ci pytorch |
| print(f"chunk sort\t\tbwd: pytorch: {t1:.3f} ms, TE: {t2:.3f} ms") | ||
|
|
||
|
|
||
| def _test_permutation_mask_map_alongside_probs( |
There was a problem hiding this comment.
Hi,
I found this test almost identical to the _test_permutation_mask_map except that we have TP thus involves calling te_sort_chunks_by_index.
Is it correct that the _test_permutation_mask_map is the special case of _test_permutation_mask_map_alongside_probs, in which the tp_size=1?
If yes, I suggest we combine these two test and eliminate code duplications.
There was a problem hiding this comment.
Hi, it is not exactly same.
For _test_permutation_mask_map, the probs are applied to the unpermutation and the results of permute and unpermute are verified separately.
For _test_permutation_mask_map_alongside_probs, it is roughly an end-to-end test. The probs are applied to the unpermutation for the PyTorch version and applied to the permute output for the TE version. We want to make sure that the two methods can have same final output value.
| merging_prob = tl.load(merging_probs_ptr + merging_prob_off).to(compute_type) | ||
| inp *= merging_prob | ||
| accumulator += inp | ||
| if PERMUTE_PROBS: |
There was a problem hiding this comment.
Hi,
I wonder in which case we can have PERMUTE_PROBS != WITH_MERGING_PROBS.
I interpret the code as when we have WITH_MERGING_PROBS=True, we do the accumulation then reset the prob to zeros in the if PERMUTE_PROBS block. Then in which case we don't need to reset them?
There was a problem hiding this comment.
Hi, Phuong,
Yes, it is easily confused by this part. I just added a graph to this PR description to depict the usage.
The previous codes are for the right workflow and this PR is to support the left workflow. For the right workflow, we don't permute the probs tensor and pass it directly to the unpermute operation. For the left workflow, we permute the probs tensor and apply it on the GroupedGEMM; for this case, no probs is passed to the unpermute operation.
- Left: moe_permute(probs=probs), moe_unpermute(merging_probs=None)
- Right: moe_permute(probs=None), moe_unpermute(merging_probs=probs)
For this kernel, it is used by both permute_bwd and unpermute_fwd:
- permute_bwd in the left workflow:
PERMUTE_PROBS=TrueandWITH_MERGING_PROBS=False - permute_bwd in the right workflow:
PERMUTE_PROBS=FalseandWITH_MERGING_PROBS=False - unpermute_fwd in the left workflow:
PERMUTE_PROBS=FalseandWITH_MERGING_PROBS=False - unpermute_fwd in the left workflow:
PERMUTE_PROBS=FalseandWITH_MERGING_PROBS=True
So, these two args would not be True at the same time.
timmoon10
left a comment
There was a problem hiding this comment.
Mostly looks reasonable, but we should make sure the user-facing APIs are not too messy.
Signed-off-by: Hongxiao Bai <hongxiaob@nvidia.com>
Signed-off-by: Hongxiao Bai <hongxiaob@nvidia.com>
for more information, see https://pre-commit.ci
|
/te-ci pytorch |
…x FP8 related codes (#1468) * add prob permute; fix fp8tensor Signed-off-by: Hongxiao Bai <hongxiaob@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * revert unnecessary changes in UT Signed-off-by: Hongxiao Bai <hongxiaob@nvidia.com> * remove unnecessary probs dtype convert Signed-off-by: Hongxiao Bai <hongxiaob@nvidia.com> * keep the output nums if probs is not provided Signed-off-by: Hongxiao Bai <hongxiaob@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * refine the doc string Signed-off-by: Hongxiao Bai <hongxiaob@nvidia.com> * fix lint Signed-off-by: Hongxiao Bai <hongxiaob@nvidia.com> * use fp32 compute type Signed-off-by: Hongxiao Bai <hongxiaob@nvidia.com> * style fix Signed-off-by: Hongxiao Bai <hongxiaob@nvidia.com> * fix empty input return Signed-off-by: Hongxiao Bai <hongxiaob@nvidia.com> * separate prob related functions out Signed-off-by: Hongxiao Bai <hongxiaob@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Hongxiao Bai <hongxiaob@nvidia.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Xin Yao <xiny@nvidia.com> Co-authored-by: Phuong Nguyen <phuonguyen@nvidia.com>
Description
Add probs permutation codes to the mask-based permutation. With this, we can apply the probs to the MoE expert MLP rather than to the unpermutation to avoid saving huge input tensor of unpermute function.
Fix FP8 Tensor usages in the permutation codes since TE 2.0 has some breaking changes on FP8 Tensor interfaces.
Depiction for probs application:

Fixes # (issue)
Type of change
Changes
Please list the changes introduced in this PR:
moe_permuteandmoe_sort_chunks_by_indexChecklist: