# Mixtral MoE with Transformer Engine

## Step 1: Wrap MoE Layers with TE Modules

This notebook demonstrates wrapping Mixtral's MoE FFN layers with Transformer Engine's `GroupedLinear` for efficient expert processing.

Reference: `src/transformers/models/mixtral/modular_mixtral.py`

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional, Tuple
import transformer_engine.pytorch as te
from transformer_engine.pytorch import GroupedLinear

class TEMixtralSparseMoeBlock(nn.Module):
    """
    Transformer Engine optimized MoE block using GroupedLinear for parallel expert processing.
    
    Key improvements:
    1. Use te.GroupedLinear to process all experts in a single batched GEMM
    2. Use te.moe_permute/unpermute for efficient token routing
    """
    def __init__(self, config):
        super().__init__()
        self.hidden_dim = config.hidden_size
        self.ffn_dim = config.intermediate_size
        self.num_experts = config.num_local_experts
        self.top_k = config.num_experts_per_tok
        
        # Keep HuggingFace router (not in critical path for performance)
        self.gate = nn.Linear(self.hidden_dim, self.num_experts, bias=False)
        
        # Replace individual expert layers with GroupedLinear
        # GroupedLinear processes all experts in parallel with a single GEMM
        # For SwiGLU: w1 (gate) and w3 (up) are combined, then w2 (down)
        
        # w1 and w3 combined (gate_proj + up_proj)
        self.experts_gate_up = GroupedLinear(
            num_gemms=self.num_experts,
            in_features=self.hidden_dim,
            out_features=2 * self.ffn_dim,  # 2x for gate and up proj combined
            bias=False,
            device='cuda'
        )
        
        # w2 (down_proj)
        self.experts_down = GroupedLinear(
            num_gemms=self.num_experts,
            in_features=self.ffn_dim,
            out_features=self.hidden_dim,
            bias=False,
            device='cuda'
        )
        
    def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Args:
            hidden_states: [batch_size, sequence_length, hidden_dim]
            
        Returns:
            final_hidden_states: [batch_size, sequence_length, hidden_dim]
            router_logits: [batch_size * sequence_length, num_experts]
        """
        batch_size, sequence_length, hidden_dim = hidden_states.shape
        hidden_states_flat = hidden_states.view(-1, hidden_dim)  # [num_tokens, hidden_dim]
        num_tokens = hidden_states_flat.shape[0]
        
        # Router: Get expert assignments for each token
        router_logits = self.gate(hidden_states_flat)
        routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
        routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)
        routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
        routing_weights = routing_weights.to(hidden_states.dtype)
        
        # Permute tokens by expert assignment
        # moe_permute groups tokens going to the same expert together
        permuted_tokens, row_id_map = te.moe_permute(
            hidden_states_flat,
            selected_experts.to(torch.int32),
            num_out_tokens=None,  # Auto-calculate
            max_token_num=num_tokens
        )
        
        # Calculate m_splits: number of tokens assigned to each expert
        m_splits = []
        for expert_idx in range(self.num_experts):
            expert_mask = (selected_experts == expert_idx).any(dim=-1)
            m_splits.append(expert_mask.sum().item() * self.top_k)
        
        # Process all experts in parallel using GroupedLinear
        # Gate and Up projection (combined)
        intermediate = self.experts_gate_up(permuted_tokens, m_splits=m_splits)
        
        # Apply SwiGLU activation: silu(gate) * up
        gate, up = intermediate.chunk(2, dim=-1)
        intermediate_act = F.silu(gate) * up
        
        # Down projection
        expert_outputs = self.experts_down(intermediate_act, m_splits=m_splits)
        
        # Unpermute tokens back to original order and apply routing weights
        final_hidden_states = te.moe_unpermute(
            expert_outputs,
            row_id_map,
            probs=routing_weights
        )
        
        final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim)
        return final_hidden_states, router_logits

### Test the Implementation

In [None]:
# Create a mock config for testing
class MixtralConfig:
    hidden_size = 4096
    intermediate_size = 14336
    num_local_experts = 8
    num_experts_per_tok = 2

config = MixtralConfig()

# Initialize TE-optimized MoE block
te_moe_block = TEMixtralSparseMoeBlock(config).cuda()

# Test with sample input
batch_size, seq_len = 2, 16
hidden_states = torch.randn(batch_size, seq_len, config.hidden_size, device='cuda', dtype=torch.bfloat16)

# Forward pass
with torch.no_grad():
    output, router_logits = te_moe_block(hidden_states)
    
print(f"Input shape: {hidden_states.shape}")
print(f"Output shape: {output.shape}")
print(f"Router logits shape: {router_logits.shape}")
print(f"Output dtype: {output.dtype}")
print("âœ“ TE-optimized MoE block working correctly!")

### Next: Weight Mapping and Integration

To integrate with HuggingFace Mixtral models, you need to:

1. Map weights from HF `MixtralSparseMoeBlock` to `TEMixtralSparseMoeBlock`
2. Use monkey-patching to replace HF layers during model loading
3. Implement weight loading from HF checkpoints