In [1]:
print("Soft Mixture of Experts")

Soft Mixture of Experts


In [13]:
import torch
import torch.nn as nn
from torch import Tensor

In [12]:
from dataclasses import dataclass
import math

In [11]:
@dataclass
class MoEConfig:
    d_model: int = 512
    num_experts: int = 8
    expert_hidden_dim: int = 2048
    capacity_factor: float = 1.25
    min_capacity: int = 4
    drop_tokens: bool = True # Whether to drop overflow tokens
    use_residual: bool = True
    activation: str = "gelu"
    dropout: float = 0.0
    bias: bool = True
    aux_loss_coef: float = 1e-2

In [4]:
def soft_moe_layer(X, phi, experts):
    logits = torch.einsum("md,dnp->mnp", X, phi)
    D = torch.softmax(logits, dim=0)
    C = torch.softmax(logits, dim=1)
    
    Xs = torch.einsum("md,mnp->npd", X, D)
    Ys = torch.stack([
        f(Xs[i, :, :]) for i, f in enumerate(experts)
    ], axis=0)
    
    Y = torch.einsum("npd,mnp->md", Ys, C)
    
    return Y

In [17]:
class Expert(nn.Module):
    """Simple feed-forward expert"""
    
    def __init__(self, d_model, d_hidden, dropout: float = 0.0):
        super().__init__()
        
        self.expert = nn.Sequential(
            nn.Linear(d_model, d_hidden),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(d_hidden, d_model)
        )
    
    def forward(self, x):
        return self.expert(x)

In [23]:
class Router(nn.Module):
    """Top-1 Router with capacity and load-balancing loss."""
    
    def __init__(self, d_model, num_experts, z_loss_coef: float, aux_loss_coef: float, k=2, capacity_factor=1.0, epsilon=1e-6, min_capacity: int = 4, training=False):
        super().__init__()
        
        self.d_model = d_model
        self.num_experts = num_experts
        self.k = k
        self.capacity_factor = capacity_factor
        self.epsilon = epsilon
        self.min_capacity = min_capacity
        self.training = training
        self.z_loss_coef = z_loss_coef
        self.aux_loss_coef = aux_loss_coef
        
        self.gate = nn.Linear(d_model, num_experts, bias=False)
        nn.init.normal(self.gate, std=0.02)
        
    def _compute_capacity(self, num_tokens: int) -> int:
        return math.ceil(self.capacity_factor * num_tokens / self.num_experts) 
        
    def forward(self, x: Tensor):
        B, T, D = x.shape
        
        num_tokens = B * T
        E = self.num_experts
        
        # Compute logits
        logits = self.gate(x) # [B, T, E]
        
        # Calc probabilities
        probs = torch.softmax(logits, dim=-1) # [B, T, E]
        
        capacity = self._compute_capacity(num_tokens)
        
        # Compute z-loss
        z_loss = torch.logsumexp(logits, dim=-1).square().mean()
            
        if self.training:
            u = torch.rand_like(logits)
            g = -torch.log(-torch.log(u)) # gumbel noise
            
            noisy_logits = logits + g
            expert_index = logits.argmax(dim=-1)
            
            routing_weights = probs[torch.arange(num_tokens), expert_index]
        else:
            expert_index = logits.argmax(dim=-1)
            routing_weights = probs[torch.arange(num_tokens), expert_index]
            
        # Auxilliary loss
        tokens_per_expert = torch.zeros(E, device=x.device)
        tokens_per_expert.scatter_add_(0, expert_index, torch.ones(num_tokens, device=x.device))
        
        f_i = tokens_per_expert / num_tokens # fraction of tokens dispatched to expert i
        p_i = probs.mean(dim=0) # fraction of the router probability allocated to expert i
        
        aux_loss = (self.aux_loss_coef * E * (f_i * p_i).sum() + z_loss * self.z_loss_coef)
        
        dispatch_mask = torch.zeros(num_tokens, E, capacity, dtype=torch.bool, device=x.device)
        combine_weights = torch.zeros(num_tokens, E, capacity, dtype=probs.dtype, device=x.device)
        
                   
        for e in range(E):
            token_indices = torch.where(expert_index == e)
            
            if token_indices.numel() == 0:
                continue
            
            token_indices = token_indices[:capacity]
            
            slots = torch.arange(token_indices.numel(), device=x.device)
            
            dispatch_mask[num_tokens, e, capacity] = True
            combine_weights[num_tokens, e, capacity] = routing_weights[token_indices].to(combine_weights.dtype)
            
        return dispatch_mask, combine_weights, aux_loss, probs.detach()               

In [None]:
class MixtureOfExperts(nn.Module):
    """Top-1 Mixture of Experts Layer"""
    
    def __init__(self, d_model, d_hidden, n_experts):
        super().__init__()
        
        self.n_experts = n_experts
        self.experts = nn.ModuleList([
            Expert(d_model, d_hidden) for _ in range(n_experts)
        ])
        
        self.layer_norm = nn.LayerNorm(d_model)
       
    def _dispatch(self):
        # expert input [C, D]
        # Extract token_idx, slot_idx from dispatch mask
        # Slice selected tokens from x
        # add to expert input
    
    def _combine(self):
        pass 
     
    def forward(self, x):
        pass
            
            
            