In [1]:
import os

os.environ['CUDA_VISIBLE_DEVICES'] = '4,5,6,7'

In [4]:
import torch
from torch import nn
import torch.nn.functional as F
from torchao.prototype.moe_training.scaled_grouped_mm import _to_fp8_rowwise_then_scaled_grouped_mm
from torchao.prototype.moe_training import scaled_grouped_mm

def pad_for_alignment(grouped_inputs, experts_count, alignment=16):
    """Pad inputs so each expert group is aligned. Returns padding context for reuse."""
    experts_count_padded = ((experts_count + alignment - 1) // alignment) * alignment
    
    # cu_experts_count = experts_count.cumsum(dim=0).to(torch.int32)
    cu_original = torch.cat([torch.zeros(1, dtype=torch.int32, device=experts_count.device), 
                             experts_count.cumsum(0).to(torch.int32)])
    cu_padded = torch.cat([torch.zeros(1, dtype=torch.int32, device=experts_count.device), 
                           experts_count_padded.cumsum(0).to(torch.int32)])
    
    total_tokens = grouped_inputs.shape[0]
    token_indices = torch.arange(total_tokens, device=grouped_inputs.device)
    expert_ids = torch.searchsorted(cu_original[1:], token_indices, right=True)
    position_in_group = token_indices - cu_original[expert_ids]
    dest_indices = cu_padded[expert_ids] + position_in_group
    
    total_padded = cu_padded[-1]
    padded_inputs = torch.zeros(total_padded, grouped_inputs.shape[1], 
                                dtype=grouped_inputs.dtype, device=grouped_inputs.device)
    padded_inputs[dest_indices] = grouped_inputs
    
    ctx = {
        'cu_original': cu_original[1:],  # For unpadded LoRA ops
        'cu_padded': cu_padded[1:],
        'dest_indices': dest_indices,
        'total_padded': total_padded,
        'total_tokens': total_tokens,
    }
    return padded_inputs, ctx


def pad_tensor(tensor, ctx):
    padded = torch.zeros(ctx['total_padded'], tensor.shape[1], 
                         dtype=tensor.dtype, device=tensor.device)
    padded[ctx['dest_indices']] = tensor
    return padded


def unpad_tensor(padded_tensor, ctx):
    return padded_tensor[ctx['dest_indices']]

In [5]:
top_k = 4
num_experts = 32
norm_topk_prob = True
alignment = 16

hidden_states = torch.randn(32, 100, 1024, dtype=torch.bfloat16).cuda()
batch_size, sequence_length, hidden_dim = hidden_states.shape
inputs = hidden_states.view(-1, hidden_dim)
M = inputs.shape[0]
w = torch.randn(32, 512, 1024, dtype=torch.bfloat16).cuda().transpose(-1, -2)
gate = nn.Linear(1024, 32, bias=False, device='cuda', dtype=torch.bfloat16)

In [12]:
gate_proj_stacked_t = nn.Parameter(w.clone())

router_logits = gate(inputs)
routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
topk_weights, topk_indices = torch.topk(routing_weights, top_k, dim=-1)
if norm_topk_prob:
    topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
topk_weights = topk_weights.to(inputs.dtype)
sort_indices = topk_indices.view(-1).argsort()
sorted_pos = sort_indices // top_k
grouped_inputs = inputs[sorted_pos]

experts_count = topk_indices.view(-1).bincount(minlength=num_experts)

padded_inputs, pad_ctx = pad_for_alignment(grouped_inputs, experts_count, alignment=16)
cu_experts_padded = pad_ctx['cu_padded']
cu_experts_original = pad_ctx['cu_original']

o = _to_fp8_rowwise_then_scaled_grouped_mm(
    padded_inputs,
    gate_proj_stacked_t,
    cu_experts_padded,
)

# Unpad outputs (no CPU-GPU sync)
o_unpadded = unpad_tensor(o, pad_ctx)
o_unpadded.backward(torch.ones_like(o_unpadded))

In [17]:
gate_proj_stacked_t_ = nn.Parameter(w.clone())

router_logits = gate(inputs.clone())
routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
topk_weights, topk_indices = torch.topk(routing_weights, top_k, dim=-1)
if norm_topk_prob:
    topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
topk_weights = topk_weights.to(inputs.dtype)
sort_indices = topk_indices.view(-1).argsort()  # (M * topk,)
sorted_pos = sort_indices // top_k
grouped_inputs = inputs[sorted_pos]  # (M * topk, dim)

experts_count = topk_indices.view(-1).bincount(minlength=num_experts)
cu_experts_count = experts_count.cumsum(dim=0).to(torch.int32)
o = torch._grouped_mm(
    grouped_inputs,
    gate_proj_stacked_t_,
    cu_experts_original,
)
o.backward(torch.ones_like(o))

In [18]:
(o - o_unpadded).abs().max()

tensor(6.2500, device='cuda:0', dtype=torch.bfloat16, grad_fn=<MaxBackward1>)

In [19]:
(gate_proj_stacked_t_.grad - gate_proj_stacked_t.grad).abs().max()

tensor(2.5000, device='cuda:0', dtype=torch.bfloat16)