<a href="https://colab.research.google.com/github/A7mita/SwitchTransformer/blob/main/Switch_Transformer_Layer_in_PyTorch.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

# A simple Feed-Forward Network which will be our "expert"
class Expert(nn.Module):
    """
    A simple feed-forward network used as an expert in the Switch Transformer.
    It consists of two linear layers with a ReLU activation in between.
    """
    def __init__(self, d_model, d_ff):
        super().__init__()
        self.fc1 = nn.Linear(d_model, d_ff)
        self.fc2 = nn.Linear(d_ff, d_model)

    def forward(self, x):
        return self.fc2(F.relu(self.fc1(x)))

class SwitchFeedForward(nn.Module):
    """
    The Switch Transformer layer.

    This layer implements the core logic of routing tokens to different experts.
    It includes the router, the experts themselves, and the load balancing loss calculation.
    """
    def __init__(self, d_model, d_ff, num_experts, capacity_factor=1.25, drop_tokens=True):
        super().__init__()
        self.d_model = d_model
        self.d_ff = d_ff
        self.num_experts = num_experts
        self.capacity_factor = capacity_factor
        self.drop_tokens = drop_tokens

        # The router weights, mapping token embeddings to expert logits
        self.router = nn.Linear(d_model, num_experts)
        # A list of expert networks
        self.experts = nn.ModuleList([Expert(d_model, d_ff) for _ in range(num_experts)])

    def load_balancing_loss(self, router_probs, expert_mask):
        """
        Calculates the load balancing loss as described in the Switch Transformer paper.

        This loss encourages the router to send an equal number of tokens to each expert.

        Args:
            router_probs (torch.Tensor): Probabilities for each token-expert pair.
                                         Shape: (num_tokens, num_experts)
            expert_mask (torch.Tensor): A one-hot mask indicating which expert each token is routed to.
                                        Shape: (num_tokens, num_experts)

        Returns:
            torch.Tensor: The calculated load balancing loss.
        """
        num_tokens = router_probs.shape[0]

        # Fraction of tokens dispatched to each expert
        tokens_per_expert = torch.sum(expert_mask, dim=0)
        f_i = tokens_per_expert / num_tokens

        # Fraction of router probability allocated to each expert
        router_prob_per_expert = torch.sum(router_probs, dim=0)
        P_i = router_prob_per_expert / num_tokens

        # The loss is the scaled dot-product of f_i and P_i
        loss = self.num_experts * torch.sum(f_i * P_i)
        return loss

    def forward(self, x):
        """
        Forward pass for the Switch Transformer layer.

        Args:
            x (torch.Tensor): Input tensor. Shape: (batch_size, seq_len, d_model)

        Returns:
            torch.Tensor: The output tensor from the layer. Shape: (batch_size, seq_len, d_model)
            torch.Tensor: The auxiliary load balancing loss.
        """
        batch_size, seq_len, d_model = x.shape
        # Flatten the input to treat all tokens independently
        x = x.view(-1, d_model)
        num_tokens = x.shape[0]

        print("--- Input ---")
        print(f"Original input shape: ({batch_size}, {seq_len}, {d_model})")
        print(f"Reshaped input for routing (num_tokens, d_model): {x.shape}\n")

        # --- 1. Routing Logic ---
        # Get router logits and probabilities
        router_logits = self.router(x)
        router_probs = F.softmax(router_logits, dim=1)

        # Get the top-1 expert index and the corresponding gate value (probability)
        expert_gates, expert_indices = torch.max(router_probs, dim=1)

        # Create a one-hot mask for the chosen experts
        expert_mask = F.one_hot(expert_indices, self.num_experts).float()

        print("--- Routing ---")
        print(f"Router logits shape: {router_logits.shape}")
        print(f"Router probabilities shape: {router_probs.shape}")
        print(f"Expert indices (top-1 choice for each token): {expert_indices.shape}\n{expert_indices}\n")
        print(f"Expert gates (probability of top-1 choice): {expert_gates.shape}\n")
        print(f"Expert mask (one-hot): {expert_mask.shape}\n")

        # --- 2. Load Balancing Loss ---
        aux_loss = self.load_balancing_loss(router_probs, expert_mask)
        print("--- Load Balancing ---")
        print(f"Auxiliary Load Balancing Loss: {aux_loss.item():.4f}\n")


        # --- 3. Capacity Calculation and Token Dropping ---
        # Each expert has a fixed capacity
        expert_capacity = math.ceil((num_tokens / self.num_experts) * self.capacity_factor)

        # Count how many tokens are assigned to each expert
        tokens_per_expert = torch.sum(expert_mask, dim=0)

        print("--- Capacity ---")
        print(f"Expert capacity (max tokens per expert): {expert_capacity}")
        print(f"Tokens assigned to each expert: {tokens_per_expert.cpu().numpy()}\n")

        # Get the position of each token within its expert's batch
        # This is used to drop tokens if an expert's capacity is exceeded
        position_in_expert = torch.cumsum(expert_mask, dim=0) * expert_mask

        # Keep only tokens that are within the expert's capacity
        within_capacity = position_in_expert <= expert_capacity

        if self.drop_tokens:
            expert_mask = expert_mask * within_capacity.float()

        # The final gate value is the original gate masked by the capacity
        final_expert_gates = expert_gates * expert_mask.sum(dim=1)

        print("--- Token Dispatching (Pre-computation) ---")
        print(f"Position of each token in its expert's queue: {position_in_expert.shape}")
        print(f"Final expert mask (after capacity check): {expert_mask.shape}")
        dropped_tokens = num_tokens - expert_mask.sum()
        print(f"Tokens dropped due to capacity overflow: {dropped_tokens.item()}\n")

        # --- 4. Dispatch Tokens to Experts and Combine Results ---
        # Create a final output tensor of zeros
        y = torch.zeros_like(x)

        # This loop iterates through each expert, processes the tokens routed to it,
        # and scatters the results back to the correct positions in the output tensor.
        for i, expert in enumerate(self.experts):
            # Find which tokens are routed to this expert
            token_indices = (expert_mask[:, i] == 1).nonzero(as_tuple=True)[0]

            if token_indices.numel() > 0:
                # Gather the tokens for this expert
                expert_input = x[token_indices]

                # Process tokens through the expert
                expert_output = expert(expert_input)

                # Scatter the expert's output back to the main tensor `y`
                y.index_add_(0, token_indices, expert_output)

                print(f"-> Expert {i}: Processed {token_indices.numel()} tokens.")

        # Scale the output by the final gate values
        y = y * final_expert_gates.unsqueeze(1)

        # Reshape the output back to the original input shape
        y = y.view(batch_size, seq_len, d_model)

        print("\n--- Output ---")
        print(f"Final output shape (after combining experts): {y.shape}")

        return y, aux_loss


if __name__ == '__main__':
    # --- Toy Configuration ---
    # Using small values to make the output easy to inspect
    batch_size = 2
    seq_len = 5
    d_model = 8  # Hidden dimension size
    d_ff = 16    # Feed-forward inner dimension
    num_experts = 4
    capacity_factor = 1.0 # Set to 1.0 for simplicity, can be > 1.0

    print("="*50)
    print("Switch Transformer Layer - Toy Example")
    print("="*50)
    print(f"Configuration: batch_size={batch_size}, seq_len={seq_len}, d_model={d_model}, "
          f"num_experts={num_experts}, capacity_factor={capacity_factor}\n")

    # --- Model and Data ---
    # Instantiate the Switch layer
    switch_ffn = SwitchFeedForward(d_model, d_ff, num_experts, capacity_factor)

    # Create some random toy data
    input_tensor = torch.randn(batch_size, seq_len, d_model)

    # --- Run the Forward Pass ---
    # Get the output and the auxiliary loss from the layer
    output_tensor, loss = switch_ffn(input_tensor)

    print("\n--- Final Results ---")
    print(f"Final Output Tensor Shape: {output_tensor.shape}")
    print(f"Final Auxiliary Loss: {loss.item():.4f}")
    print("="*50)

Switch Transformer Layer - Toy Example
Configuration: batch_size=2, seq_len=5, d_model=8, num_experts=4, capacity_factor=1.0

--- Input ---
Original input shape: (2, 5, 8)
Reshaped input for routing (num_tokens, d_model): torch.Size([10, 8])

--- Routing ---
Router logits shape: torch.Size([10, 4])
Router probabilities shape: torch.Size([10, 4])
Expert indices (top-1 choice for each token): torch.Size([10])
tensor([1, 0, 1, 3, 1, 0, 0, 1, 1, 1])

Expert gates (probability of top-1 choice): torch.Size([10])

Expert mask (one-hot): torch.Size([10, 4])

--- Load Balancing ---
Auxiliary Load Balancing Loss: 1.1515

--- Capacity ---
Expert capacity (max tokens per expert): 3
Tokens assigned to each expert: [3. 6. 0. 1.]

--- Token Dispatching (Pre-computation) ---
Position of each token in its expert's queue: torch.Size([10, 4])
Final expert mask (after capacity check): torch.Size([10, 4])
Tokens dropped due to capacity overflow: 3.0

-> Expert 0: Processed 3 tokens.
-> Expert 1: Processed 