In [2]:
import torch
import torch.nn as nn
from torch.nn import functional as F
torch.manual_seed(42)

<torch._C.Generator at 0x111a2c390>

In [3]:
class Expert(nn.Module):
    def __init__(self, n_embd, dropout):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(n_embd, 4 * n_embd),
            nn.ReLU(),
            nn.Linear(4 * n_embd, n_embd),
            nn.Dropout(dropout),
        )

    def forward(self, x):
        return self.net(x)
        

In [4]:
class TopKRouter(nn.Module):
    def __init__(self, n_embd, n_experts, top_k):
        super().__init__()
        self.n_embd = n_embd
        self.n_experts = n_experts
        self.top_k = top_k
        self.linear = nn.Linear(n_embd, n_experts)

    def forward(self, x):
        logits = self.linear(x)
        top_k_logits, top_k_indices = torch.topk(logits, self.top_k, dim=-1)
        negative_inf = torch.full_like(logits, float("-inf"))
        sparse_logits = negative_inf.scatter(dim=-1, index=top_k_indices, src=top_k_logits)
        router_output = F.softmax(sparse_logits, dim=-1)
        return router_output, top_k_indices

In [5]:
num_experts = 3 
top_k = 2
n_embd = 8

mh_output = torch.randn(1,4,n_embd)
top_k_router = TopKRouter(n_embd, num_experts, top_k)
router_output, top_k_indices = top_k_router(mh_output)
print(router_output)
print(top_k_indices)








tensor([[[0.5747, 0.4253, 0.0000],
         [0.3194, 0.0000, 0.6806],
         [0.3203, 0.0000, 0.6797],
         [0.5498, 0.4502, 0.0000]]], grad_fn=<SoftmaxBackward0>)
tensor([[[0, 1],
         [2, 0],
         [2, 0],
         [0, 1]]])


In [7]:
class SparseMoE(torch.nn.Module):
    def __init__(self, n_embd, n_experts, top_k, dropout):
        super().__init__()
        self.router = TopKRouter(n_embd, n_experts, top_k)
        self.experts = nn.ModuleList([Expert(n_embd, dropout) for _ in range(n_experts)])
        self.top_k = top_k
    
    def forward(self, x):
        gating_output, top_k_indices = self.router(x)
        final_output = torch.zeros_like(x) # (batch_size, seq_len, n_embd)
        # gating output -> (batch_size, seq_len, n_experts)
        # top_k_indices -> (batch_size, seq_len, top_k)
        flat_x = x.view(-1,x.size(-1)) # (batch_size * seq_len, n_embd)
        flat_gating_output = gating_output.view(-1,gating_output.size(-1)) # (batch_size * seq_len, n_experts)

        for i, expert in enumerate(self.experts):
            mask = (top_k_indices == i).any(dim=-1) # (batch_size,seq_len)
            flat_mask = mask.view(-1) # (batch_size * seq_len)
            # let current_expert_selected be the number of tokens selected for current expert
            if flat_mask.any():
                expert_input = flat_x[flat_mask] # (current_expert_selected, n_embd)
                expert_output = expert(expert_input) # (current_expert_selected, n_embd)

                gating_scores = flat_gating_output[flat_mask,i].unsqueeze(-1) # (current_expert_selected,1)
                weighted_expert_output = expert_output * gating_scores # (current_expert_selected, n_embd)

                final_output[mask] = weighted_expert_output
        
        return final_output
        


In [8]:
import torch
import torch.nn as nn
from torch.nn import functional as F
torch.manual_seed(42)

num_experts = 3 
top_k = 2
n_embd = 8
dropout = 0.1

mh_output = torch.randn(1,4,n_embd)
sparse_moe = SparseMoE(n_embd, num_experts, top_k, dropout)
output = sparse_moe(mh_output)
print(f"Output shape: {output.shape}")

print(output)

Output shape: torch.Size([1, 4, 8])
tensor([[[ 0.0000e+00, -2.0952e-01,  1.1119e-01, -3.4127e-01,  5.1096e-02,
           0.0000e+00,  1.0927e-01,  1.1165e-01],
         [ 0.0000e+00, -2.2709e-04,  2.1775e-01, -1.5330e-01,  6.7753e-02,
          -2.0473e-01, -2.4894e-01,  0.0000e+00],
         [ 1.3526e-01,  0.0000e+00,  7.6696e-01,  1.0333e-01,  6.1043e-02,
          -0.0000e+00, -0.0000e+00,  3.0101e-01],
         [ 3.8247e-02, -2.0839e-01,  3.1646e-01, -0.0000e+00,  2.6404e-02,
           0.0000e+00,  1.2235e-01,  1.0671e-01]]],
       grad_fn=<IndexPutBackward0>)
