Skip to content

[PyTorch] Add ops for MoE grouped MLP#2664

Merged
timmoon10 merged 6 commits intoNVIDIA:mainfrom
timmoon10:tmoon/grouped-linear-op
Feb 12, 2026
Merged

[PyTorch] Add ops for MoE grouped MLP#2664
timmoon10 merged 6 commits intoNVIDIA:mainfrom
timmoon10:tmoon/grouped-linear-op

Conversation

@timmoon10
Copy link
Collaborator

Description

This PR adds ops needed for the grouped MLP block in Mixture-of-Experts models. In particular, it adds a grouped linear op (similar to the GroupedLinear module) and a ScaledSwiGLU op. It is the same as #2622, but doesn't include the fused ops with experimental kernels. Closes #2560.

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

  • Add grouped linear op
  • Add scaled SwiGLU op
  • Handle edge cases in noop_cat function

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

Signed-off-by: Tim Moon <tmoon@nvidia.com>
@timmoon10 timmoon10 added the enhancement New feature or request label Feb 9, 2026
@greptile-apps
Copy link
Contributor

greptile-apps bot commented Feb 9, 2026

Greptile Overview

Greptile Summary

This PR adds fusible operations for MoE (Mixture-of-Experts) grouped MLP blocks, specifically GroupedLinear and ScaledSwiGLU. The implementation is clean and well-tested.

Key Changes:

  • GroupedLinear: Performs multiple parallel linear transformations by splitting input along the first dimension, applying separate weights/biases to each split, and concatenating results - includes comprehensive FP8/MXFP8 quantization support and Megatron-LM gradient fusion
  • ScaledSwiGLU: Adds element-wise post-scaling to SwiGLU activation, enabling per-expert gating in MoE models
  • GLU interleaving: Both ops support experimental block-interleaved format for advanced fused kernels
  • Edge case fix: noop_cat now handles manually-configured views from split_quantize by checking storage bounds before creating strided views
  • Refactoring: Moved SwiGLU variants from activation.py to dedicated swiglu.py module for better organization

Implementation Quality:

  • Robust parameter validation and consistency checks across all groups
  • Proper handling of quantization configurations (FP8, MXFP8)
  • Comprehensive test coverage including edge cases (zero-sized splits, various dtypes, quantization modes)
  • Gradient correctness verified against PyTorch reference implementations

Confidence Score: 5/5

  • This PR is safe to merge with high confidence
  • The implementation is well-architected with comprehensive test coverage, proper error handling, and careful consideration of edge cases. All previous review concerns have been addressed by the senior developer with valid technical explanations. The code follows established patterns in the codebase and includes extensive validation logic.
  • No files require special attention

Important Files Changed

Filename Overview
transformer_engine/pytorch/ops/basic/grouped_linear.py New file implementing GroupedLinear operation for MoE models - performs multiple parallel linear transformations with support for FP8 quantization, MXFP8 weights, and Megatron-LM main_grad fusion
transformer_engine/pytorch/ops/basic/swiglu.py New file with SwiGLU variants: moves SwiGLU and ClampedSwiGLU from activation.py and adds new ScaledSwiGLU op with element-wise post-scaling; includes support for experimental GLU interleaving
transformer_engine/pytorch/module/_common.py Improves noop_cat edge case handling: adds storage size check before creating strided view to prevent out-of-bounds access with manually-configured views from split_quantize
tests/pytorch/test_fusible_ops.py Adds comprehensive tests for GroupedLinear, ScaledSwiGLU, and interleaved GLU formats; includes edge case testing with zero-sized splits and combined grouped MLP block

Sequence Diagram

sequenceDiagram
    participant User
    participant GroupedLinear
    participant tex
    participant general_grouped_gemm
    participant ScaledSwiGLU
    
    Note over User,ScaledSwiGLU: MoE Grouped MLP Forward Pass
    
    User->>GroupedLinear: forward(input, split_sizes)
    GroupedLinear->>GroupedLinear: validate split_sizes length
    GroupedLinear->>tex: split_quantize(input, split_sizes, quantizers)
    tex-->>GroupedLinear: quantized input splits [x0, x1, ..., xN]
    GroupedLinear->>general_grouped_gemm: gemm(weights, input_splits, m_splits, bias)
    general_grouped_gemm-->>GroupedLinear: concatenated output
    GroupedLinear-->>User: output tensor
    
    User->>ScaledSwiGLU: forward(input, scales)
    ScaledSwiGLU->>ScaledSwiGLU: remove GLU interleaving if needed
    ScaledSwiGLU->>tex: swiglu(swiglu_in, quantizer)
    tex-->>ScaledSwiGLU: swiglu_out
    ScaledSwiGLU->>ScaledSwiGLU: multiply by scales.unsqueeze(-1)
    ScaledSwiGLU-->>User: scaled output
    
    Note over User,ScaledSwiGLU: Backward Pass
    
    User->>ScaledSwiGLU: backward(grad_output)
    ScaledSwiGLU->>tex: swiglu(saved_input, None)
    tex-->>ScaledSwiGLU: swiglu_out (recomputed)
    ScaledSwiGLU->>ScaledSwiGLU: grad_scales = vecdot(swiglu_out, grad_output)
    ScaledSwiGLU->>tex: dswiglu(grad_output * scales, saved_input, None)
    tex-->>ScaledSwiGLU: grad_input
    ScaledSwiGLU-->>User: grad_input, grad_scales
    
    User->>GroupedLinear: backward(grad_output)
    GroupedLinear->>tex: split_quantize(grad_output, split_sizes, quantizers)
    tex-->>GroupedLinear: grad output splits
    GroupedLinear->>general_grouped_gemm: dgrad gemm (weights, grad_splits)
    general_grouped_gemm-->>GroupedLinear: grad_input
    GroupedLinear->>general_grouped_gemm: wgrad gemm (input_splits, grad_splits)
    general_grouped_gemm-->>GroupedLinear: grad_weights
    GroupedLinear-->>User: grad_input, grad_weights, grad_biases
Loading

greptile-apps[bot]

This comment was marked as resolved.

greptile-apps[bot]

This comment was marked as resolved.

@timmoon10

This comment was marked as outdated.

Comment on lines +340 to +341
swiglu_out = tex.swiglu(swiglu_in, None)
out = swiglu_out * scales.unsqueeze(-1)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Considering it is implemented with 2 kernels anyway, what is the benefit of having this operation here? I would prefer to have the ScaleWithExtraInput basic op or something like that instead.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was the approach I used in my initial implementation (#2605), but it's not compatible with the fused GEMM + SwiGLU kernel (https://github.com/NVIDIA/cudnn-frontend/blob/main/python/cudnn/grouped_gemm/grouped_gemm_swiglu/api.py). If we have a standalone scale op, then we need to cache its input for the backward pass. However, the fused kernel assumes you are doing activation recompute and it only outputs the SwiGLU input and scale output. Rather than intertwining the implementations of the SwiGLU and scale to support activation recompute, I just implemented a new op that does it explicitly

Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Review suggestion from @ptrendx.

Signed-off-by: Tim Moon <tmoon@nvidia.com>
greptile-apps[bot]

This comment was marked as outdated.

@timmoon10

This comment was marked as outdated.

Signed-off-by: Tim Moon <tmoon@nvidia.com>
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

7 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

@timmoon10
Copy link
Collaborator Author

/te-ci pytorch

@timmoon10 timmoon10 merged commit 3774aa3 into NVIDIA:main Feb 12, 2026
19 of 24 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

enhancement New feature or request

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[PyTorch] Support grouped linear op in te.Sequential

2 participants