In [197]:
import torch
import os
import torch.nn as nn
import numpy as np
import torch.nn.functional as F
import math

In [198]:
device = "mps" if torch.backends.mps.is_available() else "cpu"

### Hyperparameters

In [199]:
SEQ_LENGHT = 4096
VOCAB_SIZE = 50304
EMBEDDING_DIM = 1024
NUM_HEADS = 16
NUM_BLOCKS = 16
BATCH_SIZE = 128
NUM_EXPERTS = 64
TOP_K_EXPERTS = 8

### Dataloader

In [200]:
def load_tokens(filename):
    npt = np.load(filename)
    npt = npt.astype(np.int32)
    ptt = torch.tensor(npt, dtype=torch.long)
    return ptt

In [228]:
class DataLoader:
    def __init__(self, B, T, split, data_root):
        self.B = B
        self.T = T
        assert split in {'train', 'val'}

        data_root = data_root
        shards = os.listdir(data_root)
        shards = [s for s in shards if split in s]
        shards = sorted(shards)
        shards = [os.path.join(data_root, s) for s in shards]
        self.shards = shards
        if (len(shards) > 0) == False:
            print(f"no shards found for split {split}")
        # assert len(shards) > 0, f"no shards found for split {split}"
        print(f"[green]found {len(shards)} shards for split {split}[/green]")
        self.reset()

    def reset(self):
        # state, init at shard zero
        self.current_shard = 0
        self.tokens = load_tokens(self.shards[self.current_shard])
        self.current_position = self.B * self.T

    def next_batch(self):
        B, T = self.B, self.T
        buf = self.tokens[self.current_position : self.current_position+B*T+1]
        x = (buf[:-1]).view(B, T) # inputs
        y = (buf[1:]).view(B, T) # targets
        # advance the position in the tensor
        self.current_position += B * T
        # if loading the next batch would be out of bounds, advance to next shard
        if self.current_position + (B * T + 1) > len(self.tokens):
            self.current_shard = (self.current_shard + 1) % len(self.shards)
            self.tokens = load_tokens(self.shards[self.current_shard])
            self.current_position = B * T
        return x, y

### Rotary Position Embedding

In [202]:
class RotaryPositionEmbedding(nn.Module):
    def __init__(self, dim, max_seq_lenght):
        super().__init__()
        self.max_seq_len = max_seq_lenght
        inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
        pos = torch.arange(self.max_seq_len).float()
        freqs = torch.einsum("i,j->ij", pos, inv_freq)
        self.cos = torch.cos(freqs)
        self.sin = torch.sin(freqs)
        
    def forward(self, x): 
        self.seq_len = x.size(2)
        x1 = x[..., ::2]
        x2 = x[..., 1::2]
        cos = self.cos[:self.seq_len].unsqueeze(0).to(x.device).view(list(x1.size()))
        sin = self.sin[:self.seq_len].unsqueeze(0).to(x.device).view(list(x1.size()))
        x_rotated = torch.cat([x1 * cos - x2 * sin, x1 * sin + x2 * cos], dim=-1)
        
        return x_rotated 

### SwiGLU

In [203]:
class SwiGLU(nn.Module):
    def __init__(self, input_dimension, hidden_dimension):
        super().__init__()
        # First linear layer outputs 2 * hidden_dimension for the gate and value
        self.linear1 = nn.Linear(input_dimension, 2 * hidden_dimension, bias=True)
        # Second linear layer takes hidden_dimension and outputs input_dimension
        self.linear2 = nn.Linear(hidden_dimension, input_dimension, bias=True)
    
    def forward(self, x): 
        combined = self.linear1(x)
        a, b = combined.chunk(2, dim=-1)
        swish = b * torch.sigmoid(b)
        output = self.linear2(swish * a)
        return output

### RMSNorm

In [204]:
class RMSNorm(nn.Module):
    def __init__(self, input_shape, eps=1e-6):
        super().__init__()
        self.g = nn.Parameter(torch.ones(input_shape))
        self.b = nn.Parameter(torch.ones(input_shape))
        self.eps = eps

    def forward(self, x):
        rms = torch.sqrt(torch.mean(x ** 2, dim=-1, keepdim=True) + self.eps)
        output = x / rms 
        output = (output * self.g) + self.b
        return output 

### Attention Mechanism

In [205]:
class CausalSelfAttention(nn.Module):
    def __init__(self, n_embed, n_head, max_seq_lenght, eps=1e-5):
        super().__init__()
        self.n_embd = n_embed
        self.n_head = n_head
        self.max_seq_lenght = max_seq_lenght 
        self.eps = eps 
        # key, query, value projections for all heads, but in a batch
        self.c_attn = nn.Linear(self.n_embd, 3 * self.n_embd)
        self.alpha = nn.Parameter(torch.ones(self.n_head))
        # output projection
        self.c_proj = nn.Linear(self.n_embd, self.n_embd)
        self.c_proj.NANOGPT_SCALE_INIT = 1
        self.rope_q = RotaryPositionEmbedding(self.n_embd, self.max_seq_lenght) 
        self.rope_k = RotaryPositionEmbedding(self.n_embd, self.max_seq_lenght)
    def forward(self, x):
        B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)
        # calculate query, key, values for all heads in batch and move head forward to be the batch dim
        # nh is "number of heads", hs is "head size", and C (number of channels) = nh * hs
        # e.g. in GPT-2 (124M), n_head=12, hs=64, so nh*hs=C=768 channels in the Transformer
        qkv = self.c_attn(x)
        q, k, v = qkv.split(self.n_embd, dim=2)
        k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, number of heads, T, head_size)
        q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, numebr of heads, T, head_size)
        v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, number_of_heads, T, head_size)
        q_norm = torch.norm(q, dim=-1, keepdim=True)
        k_norm = torch.norm(k, dim=-1, keepdim=True)
        q_hat = q / (q_norm + self.eps)
        k_hat = k / (k_norm + self.eps)

        factor = self.alpha * math.sqrt(C // self.n_head)
        factor = factor.view(1, self.n_head, 1, 1)
        q_scaled = q_hat * factor
         
        q_scaled = self.rope_q(q_scaled)
        k_hat = self.rope_k(q_scaled)
        y = F.scaled_dot_product_attention(q_scaled, k_hat, v, is_causal=True, dropout_p=0.0)

        y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side
        y = self.c_proj(y)
        return y

In [206]:
# (Batch, number of heads, sequence lenght, head size)

input = torch.randn([1, 128, 256])
attention = CausalSelfAttention(256, 4, input.size(1))
output = attention(input)
output.shape

torch.Size([1, 128, 256])

### Expert

In [207]:
class Expert(nn.Module):
    def __init__(self, embed_dim):
        super().__init__()
        self.expert = nn.Sequential(
            nn.Linear(embed_dim, 4 * embed_dim),
            SwiGLU(4 * embed_dim, 4 * embed_dim),
            nn.Linear(4 * embed_dim, embed_dim),
        )
    def forward(self, x): 
        return self.expert(x)

### Router

In [208]:
class Router(nn.Module):
    def __init__(self, num_experts, embed_dim):
        super().__init__()
        self.num_experts = num_experts
        self.embed_dim = embed_dim

        self.router = nn.Sequential(
            nn.Linear(self.embed_dim, self.num_experts),
            nn.Softmax(dim=-1),
        )

    def forward(self, x): 
        return self.router(x)

In [209]:
# x = torch.randn([1, 128, 256])
# router = Router(12, 128)
# router(x)

### Block

In [253]:
class Block(nn.Module):
    def __init__(self, embed_dim, num_heads, num_experts, max_seq_lenght):
        super(Block, self).__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.num_experts = num_experts
        self.max_seq_lenght = max_seq_lenght

        self.RMSNorm = RMSNorm(self.embed_dim)
        self.MultiheadAttention = CausalSelfAttention(self.embed_dim, self.num_heads, self.max_seq_lenght)
        self.router = Router(self.num_experts, self.embed_dim)
        self.experts = nn.ModuleList([Expert(self.embed_dim) for _ in range(self.num_experts)])


    def forward(self, x):
        x = x + self.MultiheadAttention(self.RMSNorm(x))
        
        routes = self.router(x)
        top8_probs, top8_indices = torch.topk(routes, k=8, dim=2) 
        top8_probs = top8_probs / top8_probs.sum(dim=-1, keepdim=True)
        expert_output = torch.zeros_like(x)
        for k in range(8):
            expert_idx = top8_indices[:, :, k]
            prob = top8_probs[:, :, k]
            for expert_id in range(self.num_experts):
                mask = (expert_idx == expert_id)
                if mask.sum() == 0:
                    continue
                x_selected = x[mask]
                expert_out = self.experts[expert_id](x_selected)
                prob_selected = prob[mask].unsqueeze(-1)
                weighted_out = expert_out * prob_selected
                expert_output[mask] += weighted_out
        
        x = x + expert_output
        
        return x, top8_indices, top8_probs

### Final Model

In [251]:
class Model(nn.Module):
    def __init__(self, vocab_size, embed_dim, max_seq_lenght, num_heads, num_experts):
        super().__init__()
        self.vocab_size = vocab_size
        self.embed_dim = embed_dim
        self.max_seq_lenght = max_seq_lenght
        self.num_heads = num_heads
        self.num_experts = num_experts

        self.embedding = nn.Embedding(num_embeddings=self.vocab_size, embedding_dim=self.embed_dim, dtype=torch.float32).to(device)
        
        self.blocks = Block(self.embed_dim, self.num_heads, self.num_experts, self.max_seq_lenght)
        self.rmsnorm = RMSNorm(self.embed_dim)
        self.output_linear = nn.Linear(self.embed_dim, self.vocab_size)

    def forward(self, x):
        B, T = x.shape
        x = self.embedding(x)
        x_debug, top8_indicies, top8_probs = self.blocks(x)
        x += x_debug
        output = self.rmsnorm(x)
        output = self.output_linear(output)

        return output, top8_indicies, top8_probs

## Loss functions

In [None]:
def load_balancing_loss(num_experts: int,
                        topk_probs: torch.Tensor,      # [B, T, K]
                        topk_indices: torch.Tensor,    # [B, T, K]
                        alpha: float = 0.01):
    B, T, K = topk_indices.shape
    tot_tokens = B * T

    # mask[b,t,k,e] == 1 if that (token,k) routes to expert e
    mask = (topk_indices.unsqueeze(-1) ==
            torch.arange(num_experts, device=topk_indices.device))

    # f_i  –– fraction of tokens routed to expert i
    tokens_per_expert = mask.any(dim=2).sum((0,1)).float()      # [E]
    f = tokens_per_expert / tot_tokens

    # P_i –– mean router prob mass arriving at expert i
    probs_per_expert = (topk_probs.unsqueeze(-1) *
                        mask.float()).sum((0,1,2))              # [E]
    P = probs_per_expert / tot_tokens

    lb_loss = alpha * num_experts * (f * P).sum()
    return lb_loss

## Training loop

In [None]:
def train():
    