In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

In [3]:
n_embd = 512
n_head = 8
n_layer = 6
seq_len = 256

In [28]:
test_input = torch.randn((3, 1, 1024))

In [29]:
def precompute_theta_frequencies(
    head_dim: int, seq_len: int, device: str, theta: float = 10000.0
):

    theta = 1.0 / (theta ** ((torch.arange(0, head_dim, 2).float())/head_dim)).to(device)
    seq_idx = torch.arange(seq_len, device=device)
    freqs = torch.outer(seq_idx, theta).float()

    freq_complex = torch.polar(torch.ones_like(freqs), freqs)

    return freq_complex


def apply_rotary_embeddings(x: torch.Tensor, freq_complex: torch.Tensor, device: str):

    x_complex = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2))
    freq_complex_align = freq_complex.unsqueeze(0).unsqueeze(2)

    x_rotated = x_complex * freq_complex_align

    x_out = torch.view_as_real(x_rotated)
    x_out = x_out.reshape(*x.shape)

    return x_out.type_as(x).to(device)

In [42]:
freq_complex = precompute_theta_frequencies(
            32, 1, device="cpu"
        )

In [43]:
freq_complex.shape

torch.Size([1, 16])

In [44]:
test_input.shape

torch.Size([3, 1, 1024])

In [104]:
class MultiHeadedLatentAttention(nn.Module):

    def __init__(self,dim, head_dim, latent_kv_dim, latent_q_dim, n_heads, decop_rot_dim):
        super().__init__()

        self.dim = dim
        self.n_heads = n_heads
        self.latent_kv_dim = latent_kv_dim
        self.latent_q_dim = latent_kv_dim
        self.head_dim =head_dim
        self.decop_rot_dim = decop_rot_dim
        self.expert_dim = latent_q_dim

        self.latent_kv = nn.Linear(self.dim, latent_kv_dim, bias=False)
        self.latent_q = nn.Linear(self.dim, latent_q_dim, bias=False)

        self.query = nn.Linear(latent_q_dim, self.n_heads * self.head_dim, bias=False)
        self.key = nn.Linear(latent_kv_dim, self.n_heads * self.head_dim, bias=False)
        self.value = nn.Linear(latent_kv_dim, self.n_heads * self.head_dim, bias=False)

        self.decop_rot_q = nn.Linear(latent_q_dim, self.n_heads * self.decop_rot_dim)
        self.decop_rot_k = nn.Linear(self.dim, self.n_heads * self.decop_rot_dim)

        self.out_proj = nn.Linear(self.head_dim * self.n_heads, self.dim)


    def forward(self, x, freq_complex: torch.Tensor):
        x: torch.Tensor = x

        batch_size, seq_len, _ = x.shape

        cq = self.latent_q(x)
        ckv = self.latent_kv(x)

        q = self.query(cq)
        qr = self.decop_rot_q(cq)

        k = self.key(ckv)
        kr = self.decop_rot_k(x)

        v = self.value(ckv)
        

        qr = qr.view(batch_size, seq_len, self.n_heads, self.decop_rot_dim)
        qr = apply_rotary_embeddings(qr, freq_complex, device=x.device)
        q = q.view(batch_size, seq_len, self.n_heads, self.head_dim)
        q = torch.cat((q, qr), dim=-1)

        kr = kr.view(batch_size, seq_len, self.n_heads, self.decop_rot_dim)
        kr = apply_rotary_embeddings(kr, freq_complex, device=x.device)
        k = k.view(batch_size, seq_len, self.n_heads, self.head_dim)
        k = torch.cat((k, kr), dim=-1)

        v = v.view(batch_size, seq_len, self.n_heads, self.head_dim)

        q = q.transpose(1, 2)
        k = k.transpose(1, 2)
        v = v.transpose(1, 2)

        print(f"query: {q.shape} key: {k.shape} value: {v.shape}")

        att_scores = torch.matmul(q, k.transpose(2, 3)) / math.sqrt(
            self.head_dim + self.decop_rot_dim
        )
        att_scores = F.softmax(att_scores.float(), dim=-1).type_as(q)

        print(f"Att Scores: {att_scores.shape}")

        output = torch.matmul(att_scores, v)

        output = output.transpose(1, 2).contiguous().view(batch_size, seq_len, -1)

        return self.out_proj(output)

In [105]:
mla = MultiHeadedLatentAttention(head_dim=64, n_heads=64, latent_kv_dim=256, latent_q_dim=768, dim=1024, decop_rot_dim=32)

In [106]:
mla_out = mla(test_input, freq_complex)

query: torch.Size([3, 64, 1, 96]) key: torch.Size([3, 64, 1, 96]) value: torch.Size([3, 64, 1, 64])
Att Scores: torch.Size([3, 64, 1, 1])


In [108]:
mla_out.shape

torch.Size([3, 1, 1024])

In [109]:
class Expert(nn.Module):

    def __init__(self, dim, expert_dim):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(dim, expert_dim),
            nn.Linear(expert_dim, dim),
            nn.GELU(),
        )

    def forward(self, x):
        return self.net(x)

In [110]:
e = Expert(1024, 768)

In [111]:
e(torch.randn((3, 1, 1024)))

tensor([[[-0.0509,  0.1320, -0.0619,  ...,  0.1179,  0.2809,  0.3111]],

        [[ 0.1347, -0.0212, -0.1377,  ..., -0.0383,  0.4494,  0.0418]],

        [[ 0.1068, -0.1219, -0.0937,  ..., -0.1031,  0.0563, -0.1148]]],
       grad_fn=<GeluBackward0>)

In [None]:
class DeepSeekMoE(nn.Module):
    def __init__(self, dim, expert_dim, n_s_experts, n_r_experts, topk):
        super().__init__()
        self.dim = dim
        self.n_s_experts = n_s_experts
        self.n_r_experts = n_r_experts
        self.top_k = topk
        
        self.shared_experts = nn.ModuleList([
            Expert(dim, expert_dim) for _ in range(n_s_experts)
        ])

        self.routed_experts = nn.ModuleList([
            Expert(dim, expert_dim) for _ in range(n_r_experts)
        ])

        self.centroids = nn.Parameter(torch.randn(n_r_experts, dim))



    def forward(self, x: torch.Tensor):
        batch_size, seq_len, _ = x.shape
        shared_out = torch.zeros_like(x)
        for expert in self.shared_experts:
            shared_out += expert(x)
        
        x_flat = x.view(-1, self.dim)
        
        affinity = torch.matmul(x_flat, self.centroids.T)
        affinity = F.softmax(affinity, dim = -1)

        topk_scores, topk_indices = torch.topk(affinity, self.top_k, dim=-1)
        mask = torch.zeros_like(affinity)
        mask.scatter_(-1, topk_indices, topk_scores)

        routed_out = torch.zeros_like(x_flat)
        for i in range(self.n_r_experts):
            expert_mask = mask[:, i].unsqueeze(-1)
            #print(f"Expert Mask: {expert_mask}, Shape: {expert_mask.shape}")
            expert_out = self.routed_experts[i](x_flat)
            routed_out += expert_mask * expert_out

        routed_out = routed_out.view(batch_size, seq_len, self.dim)

        return x + shared_out + routed_out, affinity

In [200]:
d = DeepSeekMoE(dim=1024, expert_dim=768, n_s_experts=2, n_r_experts=160, topk=4)

In [201]:
mla_out.shape

torch.Size([3, 1, 1024])

In [202]:
d(mla_out)[0].shape

torch.Size([3, 160])
tensor([[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.4635, 0.0000, 0.0000,
         0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000, 0.0658, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0159, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
       

torch.Size([3, 1, 1024])

In [76]:
out[1]

tensor([[[58, 51, 27, 60]],

        [[27, 39, 45, 55]],

        [[11, 22, 55, 14]]])

In [78]:
out[2].shape

torch.Size([3, 1, 80])