-
Notifications
You must be signed in to change notification settings - Fork 284
moe triton kernel use tma. #1100
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Summary of ChangesHello @hiworldwzj, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request significantly enhances the performance of Mixture-of-Experts (MoE) Triton kernels by introducing Tensor Memory Access (TMA) capabilities. It refactors the data handling within the Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces support for Tensor Memory Access (TMA) in the MoE Triton kernels, aiming to enhance performance on modern GPU architectures. The changes are substantial, involving a refactoring of the grouped_matmul kernel and its associated helper functions. Key modifications include adding a utility to detect TMA support, restructuring data to accommodate TMA requirements like token start indices, and conditionally enabling TMA for tensor loading within the kernel. My review focuses on improving code clarity, robustness, and maintainability. The overall implementation appears solid and well-aligned with the goal of enabling TMA.
| a_scale_ptrs = token_scale_ptr + (a_m_index // topk_num) * (token_stride_0 // block_size_k) | ||
| token_scale_stride0 = token_stride_0 // block_size_k | ||
| if TOKEN_INPUT_USE_TMA: | ||
| assert MUL_ROUTED_WEIGHT is True |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Inside a Triton kernel, Python's assert statement may not behave as expected or could be ignored by the compiler. For compile-time constants like MUL_ROUTED_WEIGHT, it's better to use tl.static_assert to ensure the check is performed during kernel compilation.
| assert MUL_ROUTED_WEIGHT is True | |
| tl.static_assert(MUL_ROUTED_WEIGHT is True, "TOKEN_INPUT_USE_TMA is only supported for down projection (MUL_ROUTED_WEIGHT=True)") |
| if TOKEN_INPUT_USE_TMA: | ||
| from triton.tools.tensor_descriptor import TensorDescriptor | ||
|
|
||
| block_num = triton.cdiv(n, BLOCK_SIZE_N) * mblocks_to_expert_id.shape[0] | ||
| token_desc = TensorDescriptor( | ||
| token_inputs, token_inputs.shape, token_inputs.stride(), [BLOCK_SIZE_M, BLOCK_SIZE_K] | ||
| ) | ||
| else: | ||
| token_desc = None | ||
|
|
||
| if WEIGHT_USE_TMA: | ||
| from triton.tools.tensor_descriptor import TensorDescriptor | ||
|
|
||
| weight_desc = TensorDescriptor( | ||
| expert_weights, expert_weights.shape, expert_weights.stride(), [1, BLOCK_SIZE_N, BLOCK_SIZE_K] | ||
| ) | ||
| else: | ||
| weight_desc = None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The import from triton.tools.tensor_descriptor import TensorDescriptor is duplicated. It can be moved to a higher scope (e.g., inside the if support_tma: block at line 752) to avoid redundancy. After moving the import, this section can be simplified.
if TOKEN_INPUT_USE_TMA:
token_desc = TensorDescriptor(
token_inputs, token_inputs.shape, token_inputs.stride(), [BLOCK_SIZE_M, BLOCK_SIZE_K]
)
else:
token_desc = None
if WEIGHT_USE_TMA:
weight_desc = TensorDescriptor(
expert_weights, expert_weights.shape, expert_weights.stride(), [1, BLOCK_SIZE_N, BLOCK_SIZE_K]
)
else:
weight_desc = None| try: | ||
| from triton.tools.tensor_descriptor import TensorDescriptor | ||
|
|
||
| support_tma = torch.cuda.get_device_capability() >= (9, 0) | ||
| if support_tma: | ||
| logger.info("triton support tensor_descriptor") | ||
| return True | ||
| else: | ||
| assert False | ||
| except: | ||
| logger.info("triton not support tensor_descriptor") | ||
| return False |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The use of assert False to control logic flow, combined with a bare except: clause, is not ideal. This can make the code harder to understand and might hide unexpected errors. A more explicit and clear approach would be to use a standard if/else structure and catch specific exceptions like ImportError.
| try: | |
| from triton.tools.tensor_descriptor import TensorDescriptor | |
| support_tma = torch.cuda.get_device_capability() >= (9, 0) | |
| if support_tma: | |
| logger.info("triton support tensor_descriptor") | |
| return True | |
| else: | |
| assert False | |
| except: | |
| logger.info("triton not support tensor_descriptor") | |
| return False | |
| try: | |
| from triton.tools.tensor_descriptor import TensorDescriptor | |
| if torch.cuda.get_device_capability() >= (9, 0): | |
| logger.info("triton support tensor_descriptor") | |
| return True | |
| else: | |
| logger.info("triton not support tensor_descriptor: requires CUDA capability >= 9.0") | |
| return False | |
| except ImportError: | |
| logger.info("triton not support tensor_descriptor: failed to import TensorDescriptor") | |
| return False |
No description provided.