<a href="https://colab.research.google.com/github/Papa-Panda/Paper_reading/blob/main/deepseek_v3.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# https://www.bilibili.com/video/BV1uUPieDEK1/?spm_id_from=333.788.videopod.sections&vd_source=1fecee762931e992c96e5e166be13b76
# https://chatgpt.com/c/67f3f548-7238-800e-9856-a834f9003957

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

# Rotary positional embeddings
def apply_rotary_pos_emb(x, sin, cos):
    x1, x2 = x[..., ::2], x[..., 1::2]
    x_rot = torch.cat([x1 * cos - x2 * sin, x1 * sin + x2 * cos], dim=-1)
    return x_rot

def get_rotary_emb(seq_len, dim, device):
    freqs = torch.pow(10000, -torch.arange(0, dim, 2).float() / dim).to(device)
    t = torch.arange(seq_len, device=device).float()
    freqs = torch.outer(t, freqs)
    return torch.sin(freqs), torch.cos(freqs)

# Multi-query Attention
class MultiQueryAttention(nn.Module):
    def __init__(self, dim, num_heads):
        super().__init__()
        self.num_heads = num_heads
        self.head_dim = dim // num_heads
        self.q_proj = nn.Linear(dim, dim)
        self.k_proj = nn.Linear(dim, self.head_dim)
        self.v_proj = nn.Linear(dim, self.head_dim)
        self.out_proj = nn.Linear(dim, dim)

    def forward(self, x, rotary_sin, rotary_cos):
        B, T, C = x.shape
        q = self.q_proj(x).view(B, T, self.num_heads, self.head_dim).transpose(1, 2)
        k = self.k_proj(x).unsqueeze(1)  # shared keys
        v = self.v_proj(x).unsqueeze(1)  # shared values

        q = apply_rotary_pos_emb(q, rotary_sin, rotary_cos)
        k = apply_rotary_pos_emb(k, rotary_sin, rotary_cos)

        att = (q @ k.transpose(-2, -1)) / math.sqrt(self.head_dim)
        att = F.softmax(att, dim=-1)
        out = (att @ v).transpose(1, 2).reshape(B, T, C)
        return self.out_proj(out)

# Transformer block
class TransformerBlock(nn.Module):
    def __init__(self, dim, num_heads, mlp_ratio=4.0):
        super().__init__()
        self.ln1 = nn.LayerNorm(dim)
        self.attn = MultiQueryAttention(dim, num_heads)
        self.ln2 = nn.LayerNorm(dim)
        self.mlp = nn.Sequential(
            nn.Linear(dim, int(dim * mlp_ratio)),
            nn.GELU(),
            nn.Linear(int(dim * mlp_ratio), dim),
        )

    def forward(self, x, rotary_sin, rotary_cos):
        x = x + self.attn(self.ln1(x), rotary_sin, rotary_cos)
        x = x + self.mlp(self.ln2(x))
        return x

# Full model
class ToyDeepSeekV3(nn.Module):
    def __init__(self, vocab_size, dim, depth, num_heads, max_seq_len):
        super().__init__()
        self.token_emb = nn.Embedding(vocab_size, dim)
        self.blocks = nn.ModuleList([
            TransformerBlock(dim, num_heads) for _ in range(depth)
        ])
        self.ln_f = nn.LayerNorm(dim)
        self.head = nn.Linear(dim, vocab_size, bias=False)
        self.max_seq_len = max_seq_len
        self.dim = dim

    def forward(self, idx):
        B, T = idx.shape
        x = self.token_emb(idx)
        rotary_sin, rotary_cos = get_rotary_emb(self.max_seq_len, self.dim // self.blocks[0].attn.num_heads, x.device)
        rotary_sin = rotary_sin[:T].unsqueeze(0).unsqueeze(0)
        rotary_cos = rotary_cos[:T].unsqueeze(0).unsqueeze(0)

        for block in self.blocks:
            x = block(x, rotary_sin, rotary_cos)
        x = self.ln_f(x)
        return self.head(x)

# Example use
model = ToyDeepSeekV3(vocab_size=1000, dim=256, depth=4, num_heads=4, max_seq_len=128)
dummy_input = torch.randint(0, 1000, (2, 128))
out = model(dummy_input)  # (2, 128, 1000)
print(out.shape)


torch.Size([2, 128, 1000])
