In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from typing import Tuple, Optional, List
import numpy as n

In [None]:
# expert method 

class SwiGLUExpert(nn.Module):
    """ Single Expert Using SwiGLU activation function. 
        
    SwiGLU(x) = (x W_gate ⊙ σ(x W_up)) W_down
    where σ is SiLU activation, ⊙ is element-wise product
    """

    def __init__(self, input_dim : int, hidden_dim : int , dropout :float = 0.0):
        super(SwiGLUExpert,self).__init__()
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim

        #gate and up projection
        self.gate_proj = nn.Linear(input_dim, hidden_dim,bias=False)
        self.up_proj = nn.Linear(input_dim, hidden_dim, bias=False)
        self.down_proj = nn.Linear(hidden_dim, input_dim)
        self.dropout = nn.Dropout(dropout)

        self._init_weights()

    def _init_weights(self):
        # Initialize weights using Xavier initialization
        nn.init.xavier_uniform_(self.gate_proj.weight)
        nn.init.xavier_uniform_(self.up_proj.weight)
        nn.init.xavier_uniform_(self.down_proj.weight)
    
    def forward(self, x : torch.Tensor) -> torch.Tensor:
        
        gate = self.gate_proj(x)
        up = self.up_proj(x)
        swiglu = F.silu(gate) * up
        output = self.down_proj(swiglu)
        output = self.dropout(output)
        return output





In [None]:
# router implementation

class TopKRouter(nn.Module):
    """
    Router that selects top-K experts for each token.
    """
    
    def __init__(
        self,
        input_dim: int,
        num_experts: int,
        top_k: int = 6,
        use_bias: bool = False
    ):
        """
        Args:
            input_dim: Input dimension
            num_experts: Total number of experts
            top_k: Number of experts to select per token
            use_bias: Whether to use bias in router (V3 uses this)
        """
        super(TopKRouter, self).__init__()
        
        self.input_dim = input_dim
        self.num_experts = num_experts
        self.top_k = top_k
        
        # Router linear layer
        self.gate = nn.Linear(input_dim, num_experts, bias=use_bias)
        
        nn.init.normal_(self.gate.weight, std=0.02)
        if use_bias:
            nn.init.zeros_(self.gate.bias)
    
    def forward(
        self,
        hidden_states: torch.Tensor,
        training: bool = True
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """
        Route tokens to experts.
        
        Args:
            hidden_states: [batch_size, seq_len, input_dim]
            training: Whether in training mode (for load balancing)
        
        Returns:
            expert_indices: [batch_size, seq_len, top_k]
            expert_weights: [batch_size, seq_len, top_k]
            router_logits: [batch_size, seq_len, num_experts]
        """
        batch_size, seq_len, _ = hidden_states.size()
        
        # Compute router logits
        router_logits = self.gate(hidden_states)  # [batch, seq, num_experts]
        
        # Apply softmax to get probabilities
        router_probs = F.softmax(router_logits, dim=-1)
        
        # Select top-K experts
        expert_weights, expert_indices = torch.topk(
            router_probs,
            self.top_k,
            dim=-1
        )
        
        # Renormalize weights
        expert_weights = expert_weights / expert_weights.sum(dim=-1, keepdim=True)
        
        return expert_indices, expert_weights, router_logits


In [None]:
# ============================================================================
# 3. DEEPSEEK-V2 MOE LAYER
# ============================================================================


class DeepSeekV2MoE(nn.Module):
    """
    DeepSeek-V2 MoE layer with:
    - Fine-grained routed experts (64)
    - Shared experts (2)
    - Top-K routing (K=6)
    - Auxiliary losses for load balancing
    """
    
    def __init__(
        self,
        input_dim: int = 5120,
        num_routed_experts: int = 64,
        num_shared_experts: int = 2,
        expert_hidden_dim: int = 1536,
        top_k: int = 6,
        dropout: float = 0.0,
        aux_loss_alpha: float = 0.01,
        aux_loss_beta: float = 0.01
    ):
        """
        Args:
            input_dim: Model dimension
            num_routed_experts: Number of routed experts
            num_shared_experts: Number of shared (always-active) experts
            expert_hidden_dim: Hidden dimension for each expert
            top_k: Number of experts to activate per token
            dropout: Dropout probability
            aux_loss_alpha: Weight for load balancing loss
            aux_loss_beta: Weight for importance loss
        """
        super(DeepSeekV2MoE, self).__init__()
        
        self.input_dim = input_dim
        self.num_routed_experts = num_routed_experts
        self.num_shared_experts = num_shared_experts
        self.expert_hidden_dim = expert_hidden_dim
        self.top_k = top_k
        self.aux_loss_alpha = aux_loss_alpha
        self.aux_loss_beta = aux_loss_beta
        
        # Router for routed experts
        self.router = TopKRouter(
            input_dim=input_dim,
            num_experts=num_routed_experts,
            top_k=top_k,
            use_bias=False
        )
        
        # Routed experts
        self.routed_experts = nn.ModuleList([
            SwiGLUExpert(input_dim, expert_hidden_dim, dropout)
            for _ in range(num_routed_experts)
        ])
        
        # Shared experts (always active)
        self.shared_experts = nn.ModuleList([
            SwiGLUExpert(input_dim, expert_hidden_dim, dropout)
            for _ in range(num_shared_experts)
        ])
    
    def forward(
        self,
        hidden_states: torch.Tensor,
        return_aux_loss: bool = True
    ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
        """
        Forward pass through MoE layer.
        
        Args:
            hidden_states: [batch_size, seq_len, input_dim]
            return_aux_loss: Whether to compute auxiliary losses
        
        Returns:
            output: [batch_size, seq_len, input_dim]
            aux_loss: Auxiliary loss for load balancing (if return_aux_loss)
        """
        batch_size, seq_len, input_dim = hidden_states.size()
        
        # ================================================================
        # 1. Shared Expert Processing (always active)
        # ================================================================
        shared_output = torch.zeros_like(hidden_states)
        
        for expert in self.shared_experts:
            shared_output = shared_output + expert(hidden_states)
        
        # ================================================================
        # 2. Routed Expert Processing
        # ================================================================
        # Get routing decisions
        expert_indices, expert_weights, router_logits = self.router(hidden_states)
        
        # Flatten for processing
        flat_hidden = hidden_states.view(-1, input_dim)  # [batch*seq, input_dim]
        flat_indices = expert_indices.view(-1, self.top_k)  # [batch*seq, top_k]
        flat_weights = expert_weights.view(-1, self.top_k)  # [batch*seq, top_k]
        
        # Initialize routed output
        routed_output = torch.zeros_like(flat_hidden)
        
        # Process each expert
        for expert_idx in range(self.num_routed_experts):
            # Find tokens that selected this expert
            expert_mask = (flat_indices == expert_idx)  # [batch*seq, top_k]
            
            if expert_mask.any():
                # Get tokens for this expert
                token_indices = expert_mask.any(dim=1).nonzero(as_tuple=True)[0]
                
                if len(token_indices) > 0:
                    # Get inputs for this expert
                    expert_input = flat_hidden[token_indices]
                    
                    # Process through expert
                    expert_output = self.routed_experts[expert_idx](expert_input)
                    
                    # Get weights for tokens that selected this expert
                    expert_token_weights = flat_weights[token_indices]
                    expert_token_mask = expert_mask[token_indices]
                    
                    # Weight expert output
                    weights = torch.where(
                        expert_token_mask,
                        expert_token_weights,
                        torch.zeros_like(expert_token_weights)
                    ).sum(dim=1, keepdim=True)
                    
                    # Add to routed output
                    routed_output[token_indices] += expert_output * weights
        
        # Reshape routed output
        routed_output = routed_output.view(batch_size, seq_len, input_dim)
        
        # ================================================================
        # 3. Combine Shared and Routed Outputs
        # ================================================================
        output = shared_output + routed_output
        
        # ================================================================
        # 4. Compute Auxiliary Losses (if training)
        # ================================================================
        aux_loss = None
        if return_aux_loss and self.training:
            aux_loss = self._compute_auxiliary_loss(router_logits, expert_indices)
        
        return output, aux_loss
    
    def _compute_auxiliary_loss(
        self,
        router_logits: torch.Tensor,
        expert_indices: torch.Tensor
    ) -> torch.Tensor:
        """
        Compute auxiliary losses for load balancing.
        
        Args:
            router_logits: [batch, seq, num_experts]
            expert_indices: [batch, seq, top_k]
        
        Returns:
            aux_loss: Combined auxiliary loss
        """
        batch_size, seq_len, num_experts = router_logits.size()
        
        # ================================================================
        # Load Balancing Loss
        # ================================================================
        # Compute fraction of tokens routed to each expert
        expert_counts = torch.zeros(num_experts, device=router_logits.device)
        
        for i in range(num_experts):
            expert_counts[i] = (expert_indices == i).sum().float()
        
        # Normalize
        total_tokens = batch_size * seq_len * self.top_k
        expert_fractions = expert_counts / total_tokens
        
        # Balance loss: encourages uniform distribution
        balance_loss = num_experts * (expert_fractions ** 2).sum()
        
        # ================================================================
        # Importance Loss
        # ================================================================
        # Compute importance (average gate value per expert)
        router_probs = F.softmax(router_logits, dim=-1)
        importance = router_probs.sum(dim=[0, 1])  # [num_experts]
        importance = importance / (batch_size * seq_len)
        
        # Importance loss: encourages diversity
        importance_loss = (importance ** 2).sum() * num_experts
        
        # ================================================================
        # Combined Loss
        # ================================================================
        aux_loss = (
            self.aux_loss_alpha * balance_loss +
            self.aux_loss_beta * importance_loss
        )
        
        return aux_loss



In [None]:

# ============================================================================
# 4. DEEPSEEK-V3 MOE LAYER (AUXILIARY-LOSS-FREE)
# ============================================================================

class DeepSeekV3MoE(nn.Module):
    """
    DeepSeek-V3 MoE layer with:
    - More routed experts (256)
    - Single shared expert
    - More experts per token (K=8)
    - Auxiliary-loss-free training (intrinsic balancing)
    """
    
    def __init__(
        self,
        input_dim: int = 7168,
        num_routed_experts: int = 256,
        expert_hidden_dim: int = 2048,
        top_k: int = 8,
        dropout: float = 0.0
    ):
        """
        Args:
            input_dim: Model dimension
            num_routed_experts: Number of routed experts
            expert_hidden_dim: Hidden dimension for each expert
            top_k: Number of experts to activate per token
            dropout: Dropout probability
        """
        super(DeepSeekV3MoE, self).__init__()
        
        self.input_dim = input_dim
        self.num_routed_experts = num_routed_experts
        self.expert_hidden_dim = expert_hidden_dim
        self.top_k = top_k
        
        # Router with bias (for intrinsic balancing)
        self.router = TopKRouter(
            input_dim=input_dim,
            num_experts=num_routed_experts,
            top_k=top_k,
            use_bias=True  # V3 uses bias for balancing
        )
        
        # Routed experts
        self.routed_experts = nn.ModuleList([
            SwiGLUExpert(input_dim, expert_hidden_dim, dropout)
            for _ in range(num_routed_experts)
        ])
        
        # Single shared expert
        self.shared_expert = SwiGLUExpert(input_dim, expert_hidden_dim, dropout)
    
    def forward(
        self,
        hidden_states: torch.Tensor
    ) -> torch.Tensor:
        """
        Forward pass (no auxiliary loss needed!).
        
        Args:
            hidden_states: [batch_size, seq_len, input_dim]
        
        Returns:
            output: [batch_size, seq_len, input_dim]
        """
        batch_size, seq_len, input_dim = hidden_states.size()
        
        # ================================================================
        # 1. Shared Expert Processing
        # ================================================================
        shared_output = self.shared_expert(hidden_states)
        
        # ================================================================
        # 2. Routed Expert Processing
        # ================================================================
        expert_indices, expert_weights, _ = self.router(hidden_states)
        
        # Flatten
        flat_hidden = hidden_states.view(-1, input_dim)
        flat_indices = expert_indices.view(-1, self.top_k)
        flat_weights = expert_weights.view(-1, self.top_k)
        
        # Initialize routed output
        routed_output = torch.zeros_like(flat_hidden)
        
        # Process each expert
        for expert_idx in range(self.num_routed_experts):
            expert_mask = (flat_indices == expert_idx)
            
            if expert_mask.any():
                token_indices = expert_mask.any(dim=1).nonzero(as_tuple=True)[0]
                
                if len(token_indices) > 0:
                    expert_input = flat_hidden[token_indices]
                    expert_output = self.routed_experts[expert_idx](expert_input)
                    
                    expert_token_weights = flat_weights[token_indices]
                    expert_token_mask = expert_mask[token_indices]
                    
                    weights = torch.where(
                        expert_token_mask,
                        expert_token_weights,
                        torch.zeros_like(expert_token_weights)
                    ).sum(dim=1, keepdim=True)
                    
                    routed_output[token_indices] += expert_output * weights
        
        routed_output = routed_output.view(batch_size, seq_len, input_dim)
        
        # ================================================================
        # 3. Combine (no auxiliary loss!)
        # ================================================================
        output = shared_output + routed_output
        
        return output



In [None]:
# ============================================================================
# 5. COMPLETE TRANSFORMER BLOCK WITH MOE
# ============================================================================

class RMSNorm(nn.Module):
    """RMS Layer Normalization."""
    
    def __init__(self, hidden_size: int, eps: float = 1e-6):
        super().__init__()
        self.weight = nn.Parameter(torch.ones(hidden_size))
        self.variance_epsilon = eps
    
    def forward(self, hidden_states):
        input_dtype = hidden_states.dtype
        hidden_states = hidden_states.to(torch.float32)
        variance = hidden_states.pow(2).mean(-1, keepdim=True)
        hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
        return self.weight * hidden_states.to(input_dtype)


class DeepSeekTransformerBlock(nn.Module):
    """
    Complete transformer block with MLA + MoE.
    Simplified version using standard attention instead of MLA.
    """
    
    def __init__(
        self,
        d_model: int = 5120,
        num_attention_heads: int = 128,
        moe_config: dict = None,
        dropout: float = 0.0
    ):
        """
        Args:
            d_model: Model dimension
            num_attention_heads: Number of attention heads
            moe_config: Configuration for MoE layer
            dropout: Dropout probability
        """
        super(DeepSeekTransformerBlock, self).__init__()
        
        # Layer norms
        self.input_layernorm = RMSNorm(d_model)
        self.post_attention_layernorm = RMSNorm(d_model)
        
        # Simplified: Use standard MHA (in practice, use MLA)
        self.self_attn = nn.MultiheadAttention(
            embed_dim=d_model,
            num_heads=num_attention_heads,
            dropout=dropout,
            batch_first=True
        )
        
        # MoE layer
        if moe_config is None:
            moe_config = {
                'type': 'v2',
                'num_routed_experts': 64,
                'num_shared_experts': 2,
                'expert_hidden_dim': 1536,
                'top_k': 6
            }
        
        if moe_config['type'] == 'v2':
            self.mlp = DeepSeekV2MoE(
                input_dim=d_model,
                num_routed_experts=moe_config['num_routed_experts'],
                num_shared_experts=moe_config['num_shared_experts'],
                expert_hidden_dim=moe_config['expert_hidden_dim'],
                top_k=moe_config['top_k'],
                dropout=dropout
            )
        else:  # v3
            self.mlp = DeepSeekV3MoE(
                input_dim=d_model,
                num_routed_experts=moe_config['num_routed_experts'],
                expert_hidden_dim=moe_config['expert_hidden_dim'],
                top_k=moe_config['top_k'],
                dropout=dropout
            )
    
    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        return_aux_loss: bool = True
    ):
        """
        Forward pass.
        
        Args:
            hidden_states: [batch, seq_len, d_model]
            attention_mask: Optional attention mask
            return_aux_loss: Whether to return auxiliary loss
        
        Returns:
            output: [batch, seq_len, d_model]
            aux_loss: Auxiliary loss (if applicable)
        """
        residual = hidden_states
        
        # Attention
        hidden_states = self.input_layernorm(hidden_states)
        attn_output, _ = self.self_attn(
            hidden_states,
            hidden_states,
            hidden_states,
            attn_mask=attention_mask,
            need_weights=False
        )
        hidden_states = residual + attn_output
        
        # MoE
        residual = hidden_states
        hidden_states = self.post_attention_layernorm(hidden_states)
        
        if isinstance(self.mlp, DeepSeekV2MoE):
            mlp_output, aux_loss = self.mlp(hidden_states, return_aux_loss)
        else:  # V3
            mlp_output = self.mlp(hidden_states)
            aux_loss = None
        
        hidden_states = residual + mlp_output
        
        return hidden_states, aux_loss



In [None]:

# ============================================================================
# 6. ANALYSIS AND BENCHMARKING
# ============================================================================

class MoEAnalyzer:
    """Analyze MoE behavior and statistics."""
    
    @staticmethod
    def analyze_expert_utilization(
        expert_indices: torch.Tensor,
        num_experts: int
    ) -> dict:
        """
        Analyze how experts are being utilized.
        
        Args:
            expert_indices: [batch, seq_len, top_k]
            num_experts: Total number of experts
        
        Returns:
            statistics: Dictionary with utilization stats
        """
        # Count expert selections
        expert_counts = torch.zeros(num_experts)
        
        for i in range(num_experts):
            expert_counts[i] = (expert_indices == i).sum().item()
        
        total_selections = expert_indices.numel()
        expert_fractions = expert_counts / total_selections
        
        return {
            'expert_counts': expert_counts.tolist(),
            'expert_fractions': expert_fractions.tolist(),
            'mean_fraction': expert_fractions.mean().item(),
            'std_fraction': expert_fractions.std().item(),
            'max_fraction': expert_fractions.max().item(),
            'min_fraction': expert_fractions.min().item(),
            'coefficient_of_variation': (expert_fractions.std() / expert_fractions.mean()).item()
        }
    
    @staticmethod
    def count_parameters(model: nn.Module) -> dict:
        """Count parameters in MoE model."""
        total = sum(p.numel() for p in model.parameters())
        
        # Try to separate routed and shared
        routed = 0
        shared = 0
        
        if hasattr(model, 'routed_experts'):
            for expert in model.routed_experts:
                routed += sum(p.numel() for p in expert.parameters())
        
        if hasattr(model, 'shared_experts'):
            for expert in model.shared_experts:
                shared += sum(p.numel() for p in expert.parameters())
        elif hasattr(model, 'shared_expert'):
            shared = sum(p.numel() for p in model.shared_expert.parameters())
        
        other = total - routed - shared
        
        return {
            'total': total,
            'routed_experts': routed,
            'shared_experts': shared,
            'other': other
        }


In [None]:

# ============================================================================
# 7. DEMONSTRATION EXAMPLES
# ============================================================================

def demo_v2_moe():
    """Demonstrate DeepSeek-V2 MoE."""
    print("="*80)
    print("DEMO 1: DeepSeek-V2 MoE Layer")
    print("="*80)
    
    batch_size = 2
    seq_len = 10
    d_model = 256  # Smaller for demo
    
    # Create V2 MoE
    moe = DeepSeekV2MoE(
        input_dim=d_model,
        num_routed_experts=16,  # Smaller for demo
        num_shared_experts=2,
        expert_hidden_dim=128,
        top_k=4
    )
    
    # Input
    hidden_states = torch.randn(batch_size, seq_len, d_model)
    
    # Forward
    output, aux_loss = moe(hidden_states)
    
    print(f"\nConfiguration:")
    print(f"  Input dim: {d_model}")
    print(f"  Routed experts: 16")
    print(f"  Shared experts: 2")
    print(f"  Top-K: 4")
    
    print(f"\nShapes:")
    print(f"  Input: {hidden_states.shape}")
    print(f"  Output: {output.shape}")
    
    print(f"\nAuxiliary Loss: {aux_loss.item():.6f}")
    
    # Parameter count
    params = MoEAnalyzer.count_parameters(moe)
    print(f"\nParameters:")
    print(f"  Total: {params['total']:,}")
    print(f"  Routed experts: {params['routed_experts']:,}")
    print(f"  Shared experts: {params['shared_experts']:,}")
    print(f"  Router: {params['other']:,}")


def demo_v3_moe():
    """Demonstrate DeepSeek-V3 MoE."""
    print("\n" + "="*80)
    print("DEMO 2: DeepSeek-V3 MoE Layer (Auxiliary-Loss-Free)")
    print("="*80)
    
    batch_size = 2
    seq_len = 10
    d_model = 256
    
    # Create V3 MoE
    moe = DeepSeekV3MoE(
        input_dim=d_model,
        num_routed_experts=32,  # Smaller for demo
        expert_hidden_dim=128,
        top_k=6
    )
    
    # Input
    hidden_states = torch.randn(batch_size, seq_len, d_model)
    
    # Forward (no auxiliary loss!)
    output = moe(hidden_states)
    
    print(f"\nConfiguration:")
    print(f"  Input dim: {d_model}")
    print(f"  Routed experts: 32")
    print(f"  Shared experts: 1")
    print(f"  Top-K: 6")
    
    print(f"\nKey Innovation: No auxiliary loss needed!")
    print(f"  Intrinsic balancing through:")
    print(f"    - Router bias")
    print(f"    - More experts")
    print(f"    - Higher K")
    
    params = MoEAnalyzer.count_parameters(moe)
    print(f"\nParameters:")
    print(f"  Total: {params['total']:,}")


def demo_expert_utilization():
    """Demonstrate expert utilization analysis."""
    print("\n" + "="*80)
    print("DEMO 3: Expert Utilization Analysis")
    print("="*80)
    
    batch_size = 8
    seq_len = 100
    d_model = 256
    num_experts = 16
    
    # Create MoE
    moe = DeepSeekV2MoE(
        input_dim=d_model,
        num_routed_experts=num_experts,
        num_shared_experts=2,
        expert_hidden_dim=128,
        top_k=4
    )
    
    # Generate data
    hidden_states = torch.randn(batch_size, seq_len, d_model)
    
    # Get routing decisions
    expert_indices, expert_weights, _ = moe.router(hidden_states)
    
    # Analyze
    stats = MoEAnalyzer.analyze_expert_utilization(
        expert_indices,
        num_experts
    )
    
    print(f"\nExpert Utilization Statistics:")
    print(f"  Mean usage: {stats['mean_fraction']:.4f}")
    print(f"  Std deviation: {stats['std_fraction']:.4f}")
    print(f"  Coefficient of variation: {stats['coefficient_of_variation']:.4f}")
    print(f"  Max usage: {stats['max_fraction']:.4f}")
    print(f"  Min usage: {stats['min_fraction']:.4f}")
    
    print(f"\nPer-Expert Usage:")
    for i, frac in enumerate(stats['expert_fractions']):
        bar = '█' * int(frac * 100)
        print(f"  Expert {i:2d}: {frac:.4f} {bar}")


def demo_comparison():
    """Compare V2 and V3."""
    print("\n" + "="*80)
    print("DEMO 4: V2 vs V3 Comparison")
    print("="*80)
    
    d_model = 512
    
    # V2 Configuration
    v2_config = {
        'num_routed_experts': 64,
        'num_shared_experts': 2,
        'expert_hidden_dim': 256,
        'top_k': 6
    }
    
    v2_moe = DeepSeekV2MoE(input_dim=d_model, **v2_config)
    v2_params = MoEAnalyzer.count_parameters(v2_moe)
    
    # V3 Configuration
    v3_config = {
        'num_routed_experts': 128,
        'expert_hidden_dim': 256,
        'top_k': 8
    }
    
    v3_moe = DeepSeekV3MoE(input_dim=d_model, **v3_config)
    v3_params = MoEAnalyzer.count_parameters(v3_moe)
    
    print("\nDeepSeek-V2 MoE:")
    print(f"  Routed experts: {v2_config['num_routed_experts']}")
    print(f"  Shared experts: {v2_config['num_shared_experts']}")
    print(f"  Top-K: {v2_config['top_k']}")
    print(f"  Total params: {v2_params['total']:,}")
    print(f"  Has auxiliary loss: Yes")
    
    print("\nDeepSeek-V3 MoE:")
    print(f"  Routed experts: {v3_config['num_routed_experts']}")
    print(f"  Shared experts: 1")
    print(f"  Top-K: {v3_config['top_k']}")
    print(f"  Total params: {v3_params['total']:,}")
    print(f"  Has auxiliary loss: No (intrinsic balancing)")
    
    # Active parameters
    v2_active = (v2_config['top_k'] + v2_config['num_shared_experts']) * v2_params['routed_experts'] / v2_config['num_routed_experts']
    v3_active = (v3_config['top_k'] + 1) * v3_params['routed_experts'] / v3_config['num_routed_experts']
    
    print(f"\nActive Parameters per Token:")
    print(f"  V2: ~{v2_active/1e6:.1f}M")
    print(f"  V3: ~{v3_active/1e6:.1f}M")



In [None]:
# ============================================================================
# MAIN EXECUTION
# ============================================================================

if __name__ == "__main__":
    print("\n" + "="*80)
    print("DEEPSEEK MIXTURE OF EXPERTS: COMPLETE IMPLEMENTATION")
    print("="*80)
    
    torch.manual_seed(42)
    
    # Run all demos
    demo_v2_moe()
    demo_v3_moe()
    demo_expert_utilization()
    demo_comparison()
    
    print("\n" + "="*80)
    print("All demonstrations completed successfully!")
    print("="*80)
    print("\nKey Takeaways:")
    print("1. Fine-grained experts (small, specialized)")
    print("2. Shared experts (always active, common knowledge)")
    print("3. V2: Auxiliary losses for balancing")
    print("4. V3: Auxiliary-loss-free with intrinsic balancing")
    print("5. Dramatic parameter efficiency gains")
    print("="*80)