## Mixture of Experts

In [1]:
import torch
from torch import nn
import numpy as np 
import math
import torch.distributions as dist

In [2]:
# Input shape is batch, token_size, embedding_size

In [15]:
# our experts are gonna be simple FFN with identical architectures 
class FFN(nn.Module):
    def __init__(self, shape, device):
        super().__init__()
        assert len(shape) == 3
        
        self.layer = nn.Sequential(
            nn.Linear(shape[-1], shape[-1], device = device),
            nn.ReLU(),
            nn.Linear(shape[-1], shape[-1], device = device)
        )
    def forward(self, x):
        return self.layer(x)


In [16]:
def test_ffn():
    ffn = FFN((1, 2, 5), device='cpu')
    assert ffn(torch.ones((1, 2, 5))).shape == torch.ones((1, 2, 5)).shape
    print("FFN test passed")

test_ffn()

FFN test passed


In [17]:
class Gating(nn.Module):
    def __init__(self, shape, num_experts, device):
        super().__init__()

        # shape is batch, token_size, embedding_size
        assert len(shape) ==3 
        batch, token_size, embedding_size = shape

        self.device = device
        self.num_experts = num_experts

        self.W_g = nn.Parameter(torch.zeros(embedding_size, num_experts).to(device))
        nn.init.xavier_uniform_(self.W_g)
        self.W_noise = nn.Parameter(torch.randn(embedding_size, num_experts).to(device))

        self.softmax= nn.Softmax(dim=-1) # dimension 0 will always be the batch so we do not apply softmax there
        self.softplus = nn.Softplus()
    
    def keepTopK(self, v, k):
        # Get the top-k values along the last dimension
        topk_values, _ = torch.topk(v, k, dim=-1)

        # Create a mask where only the top-k values remain
        v_mask = v >= topk_values[..., -1].unsqueeze(-1)

        # Set values outside the top-k to -inf
        v = torch.where(v_mask, v, torch.tensor(-float('inf'), device=self.device))

        return v
        
    def forward(self, x, k = 1):
        assert k <= self.num_experts

        prelim = torch.matmul(x, self.W_g)
        noise = torch.matmul(x, self.W_noise)

        # sample from standard normal
        standard_normal = dist.Normal(0, 1).sample((x.shape[0], x.shape[1], self.num_experts)).to(self.device)

        # apply softplus
        noise = self.softplus(noise)

        # comput H(x)
        h_x = prelim + standard_normal + noise
        return self.softmax(self.keepTopK(h_x, k))

In [18]:
def test_gating():
    gating = Gating((1, 2, 5), num_experts=4,  device='cpu')

    x = torch.randn((1, 2, 5))

    output = gating(x, k=2)
    assert output.shape == (1, 2, 4) # 3rd dimension is the number of experts
    print("Gating test passed")
test_gating()

Gating test passed


In [21]:
class MixtureOfExperts(nn.Module):
    def __init__(self, shape, number_of_experts, device):
        super().__init__()

        assert len(shape) == 3

        self.number_of_experts = number_of_experts  
        self.device = device

        self.gatings = Gating(shape, number_of_experts, device)
        self.experts = nn.ModuleList([FFN(shape, device) for _ in range(number_of_experts)])
        
    def forward(self, x, k = 1):

        assert len(x.shape) == 3 # check the shape 

        assert k <= self.number_of_experts # assert the shape as k should be less than or equal to number of experts

        gates = self.gatings(x, k=k)

        y_out = torch.zeros_like(x)
         # Apply each expert weighted by its gate value
        for i in range(self.number_of_experts):
            # Only compute if this expert was selected for at least some inputs
            if torch.any(gates[:,:,i] > 0):
                expert_output = self.experts[i](x)
                # Weight the expert output by its gate value
                # expert_output shape is batch, token_size, embedding_size and gates shape is batch, token_size, num_experts and unsqueeze makes the individual shape as batch, token_size, 1
                y_out += expert_output * gates[:,:,i].unsqueeze(-1) 
        return y_out


In [22]:
# test the MoE

def test_moe():
    moe = MixtureOfExperts(shape = (1,24, 512), number_of_experts=20, device='cpu')

    x = torch.randn((1, 24, 512))

    output = moe(x, k=2)
    assert output.shape == (1, 24, 512)
    print("Moe test passed")
test_moe()

Moe test passed


## Used Claude to evaluate the model creation and here are its suggestion:
Minor Improvement Suggestions:

Parameter naming consistency: Consider renaming self.gatings to self.gating in MixtureOfExperts for clarity
Fix syntax formatting: There are still some formatting issues with asterisks around parameter names in the Gating and FFN classes (*self*, *shape*, etc.)
Add load balancing loss: The paper recommends an auxiliary loss to ensure balanced expert utilization

Advanced Enhancements (if needed):

Expert capacity: Implement per-expert capacity limits to prevent overloading
Batching optimizations: Group inputs by selected experts for more efficient batch processing
Layer normalization: Consider adding layer norm before/after expert application