# Ref:
- https://github.com/TimS-ml/nanoGPT
- https://youtu.be/kCc8FmEb1nY

In [1]:
import os
from boring_llm_base.constants import PROJECT_HOME_DIR
import sys; sys.path.append(str(PROJECT_HOME_DIR)); os.chdir(PROJECT_HOME_DIR)
import math
import random
import tqdm
import gzip
import pickle
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.nn import functional as F
from torch.utils.data import DataLoader, Dataset

from torch import Tensor
from typing import Optional, Tuple, Union, List
from jaxtyping import Float, Bool

from einops import rearrange, repeat, reduce

from boring_utils.utils import (
    cprint, 
    tprint, 
    get_device
)

# Config

In [2]:
# from boring_nn.attention.config import AttentionConfig
# cfg = AttentionConfig()
# cprint(cfg)

batch_size = 4  # how many independent sequences will we process in parallel?
block_size = 8  # time steps (seq length, context window)
n_embed = 36    # channels (embedding dim)

t_enc, t_dec = 10, block_size  # encoder/decoder sequence lengths 
n_head = 6
assert n_embed % n_head == 0
n_layer = 6
dropout = 0.2

max_iters = 100
eval_interval = 100
learning_rate = 3e-4
eval_iters = 100

device = get_device()
# vocab_size = len(set(text))
cprint(device)

[93m<module> -> device:[0m device(type='mps')


# Data Loader

In [3]:
data_dir = os.getenv('DATA_DIR', './data/')
data_dir = os.path.join(data_dir, 'enwik8')

# # NOTE: only read enwik8 first 10M bytes
# with gzip.open(os.path.join(data_dir, 'enwik8.gz')) as file:
#     text = file.read(int(10e6)).decode('utf-8')

meta_path = os.path.join(data_dir, 'meta.pkl')
vocab_size = None
if os.path.exists(meta_path):
    with open(meta_path, 'rb') as f:
        meta = pickle.load(f)
    vocab_size = meta['vocab_size']
    stoi = meta['stoi']
    itos = meta['itos']
else:
    raise FileNotFoundError(f"Meta file {meta_path} not found")

encode = lambda s: [stoi[c] for c in s] 
decode = lambda l: ''.join([itos[i] for i in l]) 

train_bin_path = os.path.join(data_dir, 'train.bin')
val_bin_path = os.path.join(data_dir, 'val.bin')

# train_tensor = torch.tensor(encode(data), dtype=torch.long) # convert to tensor

# torch.long is just an alias for torch.int64
# load the binary data
train_data = np.fromfile(train_bin_path, dtype=np.uint16)
val_data = np.fromfile(val_bin_path, dtype=np.uint16)

# convert to pytorch tensors
train_data = torch.from_numpy(train_data.astype(np.int64))
val_data = torch.from_numpy(val_data.astype(np.int64))

class TextSamplerDataset(Dataset):
    def __init__(self, data, block_size):
        self.data = data
        self.block_size = int(block_size)

    def __getitem__(self, index):
        # single sample
        ix = torch.randint(
            len(self.data) - self.block_size - 1, (1,)
        )
        full_seq = self.data[ix:ix + self.block_size + 1]
        x = full_seq[:-1]
        y = full_seq[1:]
        x, y = x.to(device), y.to(device)
        return x, y

    def __len__(self):
        return len(self.data) // self.block_size


train_dataset = TextSamplerDataset(train_data, block_size)
val_dataset = TextSamplerDataset(val_data, block_size)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=False)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

# Model

In [4]:
class BigramLanguageModel(nn.Module):
    def __init__(self, vocab_size):
        super().__init__()
        embedding_dim = vocab_size
        # embedding_dim = 128
        # each token is represented by a one-hot vector
        # directly reads off the logits for the next token from the embedding table
        # for example: 24 will reads off the 24th column of the embedding table
        self.embedding = nn.Embedding(vocab_size, embedding_dim)

    def forward(self, idx, targets=None):
        # idx is (batch_size, block_size)
        logits = self.embedding(idx)  # B, T, C: (batch_size, block_size, embedding_dim)

        if targets == None:
            loss = None
        else:
            B, T, C = logits.shape
            logits = logits.view(B*T, C)  # (batch_size * block_size, embedding_dim)
            targets = targets.view(-1)  # (batch_size * block_size)
            loss = F.cross_entropy(logits, targets)

        return logits, loss

model = BigramLanguageModel(vocab_size)
model.to(device)

BigramLanguageModel(
  (embedding): Embedding(2102, 2102)
)

In [6]:
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)

# Training

In [7]:
def cycle(loader):
    while True:
        for data in loader:
            yield data

train_iter = cycle(train_loader)
val_iter = cycle(val_loader)

def train(
        model: nn.Module = model,
        train_iter: DataLoader = train_iter,
        val_iter: DataLoader = val_iter,
        eval_iters: int = eval_iters,
        max_iters: int = max_iters,
        eval_interval: int = eval_interval,
    ):
    for iter in range(max_iters):
        # Eval logic
        if iter % eval_interval == 0 or iter == max_iters - 1:
            model.eval()
            with torch.no_grad():
                val_losses = []
                for _, (x, y) in zip(range(eval_iters), val_iter):
                    _, loss = model(x, y)
                    val_losses.append(loss.item())
                val_loss = np.mean(val_losses)

                train_losses = []
                for _, (x, y) in zip(range(eval_iters), train_iter):
                    _, loss = model(x, y)
                    train_losses.append(loss.item())
                train_loss = np.mean(train_losses)

                print(f"step {iter}: train loss {train_loss:.4f}, val loss {val_loss:.4f}")
            model.train()

        # Training logic
        x, y = next(train_iter)  # replace get_batch
        logits, loss = model(x, y)
        optimizer.zero_grad(set_to_none=True)
        loss.backward()
        optimizer.step()

# Dev Zone

From nanoGPT:

```python
class Head(nn.Module):
    """ one head of self-attention """

    def __init__(self, head_size):
        super().__init__()
        self.key = nn.Linear(n_embd, head_size, bias=False)
        self.query = nn.Linear(n_embd, head_size, bias=False)
        self.value = nn.Linear(n_embd, head_size, bias=False)
        self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size)))

        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        # input of size (batch, time-step, channels)
        # output of size (batch, time-step, head size)
        B,T,C = x.shape
        k = self.key(x)   # (B,T,hs)
        q = self.query(x) # (B,T,hs)
        # compute attention scores ("affinities")
        wei = q @ k.transpose(-2,-1) * k.shape[-1]**-0.5 # (B, T, hs) @ (B, hs, T) -> (B, T, T)
        wei = wei.masked_fill(self.tril[:T, :T] == 0, float('-inf')) # (B, T, T)
        wei = F.softmax(wei, dim=-1) # (B, T, T)
        wei = self.dropout(wei)
        # perform the weighted aggregation of the values
        v = self.value(x) # (B,T,hs)
        out = wei @ v # (B, T, T) @ (B, T, hs) -> (B, T, hs)
        return out

class MultiHeadAttention(nn.Module):
    """ multiple heads of self-attention in parallel """

    def __init__(self, num_heads, head_size):
        super().__init__()
        self.heads = nn.ModuleList([Head(head_size) for _ in range(num_heads)])
        self.proj = nn.Linear(head_size * num_heads, n_embd)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        out = torch.cat([h(x) for h in self.heads], dim=-1)
        out = self.dropout(self.proj(out))
        return out
```

## QKV

In [8]:
x = torch.randn(batch_size, block_size, n_embed)

def qkv_normal(
    x: Float[Tensor, "batch seq_len embedding_dim"],
    num_heads: int = n_head
):
    batch, seq_len, dim = x.shape
    head_dim = dim // num_heads
    
    # Fixed weights, for testing `assert torch.allclose`
    fixed_weight = torch.randn(dim, dim)

    # Separated QKV (x-transformer style)
    def impl1(x):
        to_q = nn.Linear(dim, dim, bias=False)
        to_kv = nn.Linear(dim, dim * 2, bias=False)
        to_q.weight.data = fixed_weight.clone()
        to_kv.weight.data = torch.cat([fixed_weight, fixed_weight]).clone()

        q = to_q(x)
        k, v = to_kv(x).chunk(2, dim=-1)

        q, k, v = map(lambda t: rearrange(t, 'batch seq_len (num_heads head_dim) -> batch num_heads seq_len head_dim', num_heads = num_heads), (q, k, v))
        return q, k, v

    # Unified QKV (lit-gpt style)
    def impl2(x):
        to_qkv = nn.Linear(dim, dim * 3, bias=False)
        to_qkv.weight.data = torch.cat([fixed_weight] * 3).clone() 

        # Split into q,k,v and reshape to add head dimension
        q, k, v  = to_qkv(x).split(n_embed, dim=2)
        # Permute to get (batch, num_heads, seq_len, head_dim)
        q, k, v = map(
            lambda t: t.view(batch, -1, num_heads, head_dim).transpose(1, 2),
            (q, k, v)
        )

        return q, k, v

    # nanoGPT style
    def impl3(x):
        to_q = nn.Linear(dim, head_dim, bias=False)
        to_k = nn.Linear(dim, head_dim, bias=False)
        to_v = nn.Linear(dim, head_dim, bias=False)
        to_q.weight.data = fixed_weight.clone()
        to_k.weight.data = fixed_weight.clone()
        to_v.weight.data = fixed_weight.clone()
        
        k = to_k(x)
        q = to_q(x)
        v = to_v(x)
        
        q = q.view(batch, seq_len, num_heads, head_dim).transpose(1, 2)
        k = k.view(batch, seq_len, num_heads, head_dim).transpose(1, 2)
        v = v.view(batch, seq_len, num_heads, head_dim).transpose(1, 2)
        return q, k, v    

    # Compare results
    q1, k1, v1 = impl1(x)
    q2, k2, v2 = impl2(x)
    q3, k3, v3 = impl3(x)
    
    tprint('allclose', sep='*')
    assert torch.allclose(q1, q2, rtol=1e-4)
    assert torch.allclose(k1, k2, rtol=1e-4)
    assert torch.allclose(v1, v2, rtol=1e-4)
    assert torch.allclose(q1, q3, rtol=1e-4)
    assert torch.allclose(k1, k3, rtol=1e-4)
    assert torch.allclose(v1, v3, rtol=1e-4)
    print('passed')

    tprint('shape', sep='*')
    try:
        assert q1.shape == q2.shape
        assert k1.shape == k2.shape
        assert v1.shape == v2.shape
        assert q1.shape == q3.shape
        assert k1.shape == k3.shape
        assert v1.shape == v3.shape
        print('passed')
    except Exception as e:
        cprint(q1.shape, k1.shape, v1.shape)
        cprint(q2.shape, k2.shape, v2.shape)
        cprint(q3.shape, k3.shape, v3.shape)
        raise e


qkv_normal(x)

[33m
******************** qkv_normal -> allclose ********************[0m
passed
[33m
******************** qkv_normal -> shape ********************[0m
passed


### Grouped QKV

From lit-gpt:
```
to use multi-head attention (MHA), set this to `n_head` (default)
to use multi-query attention (MQA), set this to 1
to use grouped-query attention (GQA), set this to a value in between
Example with `n_head=4`
┌───┐┌───┐┌───┐┌───┐     ┌───┐    ┌───┐             ┌───┐
│ v ││ v ││ v ││ v │     │ v │    │ v │             │ v │
└───┘└───┘└───┘└───┘     └───┘    └───┘             └───┘
  │    │    │    │         │        │                 │
┌───┐┌───┐┌───┐┌───┐     ┌───┐    ┌───┐             ┌───┐
│ k ││ k ││ k ││ k │     │ k │    │ k │             │ k │
└───┘└───┘└───┘└───┘     └───┘    └───┘             └───┘
  │    │    │    │      ┌──┴──┐  ┌──┴──┐      ┌────┬──┴─┬────┐
┌───┐┌───┐┌───┐┌───┐  ┌───┐┌───┐┌───┐┌───┐  ┌───┐┌───┐┌───┐┌───┐
│ q ││ q ││ q ││ q │  │ q ││ q ││ q ││ q │  │ q ││ q ││ q ││ q │
└───┘└───┘└───┘└───┘  └───┘└───┘└───┘└───┘  └───┘└───┘└───┘└───┘
◀──────────────────▶  ◀──────────────────▶  ◀──────────────────▶
        MHA                    GQA                   MQA
  n_query_groups=4       n_query_groups=2      n_query_groups=1

credit https://arxiv.org/pdf/2305.13245.pdf
```

In [9]:
def qkv_grouped(
    x: Float[Tensor, "batch seq_len embedding_dim"],
    num_heads: int = n_head,  # Number of heads in Q
    n_query_groups: int = 2   # Number of heads in KV
):
    """
    Number of heads in Q is unchanged.
    Number of heads in KV is n_query_groups.
    """

    batch, seq_len, dim = x.shape
    head_dim = dim // num_heads
    
    # TODO: Fixed weights for testing
    # q_weight = torch.randn(dim, dim)
    # k_weight = torch.randn(dim, head_dim * n_query_groups)
    # v_weight = torch.randn(dim, head_dim * n_query_groups)

    # 1. x-transformer style implementation (Separated QKV + GQA)
    # in x-transformers, n_query_groups is called kv_heads
    def impl1(x):
        to_q = nn.Linear(dim, dim, bias=False)  # or head_dim * num_heads
        to_kv = nn.Linear(dim, head_dim * n_query_groups * 2, bias=False)
        # to_q.weight.data = q_weight.clone()
        # to_kv.weight.data = torch.cat([k_weight, v_weight], dim=1)

        q = to_q(x)
        k, v = to_kv(x).chunk(2, dim=-1)

        # Handle grouped-query attention
        if n_query_groups == 1:  # MQA case
            # method 1
            # k, v = tuple(rearrange(t, 'b n (h d) -> b h n d', h=kv_heads) for t in (k, v))
            # k, v = tuple(repeat(t, 'b h n d -> b (r h) n d', r=num_heads // kv_heads) for t in (k, v))

            # method 2
            # k, v = map(
            #     lambda t: repeat(
            #         rearrange(t, 'b n (h d) -> b h n d', h=kv_heads),
            #         'b h n d -> b (r h) n d',
            #         r=num_heads // kv_heads
            #     ),
            #     (k, v)
            # )

            # method 3
            # k, v = tuple(repeat(t, 'b n (h d) -> b (r h) n d', h=n_query_groups, r=num_heads // n_query_groups) for t in (k, v))
            k, v = tuple(repeat(
                    t, 'batch seq_len (kv_heads head_dim) -> batch (q_per_kv kv_heads) seq_len head_dim', 
                    kv_heads=n_query_groups, 
                    q_per_kv=num_heads // n_query_groups
                ) for t in (k, v))
        elif n_query_groups < num_heads:  # GQA case
            k, v = tuple(repeat(
                    t, 'batch seq_len (kv_heads head_dim) -> batch (q_per_kv kv_heads) seq_len head_dim', 
                    kv_heads=n_query_groups, 
                    q_per_kv=num_heads // n_query_groups
                ) for t in (k, v))
        else:  # MHA case
            k, v = map(lambda t: rearrange(t, 'batch seq_len (num_heads head_dim) -> batch num_heads seq_len head_dim', num_heads = num_heads), (k, v))

        q = rearrange(q, 'batch seq_len (num_heads head_dim) -> batch num_heads seq_len head_dim', num_heads = num_heads)
        return q, k, v

    # 2. lit-gpt style implementation (Unified QKV + GQA)
    def impl2(x):
        # Calculate number of Q per KV group
        q_per_kv = num_heads // n_query_groups
        total_qkv = q_per_kv + 2  # Each group has q_per_kv queries + 1 key + 1 value
        # qkv_dim = (num_heads + 2 * n_query_groups) * head_dim
        qkv_dim = (total_qkv * n_query_groups) * head_dim  # save value

        to_qkv = nn.Linear(dim, qkv_dim, bias=False)
        # qkv_weight = torch.cat([
        #     q_weight,
        #     k_weight.repeat(q_per_kv, 1),
        #     v_weight.repeat(q_per_kv, 1)
        # ], dim=1)
        # to_qkv.weight.data = qkv_weight.clone()

        qkv = to_qkv(x)
        qkv = qkv.view(batch, seq_len, n_query_groups, total_qkv, head_dim)
        qkv = qkv.permute(0, 2, 3, 1, 4)  # (B, n_query_groups, total_qkv, T, hs)
        
        # Split Q, K, V
        q, k, v = qkv.split((q_per_kv, 1, 1), dim=2)
        
        # Handle K, V expansion for MQA/GQA
        if n_query_groups != num_heads:
            k = k.expand(batch, n_query_groups, q_per_kv, seq_len, head_dim)
            v = v.expand(batch, n_query_groups, q_per_kv, seq_len, head_dim)
        
        # Reshape to final shape
        q = q.reshape(batch, -1, seq_len, head_dim)  # (B, nh_q, T, hs)
        k = k.reshape(batch, -1, seq_len, head_dim)  # (B, nh_k, T, hs)
        v = v.reshape(batch, -1, seq_len, head_dim)  # (B, nh_v, T, hs)
        
        return q, k, v

    # Validate implementations
    q1, k1, v1 = impl1(x)
    q2, k2, v2 = impl2(x)
    
    # tprint('allclose', sep='*')
    # assert torch.allclose(q1, q2, rtol=1e-4)
    # assert torch.allclose(k1, k2, rtol=1e-4)
    # assert torch.allclose(v1, v2, rtol=1e-4)
    # print('passed')

    tprint('shape', sep='*')
    try:
        assert q1.shape == q2.shape
        assert k1.shape == k2.shape
        assert v1.shape == v2.shape
        print('passed')
    except Exception as e:
        cprint(q1.shape, k1.shape, v1.shape)
        cprint(q2.shape, k2.shape, v2.shape)
        raise e


# Test different GQA configurations
x = torch.randn(batch_size, block_size, n_embed)

# Test standard multi-head attention
tprint('MHA')
qkv_grouped(x, num_heads=n_head, n_query_groups=n_head)

# Test MQA (1 KV head)
tprint('MQA')
qkv_grouped(x, num_heads=n_head, n_query_groups=1)

# Test GQA (e.g. 8 heads, 2 groups)
tprint('GQA')
qkv_grouped(x, num_heads=n_head, n_query_groups=2)

[35m
[33m
******************** qkv_grouped -> shape ********************[0m
passed
[35m
[33m
******************** qkv_grouped -> shape ********************[0m
passed
[35m
[33m
******************** qkv_grouped -> shape ********************[0m
passed


## PE

Sin/Cos PE:

$$
PE(\text{position}, 2i) = \sin\bigg( \frac{ \text{position} }{10000^\frac{2i}{d_{model}}} \bigg)
$$

$$
PE(\text{position}, 2i+1) = \cos\bigg( \frac{ \text{position} }{10000^\frac{2i}{d_{model}}} \bigg)
$$


In [10]:
def positional_encoding_normal(
    x: Float[Tensor, "batch seq_len embedding_dim"],
    num_heads: int = n_head
):
    batch, seq_len, dim = x.shape
    head_dim = dim // num_heads
    
    # Fixed weights for testing
    # fixed_weight = torch.randn(seq_len, dim)

    # x-transformer style (using sin-cos)
    def impl1(x):
        # pe = fixed_weight.clone()
        position = torch.arange(0, seq_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, dim, 2) * -(math.log(10000.0) / dim))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)
        
        return x + pe

    # nanoGPT style (simple positional embedding)
    def impl2(x):
        position_embeddings = nn.Embedding(seq_len, dim)
        # position_embeddings.weight.data = fixed_weight.clone()
        
        positions = torch.arange(seq_len, device=x.device).unsqueeze(0)
        position_embeds = position_embeddings(positions)
        return x + position_embeds

    # Compare results
    out1 = impl1(x)
    out2 = impl2(x)
    
    tprint('shape', sep='*')
    cprint(out1.shape, out2.shape)
    assert out1.shape == out2.shape
    print('passed')


# positional_encoding_normal(x)

### RoPE

Ref:
- https://huggingface.co/jinaai/xlm-roberta-flash-implementation/blob/main/rotary.py
- https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/layers/rotary.py

In [11]:
def positional_encoding_RoPE(
    x: Float[Tensor, "batch seq_len embedding_dim"],
    num_heads: int = n_head
):
    batch, seq_len, dim = x.shape
    head_dim = dim // num_heads
    
    # Fixed weights for testing
    fixed_weight = torch.randn(seq_len, dim)

    # x-transformer style (using sin-cos)
    def impl1(x):
        pe = fixed_weight.clone()
        position = torch.arange(0, seq_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, dim, 2) * -(math.log(10000.0) / dim))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)
        
        return x + pe

    # lit-gpt style (using RoPE)
    def impl2(x):
        # Generate sin/cos for RoPE
        inv_freq = 1.0 / (10000 ** (torch.arange(0, head_dim, 2) / head_dim))
        t = torch.arange(seq_len)
        freqs = torch.einsum('i,j->ij', t, inv_freq)
        emb = torch.cat((freqs, freqs), dim=-1)
        cos = emb.cos().view(1, seq_len, 1, head_dim)
        sin = emb.sin().view(1, seq_len, 1, head_dim)
        
        # Apply RoPE
        x_reshaped = x.view(batch, seq_len, num_heads, head_dim)
        x_rope = apply_rope(x_reshaped, cos, sin)
        return x_rope.view(batch, seq_len, -1)

    # Compare results
    out1 = impl1(x)
    out2 = impl2(x)
    
    tprint('shape', sep='*')
    cprint(out1.shape, out2.shape)
    assert out1.shape == out2.shape
    print('passed')


# positional_encoding_RoPE(x)