# Ref:
- https://github.com/TimS-ml/nanoGPT
- https://youtu.be/kCc8FmEb1nY

In [1]:
import os
from boring_llm_base.constants import PROJECT_HOME_DIR
import sys; sys.path.append(str(PROJECT_HOME_DIR)); os.chdir(PROJECT_HOME_DIR)
import math
import random
import tqdm
import gzip
import pickle
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.nn import functional as F
from torch.utils.data import DataLoader, Dataset

from torch import Tensor
from typing import Optional, Tuple, Union, List
from jaxtyping import Float, Bool

from einops import rearrange, repeat, reduce

from boring_utils.utils import (
    cprint, 
    tprint, 
    get_device
)

# Config

In [2]:
# from boring_nn.attention.config import AttentionConfig
# cfg = AttentionConfig()
# cprint(cfg)

batch_size = 4  # how many independent sequences will we process in parallel?
block_size = 8  # time steps (seq length, context window)
n_embed = 36    # channels (embedding dim)

t_enc, t_dec = 10, block_size  # encoder/decoder sequence lengths 
n_head = 6
assert n_embed % n_head == 0
n_layer = 6
dropout = 0.2

max_iters = 100
eval_interval = 100
learning_rate = 3e-4
eval_iters = 100

device = get_device()
# vocab_size = len(set(text))
cprint(device)

[93m<module> -> device:[0m
device(type='mps')


# Data Loader

In [3]:
data_dir = os.getenv('DATA_DIR', './data/')
data_dir = os.path.join(data_dir, 'enwik8')

# # NOTE: only read enwik8 first 10M bytes
# with gzip.open(os.path.join(data_dir, 'enwik8.gz')) as file:
#     text = file.read(int(10e6)).decode('utf-8')

meta_path = os.path.join(data_dir, 'meta.pkl')
vocab_size = None
if os.path.exists(meta_path):
    with open(meta_path, 'rb') as f:
        meta = pickle.load(f)
    vocab_size = meta['vocab_size']
    stoi = meta['stoi']
    itos = meta['itos']
else:
    raise FileNotFoundError(f"Meta file {meta_path} not found")

encode = lambda s: [stoi[c] for c in s] 
decode = lambda l: ''.join([itos[i] for i in l]) 

train_bin_path = os.path.join(data_dir, 'train.bin')
val_bin_path = os.path.join(data_dir, 'val.bin')

# train_tensor = torch.tensor(encode(data), dtype=torch.long) # convert to tensor

# torch.long is just an alias for torch.int64
# load the binary data
train_data = np.fromfile(train_bin_path, dtype=np.uint16)
val_data = np.fromfile(val_bin_path, dtype=np.uint16)

# convert to pytorch tensors
train_data = torch.from_numpy(train_data.astype(np.int64))
val_data = torch.from_numpy(val_data.astype(np.int64))

class TextSamplerDataset(Dataset):
    def __init__(self, data, block_size):
        self.data = data
        self.block_size = int(block_size)

    def __getitem__(self, index):
        # single sample
        ix = torch.randint(
            len(self.data) - self.block_size - 1, (1,)
        )
        full_seq = self.data[ix:ix + self.block_size + 1]
        x = full_seq[:-1]
        y = full_seq[1:]
        x, y = x.to(device), y.to(device)
        return x, y

    def __len__(self):
        return len(self.data) // self.block_size


train_dataset = TextSamplerDataset(train_data, block_size)
val_dataset = TextSamplerDataset(val_data, block_size)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=False)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

# Model

In [4]:
class BigramLanguageModel(nn.Module):
    def __init__(self, vocab_size):
        super().__init__()
        embedding_dim = vocab_size
        # embedding_dim = 128
        # each token is represented by a one-hot vector
        # directly reads off the logits for the next token from the embedding table
        # for example: 24 will reads off the 24th column of the embedding table
        self.embedding = nn.Embedding(vocab_size, embedding_dim)

    def forward(self, idx, targets=None):
        # idx is (batch_size, block_size)
        logits = self.embedding(idx)  # B, T, C: (batch_size, block_size, embedding_dim)

        if targets == None:
            loss = None
        else:
            B, T, C = logits.shape
            logits = logits.view(B*T, C)  # (batch_size * block_size, embedding_dim)
            targets = targets.view(-1)  # (batch_size * block_size)
            loss = F.cross_entropy(logits, targets)

        return logits, loss

model = BigramLanguageModel(vocab_size)
model.to(device)

BigramLanguageModel(
  (embedding): Embedding(2102, 2102)
)

In [5]:
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)

# Training

In [6]:
def cycle(loader):
    while True:
        for data in loader:
            yield data

train_iter = cycle(train_loader)
val_iter = cycle(val_loader)

def train(
        model: nn.Module = model,
        train_iter: DataLoader = train_iter,
        val_iter: DataLoader = val_iter,
        eval_iters: int = eval_iters,
        max_iters: int = max_iters,
        eval_interval: int = eval_interval,
    ):
    for iter in range(max_iters):
        # Eval logic
        if iter % eval_interval == 0 or iter == max_iters - 1:
            model.eval()
            with torch.no_grad():
                val_losses = []
                for _, (x, y) in zip(range(eval_iters), val_iter):
                    _, loss = model(x, y)
                    val_losses.append(loss.item())
                val_loss = np.mean(val_losses)

                train_losses = []
                for _, (x, y) in zip(range(eval_iters), train_iter):
                    _, loss = model(x, y)
                    train_losses.append(loss.item())
                train_loss = np.mean(train_losses)

                print(f"step {iter}: train loss {train_loss:.4f}, val loss {val_loss:.4f}")
            model.train()

        # Training logic
        x, y = next(train_iter)  # replace get_batch
        logits, loss = model(x, y)
        optimizer.zero_grad(set_to_none=True)
        loss.backward()
        optimizer.step()

# Dev Zone

From nanoGPT:

```python
class Head(nn.Module):
    """ one head of self-attention """

    def __init__(self, head_size):
        super().__init__()
        self.key = nn.Linear(n_embd, head_size, bias=False)
        self.query = nn.Linear(n_embd, head_size, bias=False)
        self.value = nn.Linear(n_embd, head_size, bias=False)
        self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size)))

        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        # input of size (batch, time-step, channels)
        # output of size (batch, time-step, head size)
        B,T,C = x.shape
        k = self.key(x)   # (B,T,hs)
        q = self.query(x) # (B,T,hs)
        # compute attention scores ("affinities")
        wei = q @ k.transpose(-2,-1) * k.shape[-1]**-0.5 # (B, T, hs) @ (B, hs, T) -> (B, T, T)
        wei = wei.masked_fill(self.tril[:T, :T] == 0, float('-inf')) # (B, T, T)
        wei = F.softmax(wei, dim=-1) # (B, T, T)
        wei = self.dropout(wei)
        # perform the weighted aggregation of the values
        v = self.value(x) # (B,T,hs)
        out = wei @ v # (B, T, T) @ (B, T, hs) -> (B, T, hs)
        return out

class MultiHeadAttention(nn.Module):
    """ multiple heads of self-attention in parallel """

    def __init__(self, num_heads, head_size):
        super().__init__()
        self.heads = nn.ModuleList([Head(head_size) for _ in range(num_heads)])
        self.proj = nn.Linear(head_size * num_heads, n_embd)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        out = torch.cat([h(x) for h in self.heads], dim=-1)
        out = self.dropout(self.proj(out))
        return out
```

## Mask

In [11]:
def mask_normal(
    batch_size: int,
    seq_len: int,
    mask_type: str = "causal",  # "causal", "padding", or "bidirectional"
    device: Optional[torch.device] = None
) -> Bool[Tensor, "batch seq_len seq_len"]:
    
    # x-transformer style (support multiple mask types)
    def impl1():
        if mask_type == "causal":
            # Create lower triangular mask
            mask = torch.ones((seq_len, seq_len), device=device).triu(1).bool()
            mask = mask.unsqueeze(0).expand(batch_size, -1, -1)
        elif mask_type == "padding":
            # Generate random padding mask for testing
            mask = torch.rand(batch_size, seq_len) > 0.2
            mask = mask.unsqueeze(1) & mask.unsqueeze(2)
        else:  # bidirectional
            mask = torch.ones((batch_size, seq_len, seq_len), device=device).bool()
        return ~mask  # Note: here we take the inverse, so True means "allow attend"
    
    # lit-gpt style (focus on causal mask + cache optimization)
    def impl2():
        if mask_type != "causal":
            raise NotImplementedError("lit-gpt mainly supports causal mask")
        # Precompute and cache mask
        ones = torch.ones((seq_len, seq_len), dtype=torch.bool, device=device)
        mask = torch.tril(ones).unsqueeze(0)
        return mask.expand(batch_size, -1, -1)
    
    # nanoGPT style
    def impl3():
        if mask_type == "causal":
            mask = torch.tril(torch.ones(seq_len, seq_len, device=device))
            return mask.view(1, seq_len, seq_len).expand(batch_size, -1, -1)
        return torch.ones((batch_size, seq_len, seq_len), device=device)
    
    mask1 = impl1()
    mask2 = impl2()
    mask3 = impl3()
    
    tprint('shape', sep='*')
    cprint(mask1.shape)
    assert mask1.shape == mask2.shape == mask3.shape
    print('passed')
    
    if mask_type == "causal":
        tprint('allclose', sep='*')
        assert torch.allclose(mask1.float(), mask2.float())
        assert torch.allclose(mask1.float(), mask3.float())
        print('passed')
    
    return mask1


# mask_normal(batch_size, seq_len)

## Post Attention

In [12]:
def attention_post_normal(
    attn_weights: Float[Tensor, "batch num_heads seq_len seq_len"],
    dropout_p: float = 0.1
) -> Float[Tensor, "batch num_heads seq_len seq_len"]:
    batch, num_heads, seq_len, _ = attn_weights.shape
    
    # x-transformer style (support talking heads etc.)
    def impl1(attn):
        # Simulate talking heads
        talking_heads = nn.Linear(num_heads, num_heads, bias=False)
        talking_heads.weight.data = torch.eye(num_heads)  # Initialize as identity matrix
        
        attn = rearrange(attn, 'b h i j -> b i j h')
        attn = talking_heads(attn)
        attn = rearrange(attn, 'b i j h -> b h i j')
        
        attn = F.dropout(attn, p=dropout_p)
        return attn
    
    # lit-gpt style
    def impl2(attn):
        return F.dropout(attn, p=dropout_p)
    
    # nanoGPT style
    def impl3(attn):
        return F.dropout(attn, p=dropout_p)
    
    out1 = impl1(attn_weights)
    out2 = impl2(attn_weights)
    out3 = impl3(attn_weights)
    
    tprint('shape', sep='*')
    cprint(out1.shape)
    assert out1.shape == out2.shape == out3.shape
    print('passed')
    
    return out1


# attention_post_normal(attn_weights)

## KV Cache

In [13]:
def kv_cache_normal(
    k: Float[Tensor, "batch num_heads seq_len head_dim"],
    v: Float[Tensor, "batch num_heads seq_len head_dim"],
    past_key_values: Optional[Tuple[Tensor, Tensor]] = None,
    use_cache: bool = True
) -> Tuple[Tensor, Tensor]:
    batch, num_heads, seq_len, head_dim = k.shape
    
    # x-transformer style
    def impl1():
        if past_key_values is not None:
            past_k, past_v = past_key_values
            k = torch.cat([past_k, k], dim=2)
            v = torch.cat([past_v, v], dim=2)
        
        if use_cache:
            return (k, v), (k, v)
        return (k, v), None
    
    # lit-gpt style
    def impl2():
        if past_key_values is not None:
            past_k, past_v = past_key_values
            # Use efficient indexing to copy
            k_cache = torch.empty_like(past_k)
            v_cache = torch.empty_like(past_v)
            k_cache[:, :, :-seq_len] = past_k[:, :, seq_len:]
            v_cache[:, :, :-seq_len] = past_v[:, :, seq_len:]
            k_cache[:, :, -seq_len:] = k
            v_cache[:, :, -seq_len:] = v
            k, v = k_cache, v_cache
            
        if use_cache:
            return (k, v), (k, v)
        return (k, v), None
    
    # nanoGPT style
    def impl3():
        if past_key_values is not None:
            past_k, past_v = past_key_values
            k = torch.cat([past_k, k], dim=2)
            v = torch.cat([past_v, v], dim=2)
        
        if use_cache:
            return (k, v), (k, v)
        return (k, v), None
    
    (k1, v1), cache1 = impl1()
    (k2, v2), cache2 = impl2()
    (k3, v3), cache3 = impl3()
    
    tprint('shape', sep='*')
    cprint(k1.shape, v1.shape)
    assert k1.shape == k2.shape == k3.shape
    assert v1.shape == v2.shape == v3.shape
    print('passed')
   
    tprint('allclose', sep='*')
    assert torch.allclose(k1, k3, rtol=1e-4)
    assert torch.allclose(v1, v3, rtol=1e-4)
    if use_cache:
        assert all(torch.allclose(c1, c3, rtol=1e-4) 
                  for c1, c3 in zip(cache1, cache3))
    print('passed')
    
    return (k1, v1), cache1


# kv_cache_normal(k, v)

# Put Everything Together

In [14]:
class ModularAttention(nn.Module):
    """Configurable modular attention implementation"""
    
    def __init__(
        self,
        dim: int,
        num_heads: int,
        projection_type: str = "unified",
        qkv_bias: bool = False,
        attn_dropout: float = 0.0,
        scaling_type: str = "default"  # or "learned" or "fixed"
    ):
        super().__init__()
        self.dim = dim
        self.num_heads = num_heads
        self.head_dim = dim // num_heads
        
        # QKV projection configuration
        if projection_type == "unified":
            self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        elif projection_type == "separated":
            self.to_q = nn.Linear(dim, dim, bias=qkv_bias)
            self.to_kv = nn.Linear(dim, dim * 2, bias=qkv_bias)
        else:  # INDIVIDUAL
            self.to_q = nn.Linear(dim, dim, bias=qkv_bias)
            self.to_k = nn.Linear(dim, dim, bias=qkv_bias)
            self.to_v = nn.Linear(dim, dim, bias=qkv_bias)
            
        # Scaling configuration
        self.scaling_type = scaling_type
        if scaling_type == "learned":
            self.scale = nn.Parameter(torch.ones(1) / np.sqrt(self.head_dim))
        else:
            self.scale = 1.0 / np.sqrt(self.head_dim)
            
        self.attn_dropout = nn.Dropout(attn_dropout)
        
    def _project_qkv(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """Implement QKV projection based on different projection_type"""
        if hasattr(self, 'qkv'):
            qkv = self.qkv(x)
            q, k, v = qkv.chunk(3, dim=-1)
        elif hasattr(self, 'to_kv'):
            q = self.to_q(x)
            k, v = self.to_kv(x).chunk(2, dim=-1)
        else:
            q = self.to_q(x)
            k = self.to_k(x)
            v = self.to_v(x)
            
        return map(
            lambda t: t.view(t.shape[0], -1, self.num_heads, self.head_dim).transpose(1, 2),
            (q, k, v)
        )
    
    def _apply_attention(
        self,
        q: torch.Tensor,
        k: torch.Tensor,
        v: torch.Tensor,
        mask: Optional[torch.Tensor] = None
    ) -> torch.Tensor:
        """Calculate attention scores and apply to values"""
        
        # Calculate attention scores
        attn_weights = torch.matmul(q, k.transpose(-2, -1))
        
        # Apply scaling
        if isinstance(self.scale, nn.Parameter):
            attn_weights = attn_weights * self.scale
        else:
            attn_weights = attn_weights * self.scale
            
        # Apply mask
        if mask is not None:
            attn_weights = attn_weights.masked_fill(mask == 0, float('-inf'))
            
        # Softmax and dropout
        attn_weights = F.softmax(attn_weights, dim=-1)
        attn_weights = self.attn_dropout(attn_weights)
        
        # Apply attention weights to values
        output = torch.matmul(attn_weights, v)
        return output, attn_weights
    
    def forward(
        self,
        x: torch.Tensor,
        mask: Optional[torch.Tensor] = None
    ) -> torch.Tensor:
        """Complete attention forward pass"""
        B, N, C = x.shape
        
        # 1. QKV projection
        q, k, v = self._project_qkv(x)
        
        # 2. Calculate attention
        out, weights = self._apply_attention(q, k, v, mask)
        
        # 3. Reshape output
        out = out.transpose(1, 2).contiguous().view(B, N, C)
        
        return out, weights