In [13]:
import torch 
import torch.nn as nn
import torch.nn.functional as F

In [2]:
class Expert(nn.Module):
    def __init__(self,n_embd,dropout):
        super().__init__()
        self.net = (
            nn.Linear(n_embd,4 * n_embd),
            nn.ReLU(),
            nn.Linear(4 * n_embd,n_embd),
            nn.Dropout(dropout)
        )
    def forward(self, x):
        return self.net(x)    

In [5]:
num_experts = 4
top_k = 2
n_embed = 4
batch = 2
num_tokens = 3
mh_output = torch.randn(batch,num_tokens,n_embed)

topkgate_linear = nn.Linear(n_embed,num_experts)

logits = topkgate_linear(mh_output)

top_k_logits,top_k_indices = logits.topk(top_k,dim = -1)



In [11]:
top_k_logits , top_k_indices , logits

(tensor([[[ 0.0014, -0.0103],
          [ 0.7816,  0.2050],
          [ 1.2231,  0.9853]],
 
         [[ 0.9302,  0.9230],
          [ 0.3702, -0.0118],
          [ 0.5589,  0.3120]]], grad_fn=<TopkBackward0>),
 tensor([[[3, 2],
          [2, 1],
          [2, 3]],
 
         [[1, 2],
          [2, 1],
          [1, 2]]]),
 tensor([[[-6.0454e-02, -2.4748e-01, -1.0280e-02,  1.3634e-03],
          [-8.4506e-02,  2.0497e-01,  7.8162e-01, -5.4297e-01],
          [ 4.1327e-01, -2.2426e+00,  1.2231e+00,  9.8526e-01]],
 
         [[-2.9083e-01,  9.3017e-01,  9.2298e-01, -9.4899e-01],
          [-3.3395e-01, -1.1830e-02,  3.7025e-01, -4.5657e-01],
          [-5.0865e-01,  5.5890e-01,  3.1198e-01, -1.4030e-03]]],
        grad_fn=<ViewBackward0>))

In [9]:
zeros = torch.full_like(logits , float("-inf"))
sparse_logits = zeros.scatter(-1,top_k_indices,top_k_logits)

In [10]:
sparse_logits

tensor([[[   -inf,    -inf, -0.0103,  0.0014],
         [   -inf,  0.2050,  0.7816,    -inf],
         [   -inf,    -inf,  1.2231,  0.9853]],

        [[   -inf,  0.9302,  0.9230,    -inf],
         [   -inf, -0.0118,  0.3702,    -inf],
         [   -inf,  0.5589,  0.3120,    -inf]]], grad_fn=<ScatterBackward0>)

In [15]:
gating_outputs = F.softmax(sparse_logits,dim=-1)

In [16]:
gating_outputs

tensor([[[0.0000, 0.0000, 0.4971, 0.5029],
         [0.0000, 0.3597, 0.6403, 0.0000],
         [0.0000, 0.0000, 0.5592, 0.4408]],

        [[0.0000, 0.5018, 0.4982, 0.0000],
         [0.0000, 0.4056, 0.5944, 0.0000],
         [0.0000, 0.5614, 0.4386, 0.0000]]], grad_fn=<SoftmaxBackward0>)

In [22]:
class TopKROuter(nn.Module):
    def __init__(self,emb_dim, num_experts, top_k):
        super().__init__()
        self.emb_dim = emb_dim
        self.linear = nn.Linear(n_embed,num_experts)

    def forward(self,mh_output):
        logits = self.linear(mh_output)
        top_k_logits,top_k_indices = logits.topk(top_k,dim=-1)
        zeros = torch.full_like(logits,float("-inf"))

        sparse_logits = zeros.scatter(-1,top_k_indices,top_k_logits)

        router_output = F.softmax(sparse_logits,dim=-1)

        return router_output,top_k_indices

In [23]:
num_experts = 4
top_k = 2
n_embed = 32


mh_output = torch.randn(3,5,n_embed)
top_k_gate = TopKROuter(n_embed,num_experts=num_experts,top_k=top_k)
output,indices = top_k_gate(mh_output)

In [75]:
class NoisyTopkRouter(nn.Module):
    def __init__(self,n_embed,num_experts ,top_k):
        super().__init__()
        self.top_k = top_k
        self.topkroute_linear  = nn.Linear(n_embed, num_experts)
        self.noise_linear  = nn.Linear(n_embed,num_experts)
    def forward(self,mh_output):
        logits = self.topkroute_linear(mh_output)
        noise_logits = self.noise_linear(mh_output)
        gaussian_noise = torch.randn_like(logits)

        final_noise = F.softplus(noise_logits) * gaussian_noise
        final_logits = final_noise + logits

        final_logits,indices = final_logits.topk(self.top_k,dim=-1)
        zeros = torch.full_like(logits,float("-inf"))

        sparse_logits = zeros.scatter(-1,indices,final_logits)

        router_output = F.softmax(sparse_logits,dim=-1)

        return router_output,indices
    

In [76]:
num_experts = 8
top_k = 2 
n_embed = 16

mh_output = torch.randn(2, 4, n_embed)
noisy_top_k_gate = NoisyTopkRouter(n_embed, num_experts, top_k)
gating_output, indices = noisy_top_k_gate(mh_output)
gating_output.shape,gating_output,indices.shape

(torch.Size([2, 4, 8]),
 tensor([[[0.0000, 0.0000, 0.0000, 0.0000, 0.2990, 0.0000, 0.7010, 0.0000],
          [0.0000, 0.0000, 0.0000, 0.6048, 0.3952, 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.5774, 0.4226, 0.0000],
          [0.0000, 0.4575, 0.0000, 0.5425, 0.0000, 0.0000, 0.0000, 0.0000]],
 
         [[0.0000, 0.0000, 0.0000, 0.0000, 0.4817, 0.0000, 0.5183, 0.0000],
          [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.3991, 0.6009],
          [0.0000, 0.0000, 0.4223, 0.5777, 0.0000, 0.0000, 0.0000, 0.0000],
          [0.0000, 0.4879, 0.0000, 0.5121, 0.0000, 0.0000, 0.0000, 0.0000]]],
        grad_fn=<SoftmaxBackward0>),
 torch.Size([2, 4, 2]))

In [85]:
import torch
import torch.nn as nn

class SparseMoE(nn.Module):
    def __init__(self, n_embed, num_experts, top_k):
        super(SparseMoE, self).__init__()
        self.router = NoisyTopkRouter(n_embed, num_experts, top_k)
        self.experts = nn.ModuleList([Expert(n_embed) for _ in range(num_experts)])
        self.top_k = top_k

    def forward(self, x):
        print("Input x:", x.shape)
        gating_output, indices = self.router(x)
        print("gating_output:", gating_output.shape, "\n", gating_output)
        print("indices:", indices.shape, "\n", indices)
        
        final_output = torch.zeros_like(x)
        print("final_output:", final_output.shape, "\n", final_output)
        
        # Reshape inputs for batch processing
        flat_x = x.view(-1, x.size(-1))
        print("flat_x:", flat_x.shape, "\n", flat_x)
        
        flat_gating_output = gating_output.view(-1, gating_output.size(-1))
        print("flat_gating_output:", flat_gating_output.shape, "\n", flat_gating_output)
        
        # Process each expert in parallel
        for i, expert in enumerate(self.experts):
            expert_mask = (indices == i).any(dim=-1)
            print(f"expert_mask for expert {i}:", expert_mask.shape, "\n", expert_mask)
            
            flat_mask = expert_mask.view(-1)
            print(f"flat_mask for expert {i}:", flat_mask.shape, "\n", flat_mask)
            
            if flat_mask.any():
                expert_input = flat_x[flat_mask]
                print(f"expert_input for expert {i}:", expert_input.shape, "\n", expert_input)
                
                expert_output = expert(expert_input)
                print(f"expert_output for expert {i}:", expert_output.shape, "\n", expert_output)
                
                gating_scores = flat_gating_output[flat_mask, i].unsqueeze(1)
                print(f"gating_scores for expert {i}:", gating_scores.shape, "\n", gating_scores)
                
                weighted_output = expert_output * gating_scores
                print(f"weighted_output for expert {i}:", weighted_output.shape, "\n", weighted_output)
                
                final_output[expert_mask] += weighted_output.squeeze(1)
        
        print("final_output after processing:", final_output.shape, "\n", final_output)
        return final_output


In [86]:
import torch
import torch.nn as nn


num_experts = 4
top_k = 2
n_embd = 6
dropout=0.1

mh_output = torch.randn(4, 8, n_embd)  # Example multi-head attention output
sparse_moe = SparseMoE(n_embd, num_experts, top_k)
final_output = sparse_moe(mh_output)
print("Shape of the final output:", final_output.shape)

Input x: torch.Size([4, 8, 6])
gating_output: torch.Size([4, 8, 4]) 
 tensor([[[0.0000, 0.2498, 0.7502, 0.0000],
         [0.7749, 0.0000, 0.2251, 0.0000],
         [0.0000, 0.0000, 0.3653, 0.6347],
         [0.4675, 0.0000, 0.0000, 0.5325],
         [0.0000, 0.3782, 0.6218, 0.0000],
         [0.0000, 0.6314, 0.3686, 0.0000],
         [0.6414, 0.3586, 0.0000, 0.0000],
         [0.0000, 0.8138, 0.1862, 0.0000]],

        [[0.0000, 0.3821, 0.6179, 0.0000],
         [0.6282, 0.0000, 0.3718, 0.0000],
         [0.0000, 0.5740, 0.0000, 0.4260],
         [0.0000, 0.0000, 0.4473, 0.5527],
         [0.4817, 0.0000, 0.5183, 0.0000],
         [0.0000, 0.0000, 0.2734, 0.7266],
         [0.4092, 0.0000, 0.5908, 0.0000],
         [0.0000, 0.3890, 0.0000, 0.6110]],

        [[0.0000, 0.7543, 0.2457, 0.0000],
         [0.0000, 0.0000, 0.1928, 0.8072],
         [0.9380, 0.0000, 0.0000, 0.0620],
         [0.0000, 0.3392, 0.6608, 0.0000],
         [0.7222, 0.0000, 0.2778, 0.0000],
         [0.4829, 0.517

In [87]:
class MultiHeadAttention(nn.Module):
    def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias=False):
        super().__init__()
        assert d_out % num_heads == 0, "d_out must be divisible by num_heads"
        self.d_out = d_out
        self.num_heads = num_heads
        self.head_dim = d_out // num_heads
        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.out_proj = nn.Linear(d_out, d_out)
        self.dropout = nn.Dropout(dropout)
        self.register_buffer("mask", torch.triu(torch.ones(context_length, context_length), diagonal=1))

    def forward(self, x):
        b, num_tokens, d_in = x.shape
        keys = self.W_key(x).view(b, num_tokens, self.num_heads, self.head_dim).transpose(1, 2)
        queries = self.W_query(x).view(b, num_tokens, self.num_heads, self.head_dim).transpose(1, 2)
        values = self.W_value(x).view(b, num_tokens, self.num_heads, self.head_dim).transpose(1, 2)
        attn_scores = queries @ keys.transpose(-2, -1)
        mask_bool = self.mask.bool()[:num_tokens, :num_tokens]
        attn_scores.masked_fill_(mask_bool, -torch.inf)
        attn_weights = torch.softmax(attn_scores / keys.shape[-1] ** 0.5, dim=-1)
        attn_weights = self.dropout(attn_weights)
        context_vec = attn_weights @ values
        context_vec = context_vec.transpose(1, 2).contiguous().view(b, num_tokens, self.d_out)
        return self.out_proj(context_vec)

In [89]:
class MOEBLOCK(nn.Module):

    def __init__(self, n_embed, n_head, num_experts, top_k , context_length):
        super().__init__()
        head_size = n_embed // n_head
        self.sa = MultiHeadAttention(n_embed,n_embed,context_length,dropout=None,num_heads=n_head)
        self.smoe = SparseMoE(n_embed, num_experts, top_k)
        self.ln1 = nn.LayerNorm(n_embed)
        self.ln2 = nn.LayerNorm(n_embed)

    def forward(self,x):
        shortcut = x 
        x = self.sa(self.ln1(x))
        x = x + shortcut
        shortcut = x
        x = self.smoe(self.ln2(x))
        x = x + shortcut

        return x 

In [90]:
class SparseMoELM(nn.Module):
    def __init__(self,cfg):
        super().__init__()
        self.embeddings = nn.Embedding(cfg["vocab_size"],cfg["n_embed"])
        self.positional_embedding = nn.Embedding(cfg["vocab_size"],cfg["n_embed"])
        self.blocks = nn.Sequential(*[MOEBLOCK(cfg["n_embed"],cfg["n_head"],cfg["num_experts"],cfg["top_k"],cfg["context_length"])for _ in range(cfg["n_layers"])])
        self.finalNorm = nn.LayerNorm(cfg["n_embed"])
        self.lm_head = nn.Linear(cfg["n_embed"],cfg["vocab_size"])

    def forward(self,idx):
        B, T = idx.shape

        tok_emb = self.embeddings(idx)
        pos_emb = self.positional_embedding(idx)
        x = tok_emb + pos_emb
        x = self.blocks(x)
        x = self.finalNorm(x)
        logits = self.lm_head(x)