In [1]:
import torch
from torch import nn
import torch.nn.functional as F
from dataclasses import dataclass
from typing import Tuple, Optional, Literal

- block_size:

    * The block_size refers to the size of the input sequence chunks or blocks that the model processes. In transformer models, this is often related to the maximum sequence length or the size of attention windows.

    * In the context of MoE, block_size might determine how inputs are split or grouped for routing to experts.

- rank and world_size:

These terms are related to distributed training:

    * world_size: The total number of processes or devices participating in distributed training. For example, if you are using 4 GPUs, world_size would be 4.
        * world_size is used to divide the total number of experts (n_routed_experts) among the available devices.

    * rank: The unique identifier for each process or device in the distributed system. For example, in a 4-GPU setup, rank would be 0, 1, 2, or 3
        *rank determines which subset of experts is handled by the current device. For example, if world_size = 4 and rank = 0, the device handles the first quarter of the experts. 



   

1. Input Routing:

    - The input tensor is passed through the Gate module, which computes routing scores and selects the top-k experts for each input.

2. Expert Computation:

    - The input is routed to the selected experts, and each expert processes its assigned inputs independently.

3. Shared Experts:

    - In addition to the routed experts, shared experts are applied to all inputs. This ensures that common features are captured across all inputs.

4. Output Aggregation:

    - The outputs from the routed experts and shared experts are combined to produce the final output.

5. Distributed Training:

    - If world_size > 1, the outputs from different devices are aggregated using dist.all_reduce to synchronize the results across devices.

In [2]:
world_size = 1
rank = 0
block_size = 128
gemm_impl: Literal["bf16", "fp8"] = "bf16"
attn_impl: Literal["naive", "absorb"] = "absorb"

The Mixture of Experts (MoE) architecture is a neural network design that leverages multiple specialized sub-networks 

(called "experts") to process different parts of the input data. The key idea is to route different inputs to different experts,

 allowing the model to scale efficiently while maintaining high performance. Below is a detailed explanation of the DeepSeek MoE implementation provided in the code.

In [3]:
@dataclass
class ModelArgs:
    """
    Data class for defining model arguments and hyperparameters.

    Attributes:
        max_batch_size (int): Maximum batch size.
        max_seq_len (int): Maximum sequence length.
        dtype (Literal["bf16", "fp8"]): Data type for computations.
        vocab_size (int): Vocabulary size.
        dim (int): Model dimension.
        inter_dim (int): Intermediate dimension for MLP layers.
        moe_inter_dim (int): Intermediate dimension for MoE layers.
        n_layers (int): Number of transformer layers.
        n_dense_layers (int): Number of dense layers in the model.
        n_heads (int): Number of attention heads.
        n_routed_experts (int): Number of routed experts for MoE layers.
        n_shared_experts (int): Number of shared experts for MoE layers.
        n_activated_experts (int): Number of activated experts in MoE layers.
        n_expert_groups (int): Number of expert groups.
        n_limited_groups (int): Number of limited groups for MoE routing.
        score_func (Literal["softmax", "sigmoid"]): Scoring function for MoE routing.
        route_scale (float): Scaling factor for routing scores.
        q_lora_rank (int): LoRA rank for query projections.
        kv_lora_rank (int): LoRA rank for key-value projections.
        qk_nope_head_dim (int): Dimension for query-key projections without positional embeddings.
        qk_rope_head_dim (int): Dimension for query-key projections with rotary embeddings.
        v_head_dim (int): Dimension for value projections.
        original_seq_len (int): Original sequence length.
        rope_theta (float): Base for rotary positional encoding.
        rope_factor (float): Scaling factor for extended sequence lengths.
        beta_fast (int): Fast beta correction factor.
        beta_slow (int): Slow beta correction factor.
        mscale (float): Scaling factor for extended attention.
    """
    max_batch_size: int = 8
    max_seq_len: int = 4096 * 4
    dtype: Literal["bf16", "fp8"] = "bf16"
    vocab_size: int = 102400
    dim: int = 2048
    inter_dim: int = 10944
    moe_inter_dim: int = 1408
    n_layers: int = 27
    n_dense_layers: int = 1
    n_heads: int = 16
    # moe
    n_routed_experts: int = 64
    n_shared_experts: int = 2
    n_activated_experts: int = 6
    n_expert_groups: int = 1
    n_limited_groups: int = 1
    score_func: Literal["softmax", "sigmoid"] = "softmax"
    route_scale: float = 1.
    # mla
    q_lora_rank: int = 0
    kv_lora_rank: int = 512
    qk_nope_head_dim: int = 128
    qk_rope_head_dim: int = 64
    v_head_dim: int = 128
    # yarn
    original_seq_len: int = 4096
    rope_theta: float = 10000.0
    rope_factor: float = 40
    beta_fast: int = 32
    beta_slow: int = 1
    mscale: float = 1.

### Gate:

The Gate module is responsible for routing inputs to the appropriate experts. It computes scores for each expert and selects the top-k experts for each input.

Key attributes:

- dim: Input feature dimensionality.

- topk: Number of experts activated for each input.

- n_groups: Number of groups for routing.

- score_func: Scoring function (softmax or sigmoid) to compute routing weights.

- route_scale: Scaling factor for routing weights.

The forward method computes routing weights and indices for the selected experts.

**1.  self.topk = args.n_activated_experts**
 - Meaning:

    * self.topk specifies the number of experts that will be activated for each input. In other words, for every input, the gating mechanism selects the top-k experts (where k = n_activated_experts) to process that input.

 - Purpose:

   * Instead of sending every input to all experts (which would be computationally expensive), the MoE model only activates a small subset of experts for each input. This makes the model more efficient while still allowing it to leverage specialized experts.

 -  Example:

   * If n_activated_experts = 6, then for each input, the gating mechanism will select the 6 most relevant experts (based on the routing scores) to process that input.

**2.  self.n_groups = args.n_expert_groups**
 - Meaning:

   * self.n_groups defines the number of expert groups used in the routing mechanism. Experts can be divided into groups, and the gating mechanism can route inputs to specific groups of experts.

 - Purpose:

   * Grouping experts allows for more structured and efficient routing. For example, experts in the same group might specialize in similar types of inputs, and the gating mechanism can route inputs to the most relevant group(s).

 - Example:

   * If n_expert_groups = 4, the experts are divided into 4 groups. The gating mechanism will first decide which group(s) to route the input to, and then select the top-k experts within those groups.

**3.  self.topk_groups = args.n_limited_groups**
 - Meaning:

   * self.topk_groups specifies the number of expert groups that will be activated for each input. In other words, for each input, the gating mechanism will select the top-k groups (where k = n_limited_groups) to route the input to.

 - Purpose:

   * This parameter further refines the routing process by limiting the number of groups that can be activated for each input. This helps reduce computational overhead and ensures that inputs are routed to the most relevant groups of experts.

 - Example:

   * If n_limited_groups = 2, then for each input, the gating mechanism will select the 2 most relevant groups of experts and route the input to those groups. Within those groups, the top-k experts (as defined by n_activated_experts) will be activated.

#### How These Parameters Work Together
**1. Routing Process:**

- For each input, the gating mechanism computes routing scores for all experts.

- If n_groups > 1, the experts are divided into groups, and the routing scores are computed for each group.

* The gating mechanism selects the top-k groups (topk_groups) and the top-k experts (topk) within those groups.

**2. Example Workflow:**

- Suppose:

    * n_expert_groups = 4 (experts are divided into 4 groups).

    * n_limited_groups = 2 (for each input, the top 2 groups are selected).

    * n_activated_experts = 6 (for each input, the top 6 experts are activated).

- For a given input:

    * The gating mechanism computes routing scores for all 4 groups.

    * It selects the top 2 groups with the highest scores.

    * Within those 2 groups, it selects the top 6 experts (in total) with the highest scores.

    * The input is routed to those 6 experts for processing.

#### Key Relationship
* n_groups defines how the experts are organized into groups before routing.

* topk_groups determines how many of these groups will be activated for each input.

* topk determines how many experts will be activated within the selected groups.

In [4]:
class Gate(nn.Module):
    """
    Gating mechanism for routing inputs in a mixture-of-experts (MoE) model.

    Attributes:
        dim (int): Dimensionality of input features.
        topk (int): Number of top experts activated for each input.
        n_groups (int): Number of groups for routing.
        topk_groups (int): Number of groups to route inputs to.
        score_func (str): Scoring function ('softmax' or 'sigmoid').
        route_scale (float): Scaling factor for routing weights.
        weight (torch.nn.Parameter): Learnable weights for the gate.
        bias (Optional[torch.nn.Parameter]): Optional bias term for the gate.
    """
    def __init__(self, args: ModelArgs):
        """
        Initializes the Gate module.

        Args:
            args (ModelArgs): Model arguments containing gating parameters.
        """
        super().__init__()
        self.dim = args.dim
        self.topk = args.n_activated_experts
        self.n_groups = args.n_expert_groups
        self.topk_groups = args.n_limited_groups
        self.score_func = args.score_func
        self.route_scale = args.route_scale
        self.weight = nn.Parameter(torch.empty(args.n_routed_experts, args.dim))
        self.bias = nn.Parameter(torch.empty(args.n_routed_experts)) if self.dim == 7168 else None

    def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Forward pass for the gating mechanism.

        Args:
            x (torch.Tensor): Input tensor.

        Returns:
            Tuple[torch.Tensor, torch.Tensor]: Routing weights and selected expert indices.
        """
         # (scores)  =  u_t * e_i -> affine between token and experts
        scores = F.linear(x, self.weight)

        #s_i,t = softmax(scores) or s_i,t =sigmoid(scores)
        if self.score_func == "softmax":
            scores = scores.softmax(dim=-1, dtype=torch.float32)
        else:
            scores = scores.sigmoid()
        original_scores = scores

        # Auxiliary-Loss-Free Load Balancing
        if self.bias is not None:
            scores = scores + self.bias
        # equation 14 paper deepseek v3
        if self.n_groups > 1:
            #  Reshape the affinity scores into groups.
            scores = scores.view(x.size(0), self.n_groups, -1)  #
            # 1. Grouping:
            if self.bias is None:
                # get the maximum values  along the last dimension 
                group_scores = scores.amax(dim=-1)  
            else:
                group_scores = scores.topk(2, dim=-1)[0].sum(dim=-1)
            # 2. Select Top-k Groups
            indices = group_scores.topk(self.topk_groups, dim=-1)[1]
            # 3.  Masking: The scores of unselected groups are set to −∞ (effectively 0 in the softmax or sigmoid function).
            mask = scores.new_ones(x.size(0), self.n_groups, dtype=bool).scatter_(1, indices, False)
            # 4. Flattening: The scores are flattened back to their original shape, with only the top K_r groups retained.
            scores = scores.masked_fill_(mask.unsqueeze(-1), float("-inf")).flatten(1)

        indices = torch.topk(scores, self.topk, dim=-1)[1] # 

        weights = original_scores.gather(1, indices)

        # Normalize if fucntion is sigmoid. sincre softmax is already normalized 
        if self.score_func == "sigmoid":
            weights /= weights.sum(dim=-1, keepdim=True)

        # last part of equation  12 DeepSek V3 paper 
        weights *= self.route_scale
        
        return weights.type_as(x), indices

### Expert:

- The Expert module represents a single expert in the MoE model. It consists of linear layers (w1, w2, w3) that transform the input data.

- The forward method applies the expert's computation to the input tensor.

In [5]:
class Expert(nn.Module):
    """
    Expert layer for Mixture-of-Experts (MoE) models.

    Attributes:
        w1 (nn.Module): Linear layer for input-to-hidden transformation.
        w2 (nn.Module): Linear layer for hidden-to-output transformation.
        w3 (nn.Module): Additional linear layer for feature transformation.
    """
    def __init__(self, dim: int, inter_dim: int):
        """
        Initializes the Expert layer.

        Args:
            dim (int): Input and output dimensionality.
            inter_dim (int): Hidden layer dimensionality.
        """
        super().__init__()
        self.w1 = nn.Linear(dim, inter_dim)
        self.w2 = nn.Linear(inter_dim, dim)
        self.w3 = nn.Linear(dim, inter_dim)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Forward pass for the Expert layer.

        Args:
            x (torch.Tensor): Input tensor.

        Returns:
            torch.Tensor: Output tensor after expert computation.
        """
        return self.w2(F.silu(self.w1(x)) * self.w3(x))

In [6]:
class MLP(nn.Module):
    """
    Multi-Layer Perceptron (MLP) used as a feed-forward layer.

    Attributes:
        w1 (nn.Module): Linear layer for input-to-hidden transformation.
        w2 (nn.Module): Linear layer for hidden-to-output transformation.
        w3 (nn.Module): Additional linear layer for feature transformation.
    """
    def __init__(self, dim: int, inter_dim: int):
        """
        Initializes the MLP layer.

        Args:
            dim (int): Input and output dimensionality.
            inter_dim (int): Hidden layer dimensionality.
        """
        super().__init__()
        self.w1 = nn.Linear(dim, inter_dim)
        self.w2 = nn.Linear(inter_dim, dim)
        self.w3 = nn.Linear(dim, inter_dim)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Forward pass for the MLP layer.

        Args:
            x (torch.Tensor): Input tensor.

        Returns:
            torch.Tensor: Output tensor after MLP computation.
        """
        return self.w2(F.silu(self.w1(x)) * self.w3(x))

MoE (Mixture of Experts):

- The MoE module combines multiple experts and a gating mechanism to route inputs to the appropriate experts.

- Key attributes:

    * n_routed_experts: Total number of experts.

    * n_local_experts: Number of experts handled locally in distributed systems.

    * n_activated_experts: Number of experts activated for each input.

    * gate: Gating mechanism for routing.

    * experts: List of expert modules.

    * shared_experts: Shared experts applied to all inputs.

The forward method routes inputs to experts, computes their outputs, and combines them with shared expert outputs.

In [None]:
class MoE(nn.Module):
    """
    Mixture-of-Experts (MoE) module.

    Attributes:
        dim (int): Dimensionality of input features.
        n_routed_experts (int): Total number of experts in the model.
        n_local_experts (int): Number of experts handled locally in distributed systems.
        n_activated_experts (int): Number of experts activated for each input.
        gate (nn.Module): Gating mechanism to route inputs to experts.
        experts (nn.ModuleList): List of expert modules.
        shared_experts (nn.Module): Shared experts applied to all inputs.
    """
    def __init__(self, args: ModelArgs):
        """
        Initializes the MoE module.

        Args:
            args (ModelArgs): Model arguments containing MoE parameters.
        """
        super().__init__()
        self.dim = args.dim
        assert args.n_routed_experts % world_size == 0, f"Number of experts must be divisible by world size (world_size={world_size})"
        self.n_routed_experts = args.n_routed_experts
        self.n_local_experts = args.n_routed_experts // world_size
        self.n_activated_experts = args.n_activated_experts
        self.experts_start_idx = rank * self.n_local_experts
        self.experts_end_idx = self.experts_start_idx + self.n_local_experts
        self.gate = Gate(args)
        self.experts = nn.ModuleList([Expert(args.dim, args.moe_inter_dim) if self.experts_start_idx <= i < self.experts_end_idx else None for i in range(self.n_routed_experts)])
        self.shared_experts = MLP(args.dim, args.n_shared_experts * args.moe_inter_dim)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Forward pass for the MoE module.

        Args:
            x (torch.Tensor): Input tensor.

        Returns:
            torch.Tensor: Output tensor after expert routing and computation.
        """
        shape = x.size()
        x = x.view(-1, self.dim)
        weights, indices = self.gate(x)
        y = torch.zeros_like(x)
        counts = torch.bincount(indices.flatten(), minlength=self.n_routed_experts).tolist()
        for i in range(self.experts_start_idx, self.experts_end_idx):
            if counts[i] == 0:
                continue
            expert = self.experts[i]
            idx, top = torch.where(indices == i)
            y[idx] += expert(x[idx]) * weights[idx, top, None]
        z = self.shared_experts(x)

        return (y + z).view(shape)