Mixture of Experts is an architecture designed to replace standard dense Feed Forward layers. It is most widely used in natural language processing inside transformers. The idea is that instead of having a normal, dense FF layer, we have few "expert" FF layers instead. When a token enters a MOE layer, it is directed to one (or more) of the experts by a router module. Because of that, during a forward pass, the amount of compute needed is the same as if we were using a normal FF layer (plus the router calculation, but it is a very small cost). But our models can be sizably bigger. 

This allows us to train bigger models under the same compute budget, making better use of available memory. MOE models are faster to pre-train than dense models and are widely used in many modern transformers (eg Deep Seek uses a MOE architecture). There is also evidence that particular experts sometimes specialize in handling certain kinds of tokens, improving the overall efficiency of the model. 

Bellow is the implementation of Mixture of Experts layer. It contains:
* A naive, loop based implementation of MoE
* Vectorized and parallelizable implementation 
* A function that compares the outputs of both implementations, ensuring the correctness of the vectorized version.

This implementation of MOE was made as an assignment for the Natural Language Processing course at Machine Learning Masters degree at the University of Warsaw. We were given a general structure of how the code needs to look (eg. what classes do we have to implement) but the code was written by myself.

In [1]:
import torch
from torch import nn
import torch.nn.functional as F
import numpy as np
from transformers import PretrainedConfig
import math

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
class MLP(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.mlp = nn.Sequential(
            nn.Linear(config.hidden_size, config.intermediate_size),
            nn.ReLU(),
            nn.Linear(config.intermediate_size, config.hidden_size),
        )

    def forward(self, x):
        return self.mlp(x)

In [None]:
# Input: [batch_size, seq_len, hidden_size] - input embeddings
# Output: [batch_size, seq_len, num_experts] - expert routing weights
class Router(nn.Module):
    """
    Router module for a Mixture of Experts (MoE) transformer layer.

    Given a tensor of token embeddings with shape [batch_size, seq_len, hidden_size],
    the router computes a distribution over experts for each token.

    Returns:
        A tensor of shape [batch_size, seq_len, num_experts] containing
        routing weights (after softmax), where each token is assigned to
        a subset of experts.
    """
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.num_experts_per_token = config.num_experts_per_token
        self.hidden_size = config.hidden_size
        self.num_experts = config.num_experts

        self.expert_embeddings = nn.Parameter(torch.randn(self.num_experts, self.hidden_size))
        torch.nn.init.kaiming_uniform_(self.expert_embeddings, nonlinearity='linear')

    def chose_top_k_and_replace_rest_with_minus_inf(self, tensor, k, dim):
        """
        Sets all but top-k values along the specified dimension to -inf,
        so they are effectively ignored after softmax.
        """
        _, ind = torch.topk(-tensor, k=k, dim = dim)
        tensor.scatter_(index=ind, dim=dim, value=float('-inf'))
        return tensor

    def forward(self, x):
        result = torch.einsum("BSH,EH -> BSE", x, self.expert_embeddings)
        result = self.chose_top_k_and_replace_rest_with_minus_inf(result,self.num_experts -  self.num_experts_per_token, 2)
        result = F.softmax(result, dim = 2)
        return result


###TESTING### (if the dimensions in the output match)

def test_router(num_experts_per_token, hidden_size, num_experts, seq_len, batch_size):
    config = PretrainedConfig(
        num_experts_per_token=num_experts_per_token,
        hidden_size=hidden_size,
        num_experts=num_experts,
        batch_size = batch_size,
        seq_len = seq_len,
    )
    router = Router(config)
    x = torch.randn(config.batch_size, config.seq_len, config.hidden_size)
    print('Input: [batch_size, seq_len, hidden_size]: ', x.shape)
    result = router(x)
    print('Output: [batch_size, seq_len, num_experts]: ', result.shape)

test_router(num_experts_per_token = 3, hidden_size = 9, num_experts = 7, seq_len = 5, batch_size = 1)

Input: [batch_size, seq_len, hidden_size]:  torch.Size([1, 5, 9])
Output: [batch_size, seq_len, num_experts]:  torch.Size([1, 5, 7])


In [None]:
# Input: [batch_size, seq_len, hidden_size] - input embeddings
# Output: [batch_size, seq_len, hidden_size] - output embeddings
class Naive_MoE(nn.Module):
    '''
    Naive, loop based implementation of Mixture of Experts
    Num exerts per token controls how many experts are ascribed to every token
    Capacity factor controls how many tokens can go to a single expert
    Both input and output are tensors of shape [batch size, sequence length, hidden dim]
    '''
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.num_experts = config.num_experts
        self.hidden_size = config.hidden_size
        self.num_experts_per_token = config.num_experts_per_token
        self.capacity_factor = config.capacity_factor

        self.expert1 = nn.Parameter(torch.randn(config.num_experts, config.hidden_size, config.intermediate_size))
        torch.nn.init.kaiming_uniform_(self.expert1, nonlinearity='linear')

        self.expert2 = nn.Parameter(torch.randn(config.num_experts,config.intermediate_size, config.hidden_size))
        torch.nn.init.kaiming_uniform_(self.expert2, nonlinearity='linear')

        self.router = Router(config)

    def forward(self, x):
        batch_size, seq_len, hidden_size = x.shape
        expert_capacity = math.ceil(batch_size * seq_len / self.num_experts * self.capacity_factor)
        result = torch.zeros((batch_size, seq_len, hidden_size))
        weights = self.router(x)

        for expert in range(self.num_experts):
          token_count = 0
          full = False
          for batch in range(batch_size):
            for token in range(seq_len):
              if weights[batch, token, expert] > 0 and full == False:
                expert_result = torch.einsum('H, HI -> I', x[batch, token, :], self.expert1[expert])
                expert_result = torch.nn.functional.relu(expert_result)
                expert_result = torch.einsum('I, IH -> H', expert_result, self.expert2[expert])

                result[batch, token, :] += expert_result * weights[batch, token, expert]
                token_count += 1
                if token_count == expert_capacity:
                  full = True
        return result

###TESTING### (if the dimensions in the output match)

def test_Naive_Moe(num_experts_per_token, hidden_size, num_experts, seq_len, batch_size, capacity_factor):
    config = PretrainedConfig(
        num_experts_per_token=num_experts_per_token,
        hidden_size=hidden_size,
        num_experts=num_experts,
        batch_size = batch_size,
        seq_len = seq_len,
        capacity_factor = capacity_factor,
        intermediate_size=512,
    )
    moe = Naive_MoE(config)
    x = torch.randn(config.batch_size, config.seq_len, config.hidden_size)
    print('Input: [batch_size, seq_len, hidden_size]: ', x.shape)
    result = moe(x)
    print('Output: [batch_size, seq_len, hidden_size]: ', result.shape)

test_Naive_Moe(num_experts_per_token = 3, hidden_size = 9, num_experts = 5, seq_len = 7, batch_size = 1, capacity_factor = 2)

Input: [batch_size, seq_len, hidden_size]:  torch.Size([1, 7, 9])
Output: [batch_size, seq_len, hidden_size]:  torch.Size([1, 7, 9])


In [None]:
# Input: [batch_size, seq_len, hidden_size] - input embeddings
# Output: [batch_size, seq_len, hidden_size] - output embeddings
class MoE(nn.Module):
    '''
    Vectorized implementation of Mixture of Experts
    Both input and output are tensors of shape [batch size, sequence length, hidden dim]
    '''
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.num_experts = config.num_experts
        self.hidden_size = config.hidden_size
        self.num_experts_per_token = config.num_experts_per_token
        self.capacity_factor = config.capacity_factor

        self.expert1 = nn.Parameter(torch.randn(config.num_experts, config.hidden_size, config.intermediate_size))
        torch.nn.init.kaiming_uniform_(self.expert1, nonlinearity='linear')

        self.expert2 = nn.Parameter(torch.randn(config.num_experts,config.intermediate_size, config.hidden_size))
        torch.nn.init.kaiming_uniform_(self.expert2, nonlinearity='linear')

        self.router = Router(config)
        self.helper_tensor = torch.tensor(range(config.seq_len * config.batch_size, 0, -1))

    def forward(self, x):
        batch_size, seq_len, hidden_size = x.shape
        expert_capacity = math.ceil(batch_size * seq_len / self.num_experts * self.capacity_factor)

        result = torch.zeros((batch_size, seq_len, hidden_size))
        weights = self.router(x)
        weights = torch.reshape(weights, (batch_size * seq_len, self.num_experts))
        tokens = torch.reshape(x, (batch_size * seq_len, hidden_size))
        indexes = torch.where(weights > 0, 1, 0)
        indexes_segregated = indexes * (self.helper_tensor).unsqueeze(dim = 1)

        indexes = chose_top_k_and_replace_rest_with_value(indexes_segregated,self.config.seq_len * self.config.batch_size -  expert_capacity, 0, 0)
        #print('po wybraniu top k, \n', indexes)
        indexes = torch.where(indexes > 0, 1, 0)
        #print("1 tam gdzie dany ekspert bedzie przetwarzal: \n",indexes)
        weights_result = weights * indexes
        #print('Wagi tensorów które ekspert będzie przetważał: \n', weights_result)
        desired_values, desired_indexes = torch.topk(weights_result, expert_capacity, dim = 0)
        #print('indexy chcianych przez danego eksperta tensorów: \n', desired_indexes)
        #print('wagi chcianych przez danego eksperta tensorów: \n', desired_values)
        wektor_indeksow_do_index_select = desired_indexes.transpose(0, 1).flatten()
        wektor_wag_plaski = desired_values.transpose(0, 1).flatten()
        #print('sflatenowane indexy: \n', wektor_indeksow_do_index_select)
        #print('sflatenowane wagi: \n', wektor_wag_plaski)
        experciXcapacityXtokeny = torch.index_select(tokens, 0, wektor_indeksow_do_index_select)
        #print('wejsciowe tokeny: \n', tokeny)
        #print('experciXcapacityXtokeny: \n', experciXcapacityXtokeny)
        experciXcapacityXtokeny = torch.reshape(experciXcapacityXtokeny, (self.num_experts, expert_capacity, hidden_size))
        #print('experciXcapacityXtokeny trójwymiarowa macierz: \n', experciXcapacityXtokeny)
        #print('experci capacity tokeny shape: ', experciXcapacityXtokeny.shape)
        intermidiet_values = torch.einsum("ECH, EHI -> ECI", experciXcapacityXtokeny, self.expert1)
        intermidiet_values = torch.nn.functional.relu(intermidiet_values)
        intermidiet_values = torch.einsum("ECI, EIH -> ECH", intermidiet_values, self.expert2)
        #print('ksztalt po przepuszczeniu przez ekspertow: ', intermidiet_values.shape)
        #print('Trójwymiarowa macierz wypluta przez ekspertow: \n', intermidiet_values)
        tokensXhidden = torch.reshape(intermidiet_values, (self.config.num_experts * expert_capacity, hidden_size))
        #print('Zreshejpowane po przepuszczeniu przez ekspertow: \n', tokensXhidden)
        #print(tokensXhidden.shape, wektor_wag_plaski.unsqueeze(dim = 1).shape)
        tokensXhidden_po_pomnozeniu_przez_wagi = tokensXhidden * wektor_wag_plaski.unsqueeze(dim = 1)
        #print('Ksztalt po pomnozeniu przez wektor wag plaski: ', tokensXhidden.shape)
        result = torch.zeros((seq_len * batch_size, hidden_size))
        result.index_add_(0, wektor_indeksow_do_index_select, tokensXhidden_po_pomnozeniu_przez_wagi)
        #print('ostateczny wynik w formie ekspertXcapacity na hidden dim: \n', result)
        #print('result shape ', result.shape)
        result = torch.reshape(result, (batch_size, seq_len, hidden_size))



        #indexes == macierz (Token x Experci) mówiąca, które tokeny przetwarza który exper
        #where = torch.argwhere(indexes)
        #print(where)

        return result

###TESTING###

def test_Moe(num_experts_per_token, hidden_size, num_experts, seq_len, batch_size, capacity_factor):
    config = PretrainedConfig(
        num_experts_per_token=num_experts_per_token,
        hidden_size=hidden_size,
        num_experts=num_experts,
        batch_size = batch_size,
        seq_len = seq_len,
        capacity_factor = capacity_factor,
        intermediate_size=512,
    )
    moe = MoE(config)
    x = torch.randn(config.batch_size, config.seq_len, config.hidden_size)
    print('Input: [batch_size, seq_len, hidden_size]: ', x.shape)
    result = moe(x)
    print('Output: [batch_size, seq_len, hidden_size]: ', result.shape)

test_Moe(num_experts_per_token = 2, hidden_size = 7, num_experts = 3, seq_len = 5, batch_size = 2, capacity_factor = 1)

Input: [batch_size, seq_len, hidden_size]:  torch.Size([2, 5, 7])
Output: [batch_size, seq_len, hidden_size]:  torch.Size([2, 5, 7])


In [10]:
def compare_two_implementations(num_experts_per_token = 2, hidden_size = 7, num_experts = 3, seq_len = 5, batch_size = 2, capacity_factor = 1):
    config = PretrainedConfig(
      num_experts_per_token=num_experts_per_token,
      hidden_size=hidden_size,
      num_experts=num_experts,
      batch_size = batch_size,
      seq_len = seq_len,
      capacity_factor = capacity_factor,
      intermediate_size=512,
  )
    naive_moe = Naive_MoE(config)
    moe = MoE(config)
    router = Router(config)
    naive_moe.router = router
    moe.router = router
    moe.expert1 = naive_moe.expert1
    moe.expert2 = naive_moe.expert2
    input = torch.rand((batch_size, seq_len, hidden_size))
    result_moe = moe(input)
    result_naive = naive_moe(input)
    print(torch.equal(result_moe, result_naive))
    return(torch.max(abs(result_moe - result_naive)))
print(compare_two_implementations())

False
tensor(8.3819e-09, grad_fn=<MaxBackward1>)
