Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix fused-moe in triton2.2.0 #1654

Merged
merged 1 commit into from
May 24, 2024
Merged

fix fused-moe in triton2.2.0 #1654

merged 1 commit into from
May 24, 2024

Conversation

grimoire
Copy link
Collaborator

Thanks for your contribution and we appreciate it a lot. The following instructions would make your pull request more healthy and more easily receiving feedbacks. If you do not understand some items, don't worry, just make the pull request and seek help from maintainers.

Motivation

Please describe the motivation of this PR and the goal you want to achieve through this PR.

Modification

Please briefly describe what modification is made in this PR.

BC-breaking (Optional)

Does the modification introduce changes that break the backward-compatibility of the downstream repositories?
If so, please describe how it breaks the compatibility and how the downstream projects should modify their code to keep compatibility with this PR.

Use cases (Optional)

If this PR introduces a new feature, it is better to list some use cases here, and update the documentation.

Checklist

  1. Pre-commit or other linting tools are used to fix the potential lint issues.
  2. The modification is covered by complete unit tests. If not, please add more unit tests to ensure the correctness.
  3. If the modification has a dependency on downstream projects of a newer version, this PR should be tested with all supported versions of downstream projects.
  4. The documentation has been modified accordingly, like docstring or example tutorials.

@@ -269,7 +269,7 @@ def _start_end_kernel(TopkIdx, SortedIdx, ExpStart, ExpEnd,
sidx = tl.load(SortedIdx + sidx_off, mask=sidx_mask, other=0)
tidx = tl.load(TopkIdx + sidx, mask=sidx_mask, other=num_experts)
tidx_mask = tidx == exp_id
cnt += tl.sum(tidx_mask)
cnt += tl.sum(tidx_mask.to(tl.int32))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it compatible with triton 2.1.0?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, tested.

@zhulinJulia24
Copy link
Collaborator

fixed

Copy link
Collaborator

@zhulinJulia24 zhulinJulia24 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@lvhan028 lvhan028 merged commit cb59a8e into InternLM:main May 24, 2024
5 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants