[PyTorch] Add ops for MoE grouped MLP#2664
Conversation
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Greptile OverviewGreptile SummaryThis PR adds fusible operations for MoE (Mixture-of-Experts) grouped MLP blocks, specifically Key Changes:
Implementation Quality:
Confidence Score: 5/5
Important Files Changed
Sequence DiagramsequenceDiagram
participant User
participant GroupedLinear
participant tex
participant general_grouped_gemm
participant ScaledSwiGLU
Note over User,ScaledSwiGLU: MoE Grouped MLP Forward Pass
User->>GroupedLinear: forward(input, split_sizes)
GroupedLinear->>GroupedLinear: validate split_sizes length
GroupedLinear->>tex: split_quantize(input, split_sizes, quantizers)
tex-->>GroupedLinear: quantized input splits [x0, x1, ..., xN]
GroupedLinear->>general_grouped_gemm: gemm(weights, input_splits, m_splits, bias)
general_grouped_gemm-->>GroupedLinear: concatenated output
GroupedLinear-->>User: output tensor
User->>ScaledSwiGLU: forward(input, scales)
ScaledSwiGLU->>ScaledSwiGLU: remove GLU interleaving if needed
ScaledSwiGLU->>tex: swiglu(swiglu_in, quantizer)
tex-->>ScaledSwiGLU: swiglu_out
ScaledSwiGLU->>ScaledSwiGLU: multiply by scales.unsqueeze(-1)
ScaledSwiGLU-->>User: scaled output
Note over User,ScaledSwiGLU: Backward Pass
User->>ScaledSwiGLU: backward(grad_output)
ScaledSwiGLU->>tex: swiglu(saved_input, None)
tex-->>ScaledSwiGLU: swiglu_out (recomputed)
ScaledSwiGLU->>ScaledSwiGLU: grad_scales = vecdot(swiglu_out, grad_output)
ScaledSwiGLU->>tex: dswiglu(grad_output * scales, saved_input, None)
tex-->>ScaledSwiGLU: grad_input
ScaledSwiGLU-->>User: grad_input, grad_scales
User->>GroupedLinear: backward(grad_output)
GroupedLinear->>tex: split_quantize(grad_output, split_sizes, quantizers)
tex-->>GroupedLinear: grad output splits
GroupedLinear->>general_grouped_gemm: dgrad gemm (weights, grad_splits)
general_grouped_gemm-->>GroupedLinear: grad_input
GroupedLinear->>general_grouped_gemm: wgrad gemm (input_splits, grad_splits)
general_grouped_gemm-->>GroupedLinear: grad_weights
GroupedLinear-->>User: grad_input, grad_weights, grad_biases
|
This comment was marked as outdated.
This comment was marked as outdated.
| swiglu_out = tex.swiglu(swiglu_in, None) | ||
| out = swiglu_out * scales.unsqueeze(-1) |
There was a problem hiding this comment.
Considering it is implemented with 2 kernels anyway, what is the benefit of having this operation here? I would prefer to have the ScaleWithExtraInput basic op or something like that instead.
There was a problem hiding this comment.
This was the approach I used in my initial implementation (#2605), but it's not compatible with the fused GEMM + SwiGLU kernel (https://github.com/NVIDIA/cudnn-frontend/blob/main/python/cudnn/grouped_gemm/grouped_gemm_swiglu/api.py). If we have a standalone scale op, then we need to cache its input for the backward pass. However, the fused kernel assumes you are doing activation recompute and it only outputs the SwiGLU input and scale output. Rather than intertwining the implementations of the SwiGLU and scale to support activation recompute, I just implemented a new op that does it explicitly
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Review suggestion from @ptrendx. Signed-off-by: Tim Moon <tmoon@nvidia.com>
This comment was marked as outdated.
This comment was marked as outdated.
Signed-off-by: Tim Moon <tmoon@nvidia.com>
|
/te-ci pytorch |
Description
This PR adds ops needed for the grouped MLP block in Mixture-of-Experts models. In particular, it adds a grouped linear op (similar to the
GroupedLinearmodule) and aScaledSwiGLUop. It is the same as #2622, but doesn't include the fused ops with experimental kernels. Closes #2560.Type of change
Changes
noop_catfunctionChecklist: