## Mixture of Experts

In [None]:
import torch
from torch import nn
import numpy as np 
import math
import torch.distributions as dist

In [None]:
# Input shape is batch, token_size, embedding_size

In [None]:
# our experts are gonna be simple FFN with identical architectures 
class FFN(nn.Module):
    def __init__(self, shape, device):
        super(self).__init__()
    
        self.layer = nn.Sequential(
            nn.Linear(shape, shape, device = device),
            nn.ReLU(),
            nn.Linear(shape, shape, device = device)
        )
    def forward(self, x):
        return self.layer(x)


In [None]:
class Gating(nn.Module):
    def __init__(self, shape, num_experts, device):
        super().__init__()

        # shape is batch, token_size, embedding_size
        assert len(shape) ==3 
        batch, token_size, embedding_size = shape

        self.device = device
        self.num_experts = num_experts

        self.W_g = nn.Parameter(torch.zeros(embedding_size, num_experts).to(device))
        nn.init.xavier_uniform_(self.W_g)
        self.W_noise = nn.Parameter(torch.randn(embedding_size, num_experts).to(device))

        self.softmax= nn.Softmax(dim=-1) # dimension 0 will always be the batch so we do not apply softmax there
        self.softplus = nn.Softplus()
    
    def keepTopK(self, v, k):
        # Get the top-k values along the last dimension
        topk_values, _ = torch.topk(v, k, dim=-1)

        # Create a mask where only the top-k values remain
        v_mask = v >= topk_values[..., -1].unsqueeze(-1)

        # Set values outside the top-k to -inf
        v = torch.where(v_mask, v, torch.tensor(-float('inf'), device=self.device))

        return v
        
    def forward(self, x, k = 1):
        assert k <= self.num_experts

        prelim = torch.matmul(x, self.W_g)
        noise = torch.matmul(x, self.W_noise)

        # sample from standard normal
        standard_normal = dist.Normal(0, 1).sample((x.shape[0], x.shape[1], self.num_experts)).to(self.device)

        # apply softplus
        noise = self.softplus(noise)

        # comput H(x)
        h_x = prelim + standard_normal + noise
        return self.softmax(self.keepTopK(h_x, k))

In [12]:
def test_gating():
    gating = Gating((1, 2, 5), num_experts=4,  device='cpu')

    x = torch.randn((1, 2, 5))

    output = gating(x, k=2)
    assert output.shape == (1, 2, 4) # 3rd dimension is the number of experts
    print("Gating test passed")
test_gating()

Gating test passed


In [None]:
class MixtureOfExperts(nn.Module):
    def __init__(self, shape, number_of_experts, device):
        super( self).__init__()

        assert len(shape) == 3

        self.number_of_experts = number_of_experts  
        self.device = device

        self.gatings = Gating(shape, number_of_experts, device)
        self.experts = nn.ModuleList([FFN(shape, device) for _ in range(number_of_experts)])
        
    def forward(self, x, k = 1):

        assert len(x.shape) == 3

        assert k <= self.number_of_experts

        ret = self.gatings(x, k=k)

        y_out = torch.zeros_like(x)
        for idx in ret[-1]:
            if idx == 0:
                continue
            y_out += idx * self.experts[idx](x)

        return y_out
