In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

## Expert Layer

In [2]:
class Expert(nn.Module):
    def __init__(self, in_size, h_size, o_size) -> None:
        super().__init__()
        self.in_size = in_size
        self.h_size = h_size
        self.o_size = o_size
        self.fc1 = nn.Linear(in_size, h_size)
        self.fc2 = nn.Linear(h_size, o_size)

    def forward(self, x):
        x = self.fc1(x)
        x = self.fc2(x)
        return F.gelu(x)

## Gate Layer

In [3]:
class Gate(nn.Module):
    def __init__(self, in_size, h_size, n_experts) -> None:
        super().__init__()
        self.in_size = in_size
        self.h_size = h_size
        self.n_experts = n_experts
        self.fc1 = nn.Linear(in_size, h_size, bias=False)
        self.fc2 = nn.Linear(h_size, n_experts, bias=False)
    
    def forward(self, x):
        x = self.fc1(x)
        x = self.fc2(x)
        x = F.gelu(x)
        return F.softmax(x, dim=-1)

## Dense MOE

In [4]:
class DenseMOE(nn.Module):
    def __init__(self, in_size, h_size, o_size, n_experts) -> None:
        super().__init__()
        self.in_size = in_size
        self.h_size = h_size
        self.o_size = o_size
        self.n_experts = n_experts
        self.experts = nn.ModuleList([Expert(self.in_size, self.h_size, self.o_size) for _ in range(self.n_experts)])
        self.gate = Gate(in_size, h_size, n_experts)

    def forward(self, x):
        gate_probs = self.gate(x)
        gate_probs = gate_probs.unsqueeze(-1)
        expert_output = [expert(x) for expert in self.experts]
        expert_output = torch.stack(expert_output, dim=-2)
        weight_output = gate_probs * expert_output
        return weight_output.sum(dim=-2)

In [5]:
batch = 2
seq_len = 4
in_size = 16
h_size = 64
o_size = 16
n_experts = 6
sample = torch.randn(batch, seq_len, in_size)
dense_model = DenseMOE(in_size, h_size, o_size, n_experts)
final_out = dense_model(sample)
print(final_out.shape)

torch.Size([2, 4, 16])


## Sparse MOE

In [59]:
class SparseMOE(nn.Module):
    def __init__(self, in_size, h_size, o_size, n_experts, top_k=1) -> None:
        super().__init__()
        self.in_size = in_size
        self.h_size = h_size
        self.o_size = o_size
        self.n_experts = n_experts
        self.top_k = top_k
        self.experts = nn.ModuleList([Expert(self.in_size, self.h_size, self.o_size) for _ in range(self.n_experts)])
        self.gate = Gate(self.in_size, self.h_size, self.n_experts)
    
    def forward(self, x):
        gate_probs = self.gate(x)
        token_probs, probs_idx = torch.topk(gate_probs, self.top_k, dim=-1)
        output = torch.zeros(x.shape).unsqueeze(2).expand(-1, -1, self.top_k, -1)
        for i in range(self.n_experts):
            expert = self.experts[i]
            mask = probs_idx == i
            if mask.any():
                token_indices = torch.where(mask)
                expert_tokens = x[token_indices[0], token_indices[1]]
                token_weights = token_probs[token_indices[0], token_indices[1], token_indices[2]].unsqueeze(-1)
                expert_out = expert(expert_tokens)
                processed_tokens = expert_out * token_weights
                output[token_indices[0], token_indices[1], token_indices[2]] = processed_tokens
                break
        return torch.sum(output, dim=-2)

In [62]:
batch = 1
seq_len = 4
in_size = 16
h_size = 64
o_size = 16
n_experts = 6
top_k = 1
sample = torch.randn(batch, seq_len, in_size)
sparse_model = SparseMOE(in_size, h_size, o_size, n_experts, top_k)
out = sparse_model(sample)
out.shape

torch.Size([1, 4, 16])