In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class MOEConfig:
    def __init__(self, hidden_dim, expert_number, top_k, shared_expert_number=2):
        self.hidden_dim = hidden_dim
        self.expert_number = expert_number
        self.top_k = top_k
        self.shared_expert_number = shared_expert_number

class BasicExpert(nn.Module):
    def __init__(self, input_dim, output_dim, intermediate_dim=None):
        super().__init__()
        intermediate_dim = intermediate_dim or input_dim * 4
        
        self.net = nn.Sequential(
            nn.Linear(input_dim, intermediate_dim),
            nn.GELU(),
            nn.Dropout(0.1),
            nn.Linear(intermediate_dim, output_dim),
            nn.Dropout(0.1),
        )
        self._init_weights()
    
    def _init_weights(self):
        for module in self.net:
            if isinstance(module, nn.Linear):
                nn.init.xavier_uniform_(module.weight)
                nn.init.normal_(module.bias, std=1e-6)
    
    def forward(self, x):
        return self.net(x)

class MOERouter(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.gate = nn.Linear(config.hidden_dim, config.expert_number)
        self.expert_number = config.expert_number
        self.top_k = config.top_k
        self.layer_norm = nn.LayerNorm(config.hidden_dim, eps=1e-6)
    
    def forward(self, x):
        x_norm = self.layer_norm(x)
        router_logits = self.gate(x_norm)
        router_probs = F.softmax(router_logits, dim=1, dtype=torch.float)
        
        router_weights, selected_expert_indices = torch.topk(
            router_probs,
            self.top_k,
            dim=-1,
        )
        
        router_weights = router_weights / router_weights.sum(dim=-1, keepdim=True)
        router_weights = router_weights.to(x.dtype)
        
        expert_mask = F.one_hot(
            selected_expert_indices, 
            num_classes=self.expert_number,
        ).permute(2, 1, 0)
        
        return router_logits, router_weights, selected_expert_indices, expert_mask, router_probs

class SparseMOE(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.top_k = config.top_k
        self.hidden_dim = config.hidden_dim
        self.expert_number = config.expert_number
        
        self.experts = nn.ModuleList([
            BasicExpert(config.hidden_dim, config.hidden_dim) 
            for _ in range(config.expert_number)
        ])
        
        self.router = MOERouter(config)
        self.residual_weight = nn.Parameter(torch.tensor(1.0))
    
    def forward(self, x):
        batch_size, seq_len, hidden_dim = x.size()
        hidden_states = x.view(-1, hidden_dim)
        original_states = hidden_states
        
        router_logits, router_weights, _, expert_masks, router_probs = self.router(hidden_states)
        
        final_hidden_states = torch.zeros_like(hidden_states)
        expert_usage = torch.zeros(self.expert_number, device=x.device)
        
        for expert_idx in range(self.expert_number):
            expert_layer = self.experts[expert_idx]
            current_expert_mask = expert_masks[expert_idx]
            
            router_weight_idx, token_indices = torch.where(current_expert_mask)
            
            if token_indices.numel() == 0:
                continue
                
            expert_usage[expert_idx] = token_indices.numel()
            
            selected_states = hidden_states[token_indices]
            expert_output = expert_layer(selected_states)
            
            weights = router_weights[token_indices, router_weight_idx].unsqueeze(1)
            weighted_output = expert_output * weights
            
            final_hidden_states.index_add_(0, token_indices, weighted_output)
        
        final_hidden_states = final_hidden_states + self.residual_weight * original_states
        final_hidden_states = final_hidden_states.reshape(batch_size, seq_len, hidden_dim)
        router_logits = router_logits.view(batch_size, seq_len, -1)
        
        # 计算辅助损失
        aux_loss = self._load_balancing_loss(router_probs, expert_usage)
        
        return final_hidden_states, router_logits, aux_loss
    
    def _load_balancing_loss(self, router_probs, expert_usage):
        expert_prob = expert_usage / expert_usage.sum()
        router_prob = router_probs.mean(dim=0)
        load_balance_loss = self.expert_number * (expert_prob * router_prob).sum()
        return load_balance_loss