Skip to content

Conversation

@hiworldwzj
Copy link
Collaborator

No description provided.

@gemini-code-assist
Copy link

Summary of Changes

Hello @hiworldwzj, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly enhances the performance of Mixture-of-Experts (MoE) Triton kernels by introducing Tensor Memory Access (TMA) capabilities. It refactors the data handling within the moe_align2_kernel to consolidate expert and token information into a single structure and implements conditional TMA-based data loading in grouped_matmul_kernel. A new utility function detects hardware support for TMA, ensuring these optimizations are applied only where beneficial, leading to more efficient memory operations on modern GPU architectures.

Highlights

  • TMA Integration for MoE Kernels: The core change involves integrating Tensor Memory Access (TMA) into the Triton kernels for Mixture-of-Experts (MoE) operations, specifically within grouped_matmul_kernel.
  • Consolidated Data Structure: The moe_align2_kernel now uses a single mblocks_to_tuple_info tensor to store expert ID, M-index, and token start index, simplifying data passing and improving data locality.
  • Dynamic TMA Activation: The system dynamically checks for Triton's TensorDescriptor support and GPU compute capability (SM 9.0+) to enable TMA, optimizing memory access for compatible hardware.
  • Conditional Data Loading: grouped_matmul_kernel now conditionally loads token inputs and expert weights using TMA when available, potentially improving performance by leveraging specialized hardware units.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces support for Tensor Memory Access (TMA) in the MoE Triton kernels, aiming to enhance performance on modern GPU architectures. The changes are substantial, involving a refactoring of the grouped_matmul kernel and its associated helper functions. Key modifications include adding a utility to detect TMA support, restructuring data to accommodate TMA requirements like token start indices, and conditionally enabling TMA for tensor loading within the kernel. My review focuses on improving code clarity, robustness, and maintainability. The overall implementation appears solid and well-aligned with the goal of enabling TMA.

a_scale_ptrs = token_scale_ptr + (a_m_index // topk_num) * (token_stride_0 // block_size_k)
token_scale_stride0 = token_stride_0 // block_size_k
if TOKEN_INPUT_USE_TMA:
assert MUL_ROUTED_WEIGHT is True

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Inside a Triton kernel, Python's assert statement may not behave as expected or could be ignored by the compiler. For compile-time constants like MUL_ROUTED_WEIGHT, it's better to use tl.static_assert to ensure the check is performed during kernel compilation.

Suggested change
assert MUL_ROUTED_WEIGHT is True
tl.static_assert(MUL_ROUTED_WEIGHT is True, "TOKEN_INPUT_USE_TMA is only supported for down projection (MUL_ROUTED_WEIGHT=True)")

Comment on lines 771 to 787
if TOKEN_INPUT_USE_TMA:
from triton.tools.tensor_descriptor import TensorDescriptor

block_num = triton.cdiv(n, BLOCK_SIZE_N) * mblocks_to_expert_id.shape[0]
token_desc = TensorDescriptor(
token_inputs, token_inputs.shape, token_inputs.stride(), [BLOCK_SIZE_M, BLOCK_SIZE_K]
)
else:
token_desc = None

if WEIGHT_USE_TMA:
from triton.tools.tensor_descriptor import TensorDescriptor

weight_desc = TensorDescriptor(
expert_weights, expert_weights.shape, expert_weights.stride(), [1, BLOCK_SIZE_N, BLOCK_SIZE_K]
)
else:
weight_desc = None

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The import from triton.tools.tensor_descriptor import TensorDescriptor is duplicated. It can be moved to a higher scope (e.g., inside the if support_tma: block at line 752) to avoid redundancy. After moving the import, this section can be simplified.

    if TOKEN_INPUT_USE_TMA:
        token_desc = TensorDescriptor(
            token_inputs, token_inputs.shape, token_inputs.stride(), [BLOCK_SIZE_M, BLOCK_SIZE_K]
        )
    else:
        token_desc = None

    if WEIGHT_USE_TMA:
        weight_desc = TensorDescriptor(
            expert_weights, expert_weights.shape, expert_weights.stride(), [1, BLOCK_SIZE_N, BLOCK_SIZE_K]
        )
    else:
        weight_desc = None

Comment on lines +234 to +245
try:
from triton.tools.tensor_descriptor import TensorDescriptor

support_tma = torch.cuda.get_device_capability() >= (9, 0)
if support_tma:
logger.info("triton support tensor_descriptor")
return True
else:
assert False
except:
logger.info("triton not support tensor_descriptor")
return False

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The use of assert False to control logic flow, combined with a bare except: clause, is not ideal. This can make the code harder to understand and might hide unexpected errors. A more explicit and clear approach would be to use a standard if/else structure and catch specific exceptions like ImportError.

Suggested change
try:
from triton.tools.tensor_descriptor import TensorDescriptor
support_tma = torch.cuda.get_device_capability() >= (9, 0)
if support_tma:
logger.info("triton support tensor_descriptor")
return True
else:
assert False
except:
logger.info("triton not support tensor_descriptor")
return False
try:
from triton.tools.tensor_descriptor import TensorDescriptor
if torch.cuda.get_device_capability() >= (9, 0):
logger.info("triton support tensor_descriptor")
return True
else:
logger.info("triton not support tensor_descriptor: requires CUDA capability >= 9.0")
return False
except ImportError:
logger.info("triton not support tensor_descriptor: failed to import TensorDescriptor")
return False

@shihaobai shihaobai merged commit 98d385a into main Nov 11, 2025
1 check passed
@shihaobai shihaobai deleted the wzj branch November 11, 2025 12:57
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants