In [14]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [28]:
torch.manual_seed(42)

B, T, C = 4, 8, 32
h = C // 2

x = torch.randn(B, T, C)

Wq = nn.Linear(C, h, bias=False)
Wk = nn.Linear(C, h, bias=False)
Wv = nn.Linear(C, h, bias=False)

q = Wq(x)
k = Wk(x)
attn = q @ k.transpose(-2, -1) * h**-0.5
attn = attn.masked_fill(torch.tril(torch.ones(T, T)) == 0, float('-inf'))
attn = F.softmax(attn, dim=-1)

v = Wv(x)
out = attn @ v
out.shape

torch.Size([4, 8, 16])

Below we implement MoE

In [40]:
class Head(nn.Module):
    """ one head of self-attention """

    def __init__(self):
        super().__init__()
        self.key = nn.Linear(64, 16, bias=False)
        self.query = nn.Linear(64, 16, bias=False)
        self.value = nn.Linear(64, 16, bias=False)
        self.register_buffer('tril', torch.tril(torch.ones(8, 8)))
        self.dropout = nn.Dropout(p=0.1)

    def forward(self, x):
        B,T,C = x.shape
        k = self.key(x)   # (B,T,C)
        q = self.query(x) # (B,T,C)
        # compute attention scores ("affinities")
        wei = q @ k.transpose(-2,-1) * C**-0.5 # (B, T, C) @ (B, C, T) -> (B, T, T)
        wei = wei.masked_fill(self.tril[:T, :T] == 0, float('-inf')) # (B, T, T)
        wei = F.softmax(wei, dim=-1) # (B, T, T)
        wei = self.dropout(wei)
        # perform the weighted aggregation of the values
        v = self.value(x) # (B,T,C)
        out = wei @ v # (B, T, T) @ (B, T, C) -> (B, T, C)
        return out

head = Head()
head.parameters

<bound method Module.parameters of Head(
  (key): Linear(in_features=64, out_features=16, bias=False)
  (query): Linear(in_features=64, out_features=16, bias=False)
  (value): Linear(in_features=64, out_features=16, bias=False)
  (dropout): Dropout(p=0.1, inplace=False)
)>

In [41]:
class MultiHeadAttention(nn.Module):
    """ multiple heads of self-attention in parallel """

    def __init__(self):
        super().__init__()
        self.heads = nn.ModuleList([Head(16) for _ in range(4)])
        self.proj = nn.Linear(64, 64)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        out = torch.cat([h(x) for h in self.heads], dim=-1)
        out = self.dropout(self.proj(out))
        return out

In [42]:
class Expert(nn.Module):
    """ An MLP is a simple linear layer followed by a non-linearity i.e. each Expert """

    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(64, 4 * 64),
            nn.ReLU(),
            nn.Linear(4 * 64, 64),
            nn.Dropout(dropout),
        )

    def forward(self, x):
        return self.net(x)

In [48]:
mha_out = torch.randn(2, 8, 64)
gate = nn.Linear(64, 4)
logits = gate(mha_out)
top_k_logits, top_k_indices = logits.topk(2, dim=-1)
top_k_logits.shape, top_k_indices

(torch.Size([2, 8, 2]),
 tensor([[[1, 2],
          [2, 0],
          [1, 2],
          [1, 3],
          [1, 3],
          [3, 0],
          [2, 3],
          [0, 2]],
 
         [[1, 2],
          [2, 0],
          [1, 3],
          [1, 2],
          [0, 3],
          [3, 0],
          [1, 2],
          [2, 3]]]))

In [57]:
zeros = torch.full_like(logits, float('-inf'))
sparse_logits = zeros.scatter(-1, top_k_indices, top_k_logits)
sparse_logits

tensor([[[   -inf,  0.5624, -0.2478,    -inf],
         [ 0.9851,    -inf,  1.0902,    -inf],
         [   -inf, -0.2416, -0.2549,    -inf],
         [   -inf,  0.8075,    -inf,  0.4231],
         [   -inf,  0.9113,    -inf,  0.2213],
         [-0.1508,    -inf,    -inf,  0.5258],
         [   -inf,    -inf,  1.7084,  0.9629],
         [ 0.2472,    -inf,  0.1298,    -inf]],

        [[   -inf,  0.3157,  0.1252,    -inf],
         [ 0.0088,    -inf,  0.1553,    -inf],
         [   -inf,  1.2441,    -inf,  0.1329],
         [   -inf,  1.1404,  0.8150,    -inf],
         [ 0.2541,    -inf,    -inf, -0.1712],
         [ 0.0091,    -inf,    -inf,  0.1471],
         [   -inf,  0.7702,  0.6067,    -inf],
         [   -inf,    -inf,  0.2182, -0.0720]]], grad_fn=<ScatterBackward0>)

In [None]:
gate_out = F.softmax(sparse_logits, dim=-1)
gate_out

tensor([[[0.0000, 0.6922, 0.3078, 0.0000],
         [0.4737, 0.0000, 0.5263, 0.0000],
         [0.0000, 0.5033, 0.4967, 0.0000],
         [0.0000, 0.5949, 0.0000, 0.4051],
         [0.0000, 0.6660, 0.0000, 0.3340],
         [0.3370, 0.0000, 0.0000, 0.6630],
         [0.0000, 0.0000, 0.6782, 0.3218],
         [0.5293, 0.0000, 0.4707, 0.0000]],

        [[0.0000, 0.5475, 0.4525, 0.0000],
         [0.4634, 0.0000, 0.5366, 0.0000],
         [0.0000, 0.7524, 0.0000, 0.2476],
         [0.0000, 0.5806, 0.4194, 0.0000],
         [0.6048, 0.0000, 0.0000, 0.3952],
         [0.4655, 0.0000, 0.0000, 0.5345],
         [0.0000, 0.5408, 0.4592, 0.0000],
         [0.0000, 0.0000, 0.5720, 0.4280]]], grad_fn=<SoftmaxBackward0>)

In [None]:
class TopkRouter(nn.Module):
    def __init__(self, top_k=2):
        super(TopkRouter, self).__init__()
        self.top_k = top_k
        self.linear = nn.Linear(64, 4)
    
    def forward(self, mha_out):
        logits = self.linear(mha_out)
        top_k_logits, indices = logits.topk(self.top_k, dim=-1)
        zeros = torch.full_like(logits, float('-inf'))
        sparse_logits = zeros.scatter(-1, indices, top_k_logits)
        router_output = F.softmax(sparse_logits, dim=-1)
        return router_output, indices

In [73]:
mha_output = torch.randn(4, 8, 64) 
top_k_gate = TopkRouter()
gating_output, indices = top_k_gate(mha_output)
gating_output.shape, gating_output, indices

(torch.Size([4, 8, 4]),
 tensor([[[0.4396, 0.0000, 0.5604, 0.0000],
          [0.0000, 0.5184, 0.4816, 0.0000],
          [0.5368, 0.4632, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.5307, 0.4693],
          [0.4503, 0.0000, 0.5497, 0.0000],
          [0.8090, 0.0000, 0.1910, 0.0000],
          [0.2159, 0.0000, 0.0000, 0.7841],
          [0.5806, 0.0000, 0.4194, 0.0000]],
 
         [[0.0000, 0.4851, 0.5149, 0.0000],
          [0.0000, 0.4580, 0.0000, 0.5420],
          [0.0000, 0.4375, 0.5625, 0.0000],
          [0.0000, 0.0000, 0.6797, 0.3203],
          [0.0000, 0.0000, 0.6767, 0.3233],
          [0.0000, 0.0000, 0.5312, 0.4688],
          [0.7180, 0.2820, 0.0000, 0.0000],
          [0.0000, 0.5857, 0.0000, 0.4143]],
 
         [[0.0000, 0.4748, 0.5252, 0.0000],
          [0.4327, 0.5673, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.5812, 0.4188],
          [0.6178, 0.0000, 0.0000, 0.3822],
          [0.0000, 0.5805, 0.4195, 0.0000],
          [0.5259, 0.0000, 0.0000, 0.4741],
  

In [63]:
class NoisyTopkRouter(nn.Module):
    def __init__(self, top_k=2):
        super(NoisyTopkRouter, self).__init__()
        self.top_k = top_k
        self.topkroute_linear = nn.Linear(64, 4)
        self.noise_linear = nn.Linear(64, 4)

    
    def forward(self, mha_output):
        logits = self.topkroute_linear(mha_output)
        noise_logits = self.noise_linear(mha_output)

        #Adding scaled unit gaussian noise to the logits
        noise = torch.randn_like(logits) * F.softplus(noise_logits)
        noisy_logits = logits + noise

        top_k_logits, indices = noisy_logits.topk(self.top_k, dim=-1)
        zeros = torch.full_like(noisy_logits, float('-inf'))
        sparse_logits = zeros.scatter(-1, indices, top_k_logits)
        router_output = F.softmax(sparse_logits, dim=-1)
        return router_output, indices

In [67]:
nn.ModuleList([Expert() for _ in range(4)])

ModuleList(
  (0-3): 4 x Expert(
    (net): Sequential(
      (0): Linear(in_features=64, out_features=256, bias=True)
      (1): ReLU()
      (2): Linear(in_features=256, out_features=64, bias=True)
      (3): Dropout(p=0.1, inplace=False)
    )
  )
)

In [72]:
x = torch.randn(4, 8, 64) 
x.view(-1, x.size(-1)).shape

torch.Size([32, 64])

In [83]:
router = NoisyTopkRouter()
gating_output, indices = router(x)
gating_output.shape

torch.Size([4, 8, 4])

In [78]:
final_output = torch.zeros_like(x)
final_output.shape

torch.Size([4, 8, 64])

In [95]:
flat_gating_output = gating_output.view(-1, gating_output.size(-1))
flat_gating_output.shape, flat_gating_output[0], indices[0][0]

(torch.Size([32, 4]),
 tensor([0.4768, 0.0000, 0.0000, 0.5232], grad_fn=<SelectBackward0>),
 tensor([3, 0]))

In [None]:
(indices == 3).shape
(indices == 3).any(dim=-1) # B, T -> True means expert 3 is activated

tensor([[ True, False, False,  True,  True,  True,  True, False],
        [ True, False,  True,  True, False,  True, False,  True],
        [ True, False, False, False,  True,  True,  True,  True],
        [ True,  True,  True, False,  True, False,  True, False]])

In [117]:
(indices == 3).any(dim=-1).view(-1).shape

torch.Size([32])

In [115]:
flat_mask = (indices == 0).any(dim=-1).view(-1)
flat_x = x.view(-1, x.size(-1))
flat_x.shape

torch.Size([32, 64])

In [125]:
flat_x[flat_mask == True].shape

torch.Size([14, 64])

In [156]:
gating_scores = flat_gating_output[flat_mask, 3].unsqueeze(1)
gating_scores

tensor([[0.5232],
        [0.0000],
        [0.3629],
        [0.0000],
        [0.4361],
        [0.0000],
        [0.7600],
        [0.0000],
        [0.6714],
        [0.8067],
        [0.2371],
        [0.0000],
        [0.5308],
        [0.0000]], grad_fn=<UnsqueezeBackward0>)

In [151]:
x = torch.tensor([[3, 4, 5, 6], [0, 1, 2, 1]])
x[[True,False], 2]

tensor([5])

In [158]:
expert = Expert()
expert_output = expert(flat_x[flat_mask])
expert_output.shape

torch.Size([14, 64])

In [166]:
(expert_output * gating_scores).squeeze(1).shape

torch.Size([14, 64])

In [167]:
x = torch.randn(4, 8, 64) 
final_out = torch.zeros_like(x)
final_out.shape

torch.Size([4, 8, 64])

In [168]:
expert_mask = (indices == 0).any(dim=-1)
expert_mask

tensor([[ True,  True, False,  True, False, False, False, False],
        [False,  True, False, False, False,  True,  True,  True],
        [False, False,  True, False,  True,  True, False, False],
        [False,  True, False, False, False,  True,  True,  True]])

In [None]:
final_out[expert_mask].shape

torch.Size([14, 64])

In [171]:
class SparseMoE(nn.Module):
    def __init__(self):
        super(SparseMoE, self).__init__()
        self.router = NoisyTopkRouter(64, 4, 2)
        self.experts = nn.ModuleList([Expert() for _ in range(4)])
        self.top_k = 2

    def forward(self, x):
        gating_output, indices = self.router(x)
        final_output = torch.zeros_like(x)

        # Reshape inputs for batch processing
        flat_x = x.view(-1, x.size(-1))
        flat_gating_output = gating_output.view(-1, gating_output.size(-1))

        # Process each expert in parallel
        for i, expert in enumerate(self.experts):
            # Create a mask for the inputs where the current expert is in top-k
            expert_mask = (indices == i).any(dim=-1)
            flat_mask = expert_mask.view(-1)

            if flat_mask.any():
                expert_input = flat_x[flat_mask]
                expert_output = expert(expert_input)

                # Extract and apply gating scores
                gating_scores = flat_gating_output[flat_mask, i].unsqueeze(1)
                weighted_output = expert_output * gating_scores

                # Update final output additively by indexing and adding
                final_output[expert_mask] += weighted_output.squeeze(1)

        return final_output

In [172]:
class Block(nn.Module):
    def __init__(self):
        super().__init__()
        self.mha = MultiHeadAttention()
        self.smoe = SparseMoE()
        self.ln1 = nn.LayerNorm(64)
        self.ln2 = nn.LayerNorm(64)

    def foward(self, x):
        x = x + self.mha(self.ln1(x))
        x = x + self.smoe(self.ln2(x))
        return x