In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from dataclasses import dataclass
import os
from torch.utils.data import Dataset, DataLoader
import torch.optim as optim
import math
import requests

In [2]:
@dataclass
class Modelconfig():
    d_model = 256
    vocab_size = None
    n_head = 4
    d_compressed = 128
    head_dim = d_model // n_head  # = 32
    device = "cuda" if torch.cuda.is_available() else "cpu"
    eps = 1e-6
    batch_size = 16
    d_ff = 512
    n_shared = 2
    n_routed = 2
    top_k = 1
    n_mtp_depth = 4
    n_layers = 4
    dropout = 0.1
    lambda_mtp = 0.1
    seq_len = 32

In [3]:
class InputEmbedding(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.d_model = config.d_model
        self.vocab_size = config.vocab_size
        self.embedding = nn.Embedding(config.vocab_size, config.d_model)

    def forward(self, x):
        return self.embedding(x)

In [4]:
def precompute_freqs_cis(dim, seq_len, theta=10000):
    freqs = 1.0 / theta ** (torch.arange(0, dim, 2).float() / dim)
    pos = torch.arange(seq_len)
    angles = torch.outer(pos, freqs)  
    return torch.polar(torch.ones_like(angles), angles)

def apply_rotary_embed(x, freqs_cis, device):
    batch, seq_len, d_model = x.shape
    x_complex = torch.view_as_complex(x.float().reshape(batch, seq_len, -1, 2))
    
    x_complex = x_complex.unsqueeze(2) 
    freqs_cis = freqs_cis[:seq_len]     
    freqs_cis = freqs_cis.unsqueeze(0).unsqueeze(2)
    x_rotated = x_complex * freqs_cis    
    x_rotated = x_rotated.squeeze(2)    
    x_out = torch.view_as_real(x_rotated)  
    x_out = x_out.reshape(batch, seq_len, d_model)
    return x_out.type_as(x).to(device)


In [5]:
class RMSNorm(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.eps = config.eps 
        self.weights = nn.Parameter(torch.ones(config.d_model))

    def forward(self, x):
        mean = torch.mean(x ** 2, dim=-1, keepdim=True)
        rms = torch.sqrt(mean + self.eps)
        return (x / rms) * self.weights

In [6]:
class MultiHeadAttention(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.d_model = config.d_model      
        self.n_head = config.n_head           
        self.d_compressed = config.d_compressed 
        self.head_dim = config.head_dim      

        self.register_buffer("cache_k", torch.zeros(0, 0, 0, 0))
        self.register_buffer("cache_v", torch.zeros(0, 0, 0, 0))
        
        self.q_c = nn.Linear(self.d_model, self.d_compressed, bias=False)
        self.q_r = nn.Linear(self.d_compressed, self.d_model, bias=False)
        self.w_q = nn.Linear(self.d_model, self.d_model, bias=False)
        self.q_u = nn.Linear(self.d_compressed, self.d_model, bias=False)

        self.k_c = nn.Linear(self.d_model, self.d_compressed, bias=False)
        self.k_u = nn.Linear(self.d_compressed, self.d_model, bias=False)
        self.k_r = nn.Linear(self.d_compressed, self.d_model, bias=False)
        self.w_k = nn.Linear(self.d_model, self.d_model, bias=False)

        self.v_c = nn.Linear(self.d_model, self.d_compressed, bias=False)
        self.v_u = nn.Linear(self.d_compressed, self.d_model, bias=False)
        self.w_v = nn.Linear(self.d_model, self.d_model, bias=False)

        self.w_o = nn.Linear(self.d_model, self.d_model, bias=False)
        
    def reset_cache(self):
        self.cache_k = torch.zeros(0, 0, 0, 0, device=self.w_q.weight.device)
        self.cache_v = torch.zeros(0, 0, 0, 0, device=self.w_q.weight.device)

    def forward(self, x, freqs_cis, start_pos):
        batch_size, seq_len, _ = x.shape
        q = self.w_q(x)
        k = self.w_k(x)
        v = self.w_v(x)
        q = q.view(batch_size, seq_len, self.n_head, self.head_dim).transpose(1, 2)
        k = k.view(batch_size, seq_len, self.n_head, self.head_dim).transpose(1, 2)
        v = v.view(batch_size, seq_len, self.n_head, self.head_dim).transpose(1, 2)

        q_flat = q.reshape(batch_size, seq_len, -1)
        k_flat = k.reshape(batch_size, seq_len, -1)
        v_flat = v.reshape(batch_size, seq_len, -1)

        compressed_q = self.q_c(q_flat)
        up_q = self.q_u(compressed_q)
        temp_q = self.q_r(compressed_q).view(batch_size, seq_len, self.n_head, self.head_dim)
        rope_q = apply_rotary_embed(temp_q.reshape(batch_size, seq_len, -1), freqs_cis, device=x.device)
        q_out = up_q + rope_q 

        compressed_k = self.k_c(k_flat)
        up_k = self.k_u(compressed_k)
        temp_k = self.k_r(compressed_k).view(batch_size, seq_len, self.n_head, self.head_dim)
        rope_k = apply_rotary_embed(temp_k.reshape(batch_size, seq_len, -1), freqs_cis, device=x.device)
        k_out = up_k + rope_k  
        
        compressed_v = self.v_c(v_flat)
        v_out = self.v_u(compressed_v) 

        if self.cache_k.numel() == 0 or self.cache_k.shape[0] != batch_size:
            self.cache_k = k_out.clone()
            self.cache_v = v_out.clone()
        else:
            self.cache_k = torch.cat([self.cache_k, k_out], dim=1)
            self.cache_v = torch.cat([self.cache_v, v_out], dim=1)
        k_cat = self.cache_k[:, start_pos:start_pos+seq_len]
        v_cat = self.cache_v[:, start_pos:start_pos+seq_len]
        k_cat = k_cat.view(batch_size, seq_len, self.n_head, self.head_dim).detach()
        v_cat = v_cat.view(batch_size, seq_len, self.n_head, self.head_dim).detach()

        k_cat = k_cat.reshape(batch_size, seq_len, -1)
        v_cat = v_cat.reshape(batch_size, seq_len, -1)


        scores = (q_out @ k_cat.transpose(-1, -2)) / math.sqrt(self.head_dim)
        if self.training or start_pos == 0:
            mask = torch.triu(torch.ones(seq_len, seq_len, device=scores.device), diagonal=1).bool()
            scores = scores.masked_fill(mask, float('-inf'))
        attn_weights = F.softmax(scores, dim=-1)
        o = attn_weights @ v_cat
        o = self.w_o(o)
        return o

In [7]:
class Expert(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.fc1 = nn.Linear(config.d_model, config.d_ff)
        self.fc2 = nn.Linear(config.d_ff, config.d_model)
        self.act = nn.GELU()

    def forward(self, x):
        return self.fc2(self.act(self.fc1(x)))

class DeepSeekMOE(nn.Module):
    def __init__(self, config, gamma=0.01, alpha=1e-5):
        super().__init__()
        self.d_model = config.d_model
        self.n_shared = config.n_shared
        self.n_routed = config.n_routed
        self.top_k = config.top_k
        self.gamma = gamma
        self.alpha = alpha

        self.shared_experts = nn.ModuleList([Expert(config) for _ in range(config.n_shared)])
        self.routed_experts = nn.ModuleList([Expert(config) for _ in range(config.n_routed)])
        self.expert_centroids = nn.Parameter(torch.randn(config.n_routed, self.d_model))
        self.register_buffer('bias', torch.zeros(config.n_routed))

    def forward(self, x):
        batch_size, seq_len, _ = x.shape
        shared_output = sum(expert(x) for expert in self.shared_experts)
        scores = torch.sigmoid(x @ self.expert_centroids.T)
        adjusted_scores = scores + self.bias[None, None, :]
        topk_scores, topk_indices = torch.topk(adjusted_scores, k=self.top_k, dim=-1)
        mask = torch.zeros_like(scores).scatter_(-1, topk_indices, 1.0)
        g_prime = scores * mask
        g = g_prime / (g_prime.sum(dim=-1, keepdim=True) + 1e-6)
        experts_output = torch.stack([expert(x) for expert in self.routed_experts], dim=2)
        routed_output = (experts_output * g.unsqueeze(-1)).sum(dim=2)
        output = x + shared_output + routed_output

        if self.training:
            expert_counts = mask.sum(dim=(0, 1))
            total_tokens = batch_size * seq_len
            expected = (total_tokens * self.top_k) / self.n_routed
            delta = torch.where(expert_counts > expected, -self.gamma, self.gamma)
            with torch.no_grad():
                self.bias.add_(delta)
            s_prime = F.softmax(scores, dim=-1)
            p_i = s_prime.mean(dim=(0, 1))
            f_i = (self.n_routed / (self.top_k * total_tokens)) * expert_counts
            loss = self.alpha * torch.sum(f_i * p_i)
            return output, loss
        return output

In [8]:
class ProjectionLayer(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.d_model = config.d_model
        self.display = nn.Linear(self.d_model, config.vocab_size)
        
    def forward(self, x):
        return self.display(x)

In [9]:
class DeepSeekEncoderLayer(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.n_layers = config.n_layers
        self.mha = MultiHeadAttention(config)
        self.moe = DeepSeekMOE(config)
        self.norm1 = RMSNorm(config)
        self.norm2 = RMSNorm(config)
        self.dropout1 = nn.Dropout(config.dropout)
        self.dropout2 = nn.Dropout(config.dropout)
        
    def forward(self, x, freqs_cis, start_pos):
        attn_out = self.mha(x, freqs_cis, start_pos)
        x = x + self.dropout1(attn_out)
        x = x + self.norm1(x)
        if self.training:
            moe_out, moe_loss = self.moe(x)
        else:
            moe_out = self.moe(x)
        x = x + self.dropout2(moe_out)
        x = self.norm2(x)
        if self.training:
            return x, moe_loss
        return x, torch.tensor(0.0, device=x.device)
        

In [10]:
class DeepSeekV3(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.d_model = config.d_model
        self.n_mtp_depth = config.n_mtp_depth
        self.lambda_mtp = config.lambda_mtp

        self.embedding = InputEmbedding(config)
        self.output_head = ProjectionLayer(config)

        self.encoder_layer = nn.ModuleList([DeepSeekEncoderLayer(config) for _ in range(config.n_layers)])
        self.mtp_projs = nn.ModuleList([nn.Linear(2 * self.d_model, self.d_model) for _ in range(self.n_mtp_depth)])
        self.mtp_trms = nn.ModuleList([DeepSeekEncoderLayer(config) for _ in range(self.n_mtp_depth)])
        self.rms_norm = nn.ModuleList([RMSNorm(config) for _ in range(self.n_mtp_depth)])

    def forward(self, input_ids, freqs_cis, start_pos):
        batch_size, seq_len = input_ids.shape
        original_ids = input_ids.clone() 
        x = self.embedding(input_ids)     
        
        balance_losses = []

        for layer in self.encoder_layer:
            if self.training:
                x, balance_loss = layer(x, freqs_cis, start_pos)
                balance_losses.append(balance_loss.detach())
            else:
                x, _ = layer(x, freqs_cis, start_pos)
        
        main_logits = self.output_head(x)   

        if self.training:
            with torch.no_grad():
                h_prev = x
            mtp_losses = []
            for depth in range(self.n_mtp_depth):
                if h_prev.size(1) < depth + 2:
                    break
                future_emb = self.embedding(original_ids[:, depth+1:])
                h_trimmed = h_prev[:, :-(depth+1), :]
                h_norm = self.rms_norm[depth](h_trimmed)
                f_norm = self.rms_norm[depth](future_emb)
                combined = torch.cat([h_norm, f_norm], dim=-1)
                projected = self.mtp_projs[depth](combined)
                mtp_output, _ = self.mtp_trms[depth](projected, freqs_cis, start_pos)
                mtp_logits = self.output_head(mtp_output)
                targets = original_ids[:, depth+1:]
                loss = F.cross_entropy(mtp_logits.view(-1, mtp_logits.size(-1)), 
                                       targets.reshape(-1))
                mtp_losses.append(loss)
            total_mtp_loss = self.lambda_mtp * sum(mtp_losses) + sum(balance_losses)
        else:
            total_mtp_loss = torch.tensor(0.0, device=x.device)

        return main_logits, total_mtp_loss

In [11]:
import os
import requests
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, random_split
from tqdm import tqdm
import random
import numpy as np

seed = 42
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(seed)

DATA_PATH = "tinyshakespeare.txt"
if not os.path.exists(DATA_PATH):
    DATA_URL = "https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt"
    print("Downloading Tiny Shakespeare...")
    r = requests.get(DATA_URL)
    with open(DATA_PATH, "w", encoding="utf-8") as f:
        f.write(r.text)
    print("Download complete.")

with open(DATA_PATH, "r", encoding="utf-8") as f:
    lines = f.read().splitlines()

text = " <eos> ".join(lines)
tokens = text.split() 
text = " ".join(tokens)

special_tokens = ["<eos>", "<pos>"]

text_tokens = text.split()
vocab = special_tokens + sorted(list(set(text_tokens) - set(special_tokens)))
print(f"Vocabulary size: {len(vocab)}")

stoi = {token: i for i, token in enumerate(vocab)}
itos = {i: token for i, token in enumerate(vocab)}

class ShakespeareDataset(Dataset):
    def __init__(self, text, block_size, stoi):
        self.data = [stoi[token] for token in text.split() if token in stoi]
        self.block_size = block_size
    def __len__(self):
        return len(self.data) - self.block_size
    def __getitem__(self, idx):
        return torch.tensor(self.data[idx:idx+self.block_size], dtype=torch.long)

def generate_text(model, start_text, length, temperature=1.0, device='cuda'):
    model.eval()
    for layer in model.encoder_layer:
        layer.mha.reset_cache()
    
    context = torch.tensor([stoi[token] for token in start_text.split() if token in stoi],
                           dtype=torch.long).unsqueeze(0).to(device)
    with torch.no_grad():
        for _ in range(length):
            if context.size(1) > config.seq_len:
                context = context[:, -config.seq_len:]
            seq_len = context.size(1)
            freqs_cis = precompute_freqs_cis(config.d_model, seq_len).to(device)
            logits, _ = model(context, freqs_cis, start_pos=0)
            logits = logits[:, -1, :] / temperature
            probs = F.softmax(logits, dim=-1)
            next_token = torch.multinomial(probs, num_samples=1)
            context = torch.cat([context, next_token], dim=1)
    
    return " ".join([itos[int(idx)] for idx in context.squeeze().tolist()])

def main():
    block_size = config.seq_len 
    batch_size = config.batch_size
    dataset = ShakespeareDataset(text, block_size, stoi)
    
    dataset_len = len(dataset)
    val_len = int(0.1 * dataset_len)
    train_len = dataset_len - val_len
    train_dataset, val_dataset = random_split(dataset, [train_len, val_len])
    
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
    
    model = DeepSeekV3(config)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    
    optimizer = optim.Adam(model.parameters(), lr=2e-4, weight_decay=1e-5)
    num_epochs = 10
    
    for epoch in range(num_epochs):
        model.train()
        for layer in model.encoder_layer:
            layer.mha.reset_cache()
        total_train_loss = 0.0
        train_bar = tqdm(train_loader, desc=f"Epoch {epoch+1} [Training]", unit="batch")
        for i, batch in enumerate(train_bar):
            batch = batch.to(device)
            optimizer.zero_grad()
            seq_len = config.seq_len
            freqs_cis = precompute_freqs_cis(config.d_model, seq_len).to(device)
            start_pos = 0
            main_logits, total_mtp_loss = model(batch, freqs_cis, start_pos)
            loss_main = F.cross_entropy(main_logits.view(-1, config.vocab_size),
                                        batch.view(-1))
            loss = loss_main + total_mtp_loss
            loss.backward()
            optimizer.step()
            total_train_loss += loss.item()
            train_bar.set_postfix(loss=f"{loss.item():.4f}")
            if (i + 1) % 50 == 0:
                sample = generate_text(model, "Care for", 50, temperature=1.5, device=device)
                train_bar.set_postfix(loss=f"{loss.item():.4f}", sample=sample[:50] + "...")
        avg_train_loss = total_train_loss / len(train_loader)
        print(f"\nEpoch {epoch+1}, Average Training Loss: {avg_train_loss:.4f}")
        
        model.eval()
        total_val_loss = 0.0
        with torch.no_grad():
            for batch in tqdm(val_loader, desc=f"Epoch {epoch+1} [Validation]", unit="batch"):
                batch = batch.to(device)
                seq_len = config.seq_len
                freqs_cis = precompute_freqs_cis(config.d_model, seq_len).to(device)
                start_pos = 0
                main_logits, total_mtp_loss = model(batch, freqs_cis, start_pos)
                loss_main = F.cross_entropy(main_logits.view(-1, config.vocab_size),
                                            batch.view(-1))
                loss = loss_main + total_mtp_loss
                total_val_loss += loss.item()
        avg_val_loss = total_val_loss / len(val_loader)
        print(f"Epoch {epoch+1}, Average Validation Loss: {avg_val_loss:.4f}")
        
        generated_text = generate_text(model, "Care for", 300, temperature=1.5, device=device)
        print("\nGenerated Text:")
        print(generated_text)
        torch.cuda.empty_cache()

if __name__ == "__main__":
    config = Modelconfig()
    config.vocab_size = len(vocab)
    main()


Vocabulary size: 25672


Epoch 1 [Training]: 100%|█| 13648/13648 [06:02<00:00, 37.68batch/s, loss=0.0049]



Epoch 1, Average Training Loss: 0.2994


Epoch 1 [Validation]: 100%|██████████████| 1517/1517 [01:02<00:00, 24.08batch/s]


Epoch 1, Average Validation Loss: 0.0098

Generated Text:
bowsprit, bowsprit, bowsprit, o'er-run first. first. first. first. didst, graves, graves, executed benched images, Deposed Clare. Clare. Clare. read'st, read'st, His His His dream, guard, boy's prize prize prize prize prize Lucentio rapture


Epoch 2 [Training]: 100%|█| 13648/13648 [06:15<00:00, 36.38batch/s, loss=0.0012]



Epoch 2, Average Training Loss: 0.0156


Epoch 2 [Validation]: 100%|██████████████| 1517/1517 [01:05<00:00, 23.26batch/s]


Epoch 2, Average Validation Loss: 0.0066

Generated Text:
vow vow vow vow vow vow vow vow vow vow vow vow vow vow vow vow vow vow vow vow vow vow vow vow vow vow vow vow linen linen linen linen linen


Epoch 3 [Training]: 100%|█| 13648/13648 [06:10<00:00, 36.88batch/s, loss=0.0023]



Epoch 3, Average Training Loss: 0.0137


Epoch 3 [Validation]: 100%|██████████████| 1517/1517 [01:06<00:00, 22.97batch/s]


Epoch 3, Average Validation Loss: 0.0068

Generated Text:
Bohemia: Bohemia: Bohemia: Bohemia: womb, womb, womb, womb, womb, womb, womb, what? what? what? what? what? what? what? what? what? what? what? what? what? what? what? what? what? what? what? Ravenspurgh; Ravenspurgh; motion


Epoch 4 [Training]: 100%|█| 13648/13648 [06:16<00:00, 36.29batch/s, loss=0.0052]



Epoch 4, Average Training Loss: 0.0134


Epoch 4 [Validation]: 100%|██████████████| 1517/1517 [01:05<00:00, 23.05batch/s]


Epoch 4, Average Validation Loss: 0.0069

Generated Text:
Whiles Whiles Whiles Whiles Whiles Whiles join'd. join'd. god, god, god, god, princess,--goddess!--O, usurers; usurers; usurers; delicates, delicates, kept! sighs; sighs; constable? constable? constable? constable? desert! desert! desert! desert! shrieks 'cum 'cum 'cum


Epoch 5 [Training]: 100%|█| 13648/13648 [06:18<00:00, 36.07batch/s, loss=0.0008]



Epoch 5, Average Training Loss: 0.0138


Epoch 5 [Validation]: 100%|██████████████| 1517/1517 [01:08<00:00, 22.28batch/s]


Epoch 5, Average Validation Loss: 0.0073

Generated Text:
French French French French earnestly grief: grief: allies allies allies allies what what what what what what what what what what what what what what what what what what what what what what


Epoch 6 [Training]:   9%|▏ | 1239/13648 [00:35<05:55, 34.88batch/s, loss=0.0038]


KeyboardInterrupt: 