### DeepSeekMoE architecture recreation attempt

inpired by paper: https://arxiv.org/pdf/2401.06066

In [92]:
import torch
import torch.nn as nn
import torch.nn.functional as F 

### Direct computation attempt

before we try to build a class in `nn.Module` lets try to grasp what computation we are trying o achive

In [67]:
batch_size = 8
sequence_len = 10
embedding_dim = 32
num_experts = 14
m = 2
k = 3
dim_experts = embedding_dim//m
num_shared_exp = 3

# assume outputs from selfattention layer
attention_output = torch.rand([batch_size,
                               sequence_len,
                               embedding_dim], dtype=torch.float32)

attention_output.shape

torch.Size([8, 10, 32])

In [24]:
centroids = torch.rand([dim_experts,
                        num_experts*m], dtype=torch.float32)

scores = F.softmax(torch.einsum("btd,wn->btn", attention_output, centroids), dim=-1)
scores.shape

torch.Size([8, 10, 28])

In [None]:
# get topk elements

topk_values, topk_indices = torch.topk(scores, k=m*k, dim=-1)

mask = torch.zeros_like(scores).scatter_(-1, topk_indices, 1.0)

mask.shape

torch.Size([8, 10, 28])

In [38]:
shared_experts = torch.rand([num_shared_exp,
                             dim_experts,
                             embedding_dim], dtype=torch.float32)

shared_out = torch.einsum("btd,nwd->btnd", attention_output, shared_experts)
shared_out.shape

torch.Size([8, 10, 3, 32])

In [42]:
# compute expert weight multiplication

experts_weights = torch.rand([m*num_experts,
                            dim_experts,
                            embedding_dim], dtype=torch.float32)

no_filter_experts_out = torch.einsum("btd,nwd->btnd", attention_output, experts_weights)

no_filter_experts_out.shape

torch.Size([8, 10, 28, 32])

In [47]:
# filter
experts_out = mask.unsqueeze(-1) * no_filter_experts_out
experts_out.shape

torch.Size([8, 10, 28, 32])

In [48]:
total_experts_out = torch.cat((shared_out, experts_out), dim=2)
total_experts_out.shape

torch.Size([8, 10, 31, 32])

In [49]:
compressed = total_experts_out.sum(dim=2)
compressed.shape

torch.Size([8, 10, 32])

In [50]:
out = attention_output + compressed

out.shape

torch.Size([8, 10, 32])

In [86]:
class Moe(nn.Module):
    def __init__(self, num_experts, num_shared, k, expert_dim, num_tokens, embedding_dim):
        super(Moe, self).__init__()

        self.k = k

        self.centroids = torch.rand([embedding_dim,
                                     num_experts], dtype=torch.float32)
        
        self.shared_experts = torch.rand([num_shared,
                                     expert_dim,
                                     embedding_dim], dtype=torch.float32)
        
        self.expert_weights = torch.rand([num_experts,
                                     expert_dim,
                                     embedding_dim], dtype=torch.float32)
        
    def forward(self, x):
        # assume x is the output of a attention layer
        scores = F.softmax(torch.einsum("btd,dn->btn", x, self.centroids), dim=-1)

        topk_values, topk_indices = torch.topk(scores, k=self.k, dim=-1)
        
        mask = torch.zeros_like(scores).scatter_(-1, topk_indices, 1.0)

        shared_out = torch.einsum("btd,nwd->btnd", x, self.shared_experts)

        experts_out = torch.einsum("btd,nwd->btnd", x, self.expert_weights)

        routed_experts = mask.unsqueeze(-1) * experts_out

        total_experts_out = torch.cat((shared_out, routed_experts), dim=2)

        return total_experts_out.sum(dim=2), (scores, mask)

In [87]:
moe = Moe(14, 6, 3, 16, 10, 32)

moe(torch.rand([8, 10, 32], dtype=torch.float32))[0].shape

torch.Size([8, 10, 32])

### Expert Load Balancing Loss

In [88]:
# direct computation attempt
mask = moe(torch.rand([8, 10, 32], dtype=torch.float32))[1][1]

# calculating fi
sum_vec = mask.sum(1)
fi = (num_experts / k) * sum_vec
fi[0]


tensor([ 4.6667,  0.0000, 42.0000,  0.0000,  4.6667, 14.0000,  9.3333,  0.0000,
        14.0000, 32.6667,  0.0000,  0.0000,  0.0000, 18.6667])

In [93]:
# calculating pi
scores = moe(torch.rand([8, 10, 32], dtype=torch.float32))[1][0]

sum_vec = scores.sum(1)
pi = sum_vec / sequence_len
pi[0]

tensor([0.0441, 0.0244, 0.1539, 0.0340, 0.0643, 0.0730, 0.1039, 0.0457, 0.1177,
        0.1361, 0.0364, 0.0259, 0.0335, 0.1069])

In [97]:
alpha = 0.03

exp_bal_loss = alpha * (pi * fi).sum()

exp_bal_loss

tensor(4.2619)

#### in future notebooks we'll create a full loss function for a deepseek architecture language model

### for now, this is all we need to know about deepseek moe