### toy moe

by claude

In [None]:
import math
import torch
import torch.nn as nn
import torch.nn.functional as F

class MoELayer(nn.Module):
    def __init__(self, 
                 input_dim: int,
                 hidden_dim: int,
                 num_experts: int = 8,
                 num_experts_per_tok: int = 2,
                 router_aux_loss_coef: float = 0.01,
                 norm_topk_prob: bool = False):
        super().__init__()
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.num_experts = num_experts
        self.num_experts_per_tok = num_experts_per_tok
        self.router_aux_loss_coef = router_aux_loss_coef
        self.norm_topk_prob = norm_topk_prob

        # Router network - maps input to expert selection logits
        self.router = nn.Linear(input_dim, num_experts, bias=False)
        
        # Create experts - each is a simple MLP
        self.experts = nn.ModuleList([
            nn.Sequential(
                nn.Linear(input_dim, hidden_dim),
                nn.GELU(),
                nn.Linear(hidden_dim, input_dim)
            ) for _ in range(num_experts)
        ])

    def forward(self, x: torch.Tensor):
        batch_size, seq_len, _ = x.shape
        router_logits = self.router(x)  # [batch, seq, num_experts]
        
        # Get routing probabilities 
        router_probs = F.softmax(router_logits, dim=-1)
        
        # Select top-k experts per token
        topk_probs, topk_indices = torch.topk(
            router_probs, self.num_experts_per_tok, dim=-1
        )
        
        if self.norm_topk_prob:
            topk_probs = topk_probs / topk_probs.sum(dim=-1, keepdim=True)
        
        # Initialize output with zeros
        outputs = torch.zeros_like(x)
        router_z_loss = torch.zeros(1, device=x.device)
        
        # Compute load balancing auxiliary loss
        # Measures how evenly experts are utilized
        if self.training:
            # Mean probability of selecting each expert
            expert_usage = router_probs.mean(dim=(0,1))
            router_z_loss = self.router_aux_loss_coef * torch.mean(
                router_logits * router_probs
            )
            
        # Process tokens through their selected experts
        x_flat = x.reshape(-1, self.input_dim)
        for i in range(self.num_experts_per_tok):
            expert_index = topk_indices[:, :, i]  # [batch, seq]
            prob = topk_probs[:, :, i]  # [batch, seq] 
            
            flat_indices = expert_index.reshape(-1)  # [batch * seq]
            flat_probs = prob.reshape(-1)  # [batch * seq]
            
            # Process each expert
            for expert_id in range(self.num_experts):
                expert_mask = (flat_indices == expert_id)
                if not expert_mask.any():
                    continue
                    
                # Get inputs for this expert
                expert_input = x_flat[expert_mask]
                expert_prob = flat_probs[expert_mask]
                
                # Process through expert
                expert_output = self.experts[expert_id](expert_input)
                expert_output = expert_output * expert_prob.unsqueeze(-1)
                
                # Add to output
                outputs_flat = outputs.reshape(-1, self.input_dim)
                outputs_flat[expert_mask] += expert_output
                
        return outputs, router_z_loss

class ToyMoETransformer(nn.Module):
    def __init__(self,
                 vocab_size: int,
                 d_model: int = 256,
                 nhead: int = 4,
                 num_layers: int = 2,
                 num_experts: int = 8,
                 num_experts_per_tok: int = 2,
                 expert_hidden_multiplier: int = 4,
                 dropout: float = 0.1):
        super().__init__()
        
        self.d_model = d_model
        self.vocab_embedding = nn.Embedding(vocab_size, d_model)
        self.pos_encoder = PositionalEncoding(d_model, dropout)
        
        # Create transformer layers with MoE
        self.layers = nn.ModuleList([
            TransformerMoELayer(
                d_model=d_model,
                nhead=nhead,
                num_experts=num_experts,
                num_experts_per_tok=num_experts_per_tok,
                expert_hidden_dim=d_model * expert_hidden_multiplier,
                dropout=dropout
            ) for _ in range(num_layers)
        ])
        
        self.final_norm = nn.LayerNorm(d_model)
        self.output_proj = nn.Linear(d_model, vocab_size)
        
    def forward(self, src, src_mask=None, src_key_padding_mask=None):
        x = self.vocab_embedding(src) * math.sqrt(self.d_model)
        x = self.pos_encoder(x)
        
        router_aux_loss = 0
        
        for layer in self.layers:
            x, layer_loss = layer(
                x, 
                src_mask=src_mask,
                src_key_padding_mask=src_key_padding_mask
            )
            router_aux_loss += layer_loss
            
        x = self.final_norm(x)
        output = self.output_proj(x)
        
        return output, router_aux_loss

class TransformerMoELayer(nn.Module):
    def __init__(self, 
                 d_model: int,
                 nhead: int,
                 num_experts: int,
                 num_experts_per_tok: int,
                 expert_hidden_dim: int,
                 dropout: float = 0.1):
        super().__init__()
        
        self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
        self.moe = MoELayer(
            input_dim=d_model,
            hidden_dim=expert_hidden_dim,
            num_experts=num_experts,
            num_experts_per_tok=num_experts_per_tok
        )
        
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x, src_mask=None, src_key_padding_mask=None):
        # Self attention block
        x2 = self.norm1(x)
        x2 = self.self_attn(
            x2, x2, x2,
            attn_mask=src_mask,
            key_padding_mask=src_key_padding_mask
        )[0]
        x = x + self.dropout(x2)
        
        # MoE block
        x2 = self.norm2(x)
        x2, router_loss = self.moe(x2)
        x = x + self.dropout(x2)
        
        return x, router_loss

class PositionalEncoding(nn.Module):
    def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 5000):
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)

        position = torch.arange(max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
        pe = torch.zeros(max_len, 1, d_model)
        pe[:, 0, 0::2] = torch.sin(position * div_term)
        pe[:, 0, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x + self.pe[:x.size(0)]
        return self.dropout(x)

# Example usage:
if __name__ == "__main__":
    # Create a toy model
    model = ToyMoETransformer(
        vocab_size=1000,
        d_model=256,
        nhead=4,
        num_layers=2,
        num_experts=8,
        num_experts_per_tok=2
    )
    
    # Create sample input
    src = torch.randint(0, 1000, (10, 32))  # seq_len=10, batch_size=32
    
    # Forward pass
    output, router_loss = model(src)
    print(f"Output shape: {output.shape}")  # [10, 32, 1000]
    print(f"Router loss: {router_loss}")