## Mixture of Experts

In [2]:
import torch
from torch import nn
import numpy as np 
import math
import torch.distributions as dist


In [None]:
# Input shape is batch, token_size, embedding_size

In [None]:
# our experts are gonna be simple FFN with identical architectures 
class FFN(nn.Module):
    def __init__(self, shape, device):
        super(self).__init__()
    
        self.layer = nn.Sequential(
            nn.Linear(shape, shape, device = device).
            nn.SiLU(),
            nn.Linear(shape, shape, device = device)
        )
    def forward(self, x):
        return self.layer(x)


In [None]:
class Gating(nn.Module):
    def __init__(self, shape, num_experts, device):
        super( self).__init__()

        # shape is batch, token_size, embedding_size
        assert len(shape) ==3 
        batch, token_size, embedding_size = shape

        
        self.num_experts = num_experts

        self.W_g = nn.Parameter(torch.zeros(embedding_size, num_experts, device = device))
        nn.init.xavier_uniform_(self.W_g)
        self.W_noise = nn.Parameter(torch.randn(embedding_size, num_experts, device = device))

        self.softmax= nn.Softmax(dim=1) # dimension 0 will always be the batch so we do not apply softmax there
        self.softplus = nn.Softplus()
    
    def keepTopK(self, v, k):
        v[:,:,k:] =-torch.inf
        return v         

    def forward(self, x, k = 1):
        assert k <= self.num_experts

        prelim = torch.matmul(x, self.W_g)
        noise = torch.matmul(x, self.W_noise)

        standard_normal = dist.StandardNormal()(x.shape[0], x.shape[1],  self.num_experts, device = self.device).sample()
        noise = self.softplus(noise)

        h_x = prelim + standard_normal + noise
        return self.softmax(self.keepTopK(h_x, k))

In [None]:
def test_gating():
    gating = Gating(10, 10, 'cpu')
    ...

In [None]:
class MixtureOfExperts(nn.Module):
    def __init__(self, shape, number_of_experts, device):
        super( self).__init__()

        self.number_of_experts = number_of_experts  
        self.device = device
        
    def forward(self, x):
        ...