# Ref:
- https://github.com/TimS-ml/nanoGPT
- https://youtu.be/kCc8FmEb1nY

In [1]:
import os
from boring_llm_base.constants import PROJECT_HOME_DIR
import sys; sys.path.append(str(PROJECT_HOME_DIR)); os.chdir(PROJECT_HOME_DIR)
import math
import random
import tqdm
import gzip
import pickle
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.nn import functional as F
from torch.utils.data import DataLoader, Dataset

from torch import Tensor
from typing import Optional, Tuple, Union, List
from jaxtyping import Float, Bool

from einops import rearrange, repeat, reduce

from boring_utils.utils import (
    cprint, 
    tprint, 
    get_device
)

# Config

In [2]:
# from boring_nn.attention.config import AttentionConfig
# cfg = AttentionConfig()
# cprint(cfg)

batch_size = 4  # how many independent sequences will we process in parallel?
block_size = 8  # time steps (seq length, context window)
n_embed = 36    # channels (embedding dim)

t_enc, t_dec = 10, block_size  # encoder/decoder sequence lengths 
n_head = 6
assert n_embed % n_head == 0
n_layer = 6
dropout = 0.2

max_iters = 100
eval_interval = 100
learning_rate = 3e-4
eval_iters = 100

device = get_device()
# vocab_size = len(set(text))
cprint(device)

[93m<module> -> device:[0m
device(type='mps')


# Data Loader

In [3]:
data_dir = os.getenv('DATA_DIR', './data/')
data_dir = os.path.join(data_dir, 'enwik8')

# # NOTE: only read enwik8 first 10M bytes
# with gzip.open(os.path.join(data_dir, 'enwik8.gz')) as file:
#     text = file.read(int(10e6)).decode('utf-8')

meta_path = os.path.join(data_dir, 'meta.pkl')
vocab_size = None
if os.path.exists(meta_path):
    with open(meta_path, 'rb') as f:
        meta = pickle.load(f)
    vocab_size = meta['vocab_size']
    stoi = meta['stoi']
    itos = meta['itos']
else:
    raise FileNotFoundError(f"Meta file {meta_path} not found")

encode = lambda s: [stoi[c] for c in s] 
decode = lambda l: ''.join([itos[i] for i in l]) 

train_bin_path = os.path.join(data_dir, 'train.bin')
val_bin_path = os.path.join(data_dir, 'val.bin')

# train_tensor = torch.tensor(encode(data), dtype=torch.long) # convert to tensor

# torch.long is just an alias for torch.int64
# load the binary data
train_data = np.fromfile(train_bin_path, dtype=np.uint16)
val_data = np.fromfile(val_bin_path, dtype=np.uint16)

# convert to pytorch tensors
train_data = torch.from_numpy(train_data.astype(np.int64))
val_data = torch.from_numpy(val_data.astype(np.int64))

class TextSamplerDataset(Dataset):
    def __init__(self, data, block_size):
        self.data = data
        self.block_size = int(block_size)

    def __getitem__(self, index):
        # single sample
        ix = torch.randint(
            len(self.data) - self.block_size - 1, (1,)
        )
        full_seq = self.data[ix:ix + self.block_size + 1]
        x = full_seq[:-1]
        y = full_seq[1:]
        x, y = x.to(device), y.to(device)
        return x, y

    def __len__(self):
        return len(self.data) // self.block_size


train_dataset = TextSamplerDataset(train_data, block_size)
val_dataset = TextSamplerDataset(val_data, block_size)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=False)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

# Model

In [4]:
class BigramLanguageModel(nn.Module):
    def __init__(self, vocab_size):
        super().__init__()
        embedding_dim = vocab_size
        # embedding_dim = 128
        # each token is represented by a one-hot vector
        # directly reads off the logits for the next token from the embedding table
        # for example: 24 will reads off the 24th column of the embedding table
        self.embedding = nn.Embedding(vocab_size, embedding_dim)

    def forward(self, idx, targets=None):
        # idx is (batch_size, block_size)
        logits = self.embedding(idx)  # B, T, C: (batch_size, block_size, embedding_dim)

        if targets == None:
            loss = None
        else:
            B, T, C = logits.shape
            logits = logits.view(B*T, C)  # (batch_size * block_size, embedding_dim)
            targets = targets.view(-1)  # (batch_size * block_size)
            loss = F.cross_entropy(logits, targets)

        return logits, loss

model = BigramLanguageModel(vocab_size)
model.to(device)

BigramLanguageModel(
  (embedding): Embedding(2102, 2102)
)

In [5]:
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)

# Training

In [6]:
def cycle(loader):
    while True:
        for data in loader:
            yield data

train_iter = cycle(train_loader)
val_iter = cycle(val_loader)

def train(
        model: nn.Module = model,
        train_iter: DataLoader = train_iter,
        val_iter: DataLoader = val_iter,
        eval_iters: int = eval_iters,
        max_iters: int = max_iters,
        eval_interval: int = eval_interval,
    ):
    for iter in range(max_iters):
        # Eval logic
        if iter % eval_interval == 0 or iter == max_iters - 1:
            model.eval()
            with torch.no_grad():
                val_losses = []
                for _, (x, y) in zip(range(eval_iters), val_iter):
                    _, loss = model(x, y)
                    val_losses.append(loss.item())
                val_loss = np.mean(val_losses)

                train_losses = []
                for _, (x, y) in zip(range(eval_iters), train_iter):
                    _, loss = model(x, y)
                    train_losses.append(loss.item())
                train_loss = np.mean(train_losses)

                print(f"step {iter}: train loss {train_loss:.4f}, val loss {val_loss:.4f}")
            model.train()

        # Training logic
        x, y = next(train_iter)  # replace get_batch
        logits, loss = model(x, y)
        optimizer.zero_grad(set_to_none=True)
        loss.backward()
        optimizer.step()

# Dev Zone

From nanoGPT:

```python
class Head(nn.Module):
    """ one head of self-attention """

    def __init__(self, head_size):
        super().__init__()
        self.key = nn.Linear(n_embd, head_size, bias=False)
        self.query = nn.Linear(n_embd, head_size, bias=False)
        self.value = nn.Linear(n_embd, head_size, bias=False)
        self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size)))

        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        # input of size (batch, time-step, channels)
        # output of size (batch, time-step, head size)
        B,T,C = x.shape
        k = self.key(x)   # (B,T,hs)
        q = self.query(x) # (B,T,hs)
        # compute attention scores ("affinities")
        wei = q @ k.transpose(-2,-1) * k.shape[-1]**-0.5 # (B, T, hs) @ (B, hs, T) -> (B, T, T)
        wei = wei.masked_fill(self.tril[:T, :T] == 0, float('-inf')) # (B, T, T)
        wei = F.softmax(wei, dim=-1) # (B, T, T)
        wei = self.dropout(wei)
        # perform the weighted aggregation of the values
        v = self.value(x) # (B,T,hs)
        out = wei @ v # (B, T, T) @ (B, T, hs) -> (B, T, hs)
        return out

class MultiHeadAttention(nn.Module):
    """ multiple heads of self-attention in parallel """

    def __init__(self, num_heads, head_size):
        super().__init__()
        self.heads = nn.ModuleList([Head(head_size) for _ in range(num_heads)])
        self.proj = nn.Linear(head_size * num_heads, n_embd)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        out = torch.cat([h(x) for h in self.heads], dim=-1)
        out = self.dropout(self.proj(out))
        return out
```

## Attention

In [10]:
def attention_normal(
    q: Float[Tensor, "batch num_heads seq_len head_dim"],
    k: Float[Tensor, "batch num_heads seq_len head_dim"],
    v: Float[Tensor, "batch num_heads seq_len head_dim"],
    mask: Optional[Bool[Tensor, "batch seq_len seq_len"]] = None
):
    batch, num_heads, seq_len, head_dim = q.shape
    scale = 1.0 / math.sqrt(head_dim)

    # x-transformer style
    def impl1(q, k, v):
        sim = einsum('b h i d, b h j d -> b h i j', q, k) * scale
        
        if exists(mask):
            sim = sim.masked_fill(~mask, float('-inf'))
            
        attn = F.softmax(sim, dim=-1)
        out = einsum('b h i j, b h j d -> b h i d', attn, v)
        return out, attn

    # lit-gpt style (use torch.nn.functional)
    def impl2(q, k, v):
        if hasattr(F, 'scaled_dot_product_attention'):
            out = F.scaled_dot_product_attention(
                q.transpose(1, 2),  # (b, seq, head, dim)
                k.transpose(1, 2),
                v.transpose(1, 2),
                attn_mask=mask,
                dropout_p=0.0,
                is_causal=mask is None
            )
            return out.transpose(1, 2), None  # 返回None作为attn weights
        
        attn = torch.matmul(q, k.transpose(-2, -1)) * scale
        if exists(mask):
            attn = attn.masked_fill(~mask, float('-inf'))
        attn = F.softmax(attn, dim=-1)
        out = torch.matmul(attn, v)
        return out, attn

    # nanoGPT style
    def impl3(q, k, v):
        att = (q @ k.transpose(-2, -1)) * scale
        if exists(mask):
            att = att.masked_fill(~mask, float('-inf'))
        att = F.softmax(att, dim=-1)
        out = att @ v
        return out, att

    out1, attn1 = impl1(q, k, v)
    out2, attn2 = impl2(q, k, v)
    out3, attn3 = impl3(q, k, v)
    
    tprint('shape', sep='*')
    cprint(out1.shape)
    assert out1.shape == out2.shape == out3.shape
    if attn2 is not None:  # if not using SDPA
        assert attn1.shape == attn2.shape == attn3.shape
    print('passed')
    
    tprint('allclose', sep='*')
    assert torch.allclose(out1, out3, rtol=1e-4)
    if attn2 is not None:  # if not using SDPA
        assert torch.allclose(out1, out2, rtol=1e-4)
        assert torch.allclose(attn1, attn3, rtol=1e-4)
    print('passed')

    return out1, attn1


# attention_normal(q, k, v)