# Ref:
- https://github.com/TimS-ml/nanoGPT
- https://youtu.be/kCc8FmEb1nY

In [1]:
import os
from boring_llm_base.constants import PROJECT_HOME_DIR
import sys; sys.path.append(str(PROJECT_HOME_DIR)); os.chdir(PROJECT_HOME_DIR)
import math
import random
import tqdm
import gzip
import pickle
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.nn import functional as F
from torch.utils.data import DataLoader, Dataset

from torch import Tensor
from typing import Optional, Tuple, Union, List
from jaxtyping import Float, Bool

from einops import rearrange, repeat, reduce

from boring_utils.utils import (
    cprint, 
    tprint, 
    get_device
)

# Config

In [2]:
# from boring_nn.attention.config import AttentionConfig
# cfg = AttentionConfig()
# cprint(cfg)

batch_size = 4  # how many independent sequences will we process in parallel?
block_size = 8  # time steps (seq length, context window)
n_embed = 36    # channels (embedding dim)

t_enc, t_dec = 10, block_size  # encoder/decoder sequence lengths 
n_head = 6
assert n_embed % n_head == 0
n_layer = 6
dropout = 0.2

max_iters = 100
eval_interval = 100
learning_rate = 3e-4
eval_iters = 100

device = get_device()
# vocab_size = len(set(text))
cprint(device)

[93m<module> -> device:[0m
device(type='mps')


# Data Loader

In [3]:
data_dir = os.getenv('DATA_DIR', './data/')
data_dir = os.path.join(data_dir, 'enwik8')

# # NOTE: only read enwik8 first 10M bytes
# with gzip.open(os.path.join(data_dir, 'enwik8.gz')) as file:
#     text = file.read(int(10e6)).decode('utf-8')

meta_path = os.path.join(data_dir, 'meta.pkl')
vocab_size = None
if os.path.exists(meta_path):
    with open(meta_path, 'rb') as f:
        meta = pickle.load(f)
    vocab_size = meta['vocab_size']
    stoi = meta['stoi']
    itos = meta['itos']
else:
    raise FileNotFoundError(f"Meta file {meta_path} not found")

encode = lambda s: [stoi[c] for c in s] 
decode = lambda l: ''.join([itos[i] for i in l]) 

train_bin_path = os.path.join(data_dir, 'train.bin')
val_bin_path = os.path.join(data_dir, 'val.bin')

# train_tensor = torch.tensor(encode(data), dtype=torch.long) # convert to tensor

# torch.long is just an alias for torch.int64
# load the binary data
train_data = np.fromfile(train_bin_path, dtype=np.uint16)
val_data = np.fromfile(val_bin_path, dtype=np.uint16)

# convert to pytorch tensors
train_data = torch.from_numpy(train_data.astype(np.int64))
val_data = torch.from_numpy(val_data.astype(np.int64))

class TextSamplerDataset(Dataset):
    def __init__(self, data, block_size):
        self.data = data
        self.block_size = int(block_size)

    def __getitem__(self, index):
        # single sample
        ix = torch.randint(
            len(self.data) - self.block_size - 1, (1,)
        )
        full_seq = self.data[ix:ix + self.block_size + 1]
        x = full_seq[:-1]
        y = full_seq[1:]
        x, y = x.to(device), y.to(device)
        return x, y

    def __len__(self):
        return len(self.data) // self.block_size


train_dataset = TextSamplerDataset(train_data, block_size)
val_dataset = TextSamplerDataset(val_data, block_size)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=False)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

# Model

In [4]:
class BigramLanguageModel(nn.Module):
    def __init__(self, vocab_size):
        super().__init__()
        embedding_dim = vocab_size
        # embedding_dim = 128
        # each token is represented by a one-hot vector
        # directly reads off the logits for the next token from the embedding table
        # for example: 24 will reads off the 24th column of the embedding table
        self.embedding = nn.Embedding(vocab_size, embedding_dim)

    def forward(self, idx, targets=None):
        # idx is (batch_size, block_size)
        logits = self.embedding(idx)  # B, T, C: (batch_size, block_size, embedding_dim)

        if targets == None:
            loss = None
        else:
            B, T, C = logits.shape
            logits = logits.view(B*T, C)  # (batch_size * block_size, embedding_dim)
            targets = targets.view(-1)  # (batch_size * block_size)
            loss = F.cross_entropy(logits, targets)

        return logits, loss

model = BigramLanguageModel(vocab_size)
model.to(device)

BigramLanguageModel(
  (embedding): Embedding(2102, 2102)
)

In [5]:
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)

# Training

In [6]:
def cycle(loader):
    while True:
        for data in loader:
            yield data

train_iter = cycle(train_loader)
val_iter = cycle(val_loader)

def train(
        model: nn.Module = model,
        train_iter: DataLoader = train_iter,
        val_iter: DataLoader = val_iter,
        eval_iters: int = eval_iters,
        max_iters: int = max_iters,
        eval_interval: int = eval_interval,
    ):
    for iter in range(max_iters):
        # Eval logic
        if iter % eval_interval == 0 or iter == max_iters - 1:
            model.eval()
            with torch.no_grad():
                val_losses = []
                for _, (x, y) in zip(range(eval_iters), val_iter):
                    _, loss = model(x, y)
                    val_losses.append(loss.item())
                val_loss = np.mean(val_losses)

                train_losses = []
                for _, (x, y) in zip(range(eval_iters), train_iter):
                    _, loss = model(x, y)
                    train_losses.append(loss.item())
                train_loss = np.mean(train_losses)

                print(f"step {iter}: train loss {train_loss:.4f}, val loss {val_loss:.4f}")
            model.train()

        # Training logic
        x, y = next(train_iter)  # replace get_batch
        logits, loss = model(x, y)
        optimizer.zero_grad(set_to_none=True)
        loss.backward()
        optimizer.step()

# Dev Zone

From nanoGPT:

```python
class Head(nn.Module):
    """ one head of self-attention """

    def __init__(self, head_size):
        super().__init__()
        self.key = nn.Linear(n_embd, head_size, bias=False)
        self.query = nn.Linear(n_embd, head_size, bias=False)
        self.value = nn.Linear(n_embd, head_size, bias=False)
        self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size)))

        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        # input of size (batch, time-step, channels)
        # output of size (batch, time-step, head size)
        B,T,C = x.shape
        k = self.key(x)   # (B,T,hs)
        q = self.query(x) # (B,T,hs)
        # compute attention scores ("affinities")
        wei = q @ k.transpose(-2,-1) * k.shape[-1]**-0.5 # (B, T, hs) @ (B, hs, T) -> (B, T, T)
        wei = wei.masked_fill(self.tril[:T, :T] == 0, float('-inf')) # (B, T, T)
        wei = F.softmax(wei, dim=-1) # (B, T, T)
        wei = self.dropout(wei)
        # perform the weighted aggregation of the values
        v = self.value(x) # (B,T,hs)
        out = wei @ v # (B, T, T) @ (B, T, hs) -> (B, T, hs)
        return out

class MultiHeadAttention(nn.Module):
    """ multiple heads of self-attention in parallel """

    def __init__(self, num_heads, head_size):
        super().__init__()
        self.heads = nn.ModuleList([Head(head_size) for _ in range(num_heads)])
        self.proj = nn.Linear(head_size * num_heads, n_embd)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        out = torch.cat([h(x) for h in self.heads], dim=-1)
        out = self.dropout(self.proj(out))
        return out
```

## QKV

In [7]:
x = torch.randn(batch_size, block_size, n_embed)

def qkv_normal(
    x: Float[Tensor, "batch seq_len embedding_dim"],
    num_heads: int = n_head
):
    batch, seq_len, dim = x.shape
    head_dim = dim // num_heads
    
    # Fixed weights, for testing `assert torch.allclose`
    fixed_weight = torch.randn(dim, dim)

    # Separated QKV (x-transformer style)
    def impl1(x):
        to_q = nn.Linear(dim, dim, bias=False)
        to_kv = nn.Linear(dim, dim * 2, bias=False)
        to_q.weight.data = fixed_weight.clone()
        to_kv.weight.data = torch.cat([fixed_weight, fixed_weight]).clone()

        q = to_q(x)
        k, v = to_kv(x).chunk(2, dim=-1)

        q, k, v = map(lambda t: rearrange(t, 'batch seq_len (num_heads head_dim) -> batch num_heads seq_len head_dim', num_heads = num_heads), (q, k, v))
        return q, k, v

    # Unified QKV (lit-gpt style)
    def impl2(x):
        to_qkv = nn.Linear(dim, dim * 3, bias=False)
        to_qkv.weight.data = torch.cat([fixed_weight] * 3).clone() 

        # Split into q,k,v and reshape to add head dimension
        q, k, v  = to_qkv(x).split(n_embed, dim=2)
        # Permute to get (batch, num_heads, seq_len, head_dim)
        q, k, v = map(
            lambda t: t.view(batch, -1, num_heads, head_dim).transpose(1, 2),
            (q, k, v)
        )

        return q, k, v

    # nanoGPT style
    def impl3(x):
        to_q = nn.Linear(dim, head_dim, bias=False)
        to_k = nn.Linear(dim, head_dim, bias=False)
        to_v = nn.Linear(dim, head_dim, bias=False)
        to_q.weight.data = fixed_weight.clone()
        to_k.weight.data = fixed_weight.clone()
        to_v.weight.data = fixed_weight.clone()
        
        k = to_k(x)
        q = to_q(x)
        v = to_v(x)
        
        q = q.view(batch, seq_len, num_heads, head_dim).transpose(1, 2)
        k = k.view(batch, seq_len, num_heads, head_dim).transpose(1, 2)
        v = v.view(batch, seq_len, num_heads, head_dim).transpose(1, 2)
        return q, k, v    

    # Compare results
    q1, k1, v1 = impl1(x)
    q2, k2, v2 = impl2(x)
    q3, k3, v3 = impl3(x)
    
    tprint('allclose', sep='*')
    assert torch.allclose(q1, q2, rtol=1e-4)
    assert torch.allclose(k1, k2, rtol=1e-4)
    assert torch.allclose(v1, v2, rtol=1e-4)
    assert torch.allclose(q1, q3, rtol=1e-4)
    assert torch.allclose(k1, k3, rtol=1e-4)
    assert torch.allclose(v1, v3, rtol=1e-4)
    print('passed')

    tprint('shape', sep='*')
    try:
        assert q1.shape == q2.shape
        assert k1.shape == k2.shape
        assert v1.shape == v2.shape
        assert q1.shape == q3.shape
        assert k1.shape == k3.shape
        assert v1.shape == v3.shape
        print('passed')
    except Exception as e:
        cprint(q1.shape, k1.shape, v1.shape)
        cprint(q2.shape, k2.shape, v2.shape)
        cprint(q3.shape, k3.shape, v3.shape)
        raise e


qkv_normal(x)

[35m
******************** qkv_normal -> allclose ********************[0m
passed
[35m
******************** qkv_normal -> shape ********************[0m
passed


### Grouped QKV

From lit-gpt:
```
to use multi-head attention (MHA), set this to `n_head` (default)
to use multi-query attention (MQA), set this to 1
to use grouped-query attention (GQA), set this to a value in between
Example with `n_head=4`
┌───┐┌───┐┌───┐┌───┐     ┌───┐    ┌───┐             ┌───┐
│ v ││ v ││ v ││ v │     │ v │    │ v │             │ v │
└───┘└───┘└───┘└───┘     └───┘    └───┘             └───┘
  │    │    │    │         │        │                 │
┌───┐┌───┐┌───┐┌───┐     ┌───┐    ┌───┐             ┌───┐
│ k ││ k ││ k ││ k │     │ k │    │ k │             │ k │
└───┘└───┘└───┘└───┘     └───┘    └───┘             └───┘
  │    │    │    │      ┌──┴──┐  ┌──┴──┐      ┌────┬──┴─┬────┐
┌───┐┌───┐┌───┐┌───┐  ┌───┐┌───┐┌───┐┌───┐  ┌───┐┌───┐┌───┐┌───┐
│ q ││ q ││ q ││ q │  │ q ││ q ││ q ││ q │  │ q ││ q ││ q ││ q │
└───┘└───┘└───┘└───┘  └───┘└───┘└───┘└───┘  └───┘└───┘└───┘└───┘
◀──────────────────▶  ◀──────────────────▶  ◀──────────────────▶
        MHA                    GQA                   MQA
  n_query_groups=4       n_query_groups=2      n_query_groups=1

credit https://arxiv.org/pdf/2305.13245.pdf
```

In [8]:
def qkv_grouped(
    x: Float[Tensor, "batch seq_len embedding_dim"],
    num_heads: int = n_head,
    n_query_groups: int = 2
):
    batch, seq_len, dim = x.shape
    head_dim = dim // num_heads
    
    # TODO: Fixed weights for testing
    # q_weight = torch.randn(dim, dim)
    # k_weight = torch.randn(dim, head_dim * n_query_groups)
    # v_weight = torch.randn(dim, head_dim * n_query_groups)

    # 1. x-transformer style implementation (Separated QKV + GQA)
    def impl1(x):
        kv_heads = n_query_groups
        to_q = nn.Linear(dim, dim, bias=False)
        to_kv = nn.Linear(dim, head_dim * kv_heads * 2, bias=False)
        # to_q.weight.data = q_weight.clone()
        # to_kv.weight.data = torch.cat([k_weight, v_weight], dim=1)

        # tprint(f'grouped heads={kv_heads}', sep='-')
        # cprint(dim, head_dim)
        # cprint(to_q.weight.shape, to_kv.weight.shape)

        q = to_q(x)
        k, v = to_kv(x).chunk(2, dim=-1)

        # Handle grouped-query attention
        if kv_heads == 1:  # MQA case
            # method 1
            # k, v = tuple(rearrange(t, 'b n (h d) -> b h n d', h=kv_heads) for t in (k, v))
            # k, v = tuple(repeat(t, 'b h n d -> b (r h) n d', r=num_heads // kv_heads) for t in (k, v))

            # method 2
            # k, v = map(
            #     lambda t: repeat(
            #         rearrange(t, 'b n (h d) -> b h n d', h=kv_heads),
            #         'b h n d -> b (r h) n d',
            #         r=num_heads // kv_heads
            #     ),
            #     (k, v)
            # )

            # method 3
            k, v = tuple(repeat(t, 'b n (h d) -> b (r h) n d', h=kv_heads, r=num_heads // kv_heads) for t in (k, v))
        elif kv_heads < num_heads:  # GQA case
            k, v = tuple(repeat(t, 'b n (h d) -> b (r h) n d', h=kv_heads, r=num_heads // kv_heads) for t in (k, v))
        else:  # MHA case
            k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=num_heads), (k, v)) 

        q = rearrange(q, 'b n (h d) -> b h n d', h=num_heads)
        return q, k, v

    # 2. lit-gpt style implementation (Unified QKV + GQA)
    def impl2(x):
        # Calculate number of Q per KV group
        q_per_kv = num_heads // n_query_groups
        total_qkv = q_per_kv + 2  # Each group has q_per_kv queries + 1 key + 1 value
        qkv_dim = (num_heads + 2 * n_query_groups) * head_dim

        to_qkv = nn.Linear(dim, qkv_dim, bias=False)
        # qkv_weight = torch.cat([
        #     q_weight,
        #     k_weight.repeat(q_per_kv, 1),
        #     v_weight.repeat(q_per_kv, 1)
        # ], dim=1)
        # to_qkv.weight.data = qkv_weight.clone()

        qkv = to_qkv(x)
        qkv = qkv.view(batch, seq_len, n_query_groups, total_qkv, head_dim)
        qkv = qkv.permute(0, 2, 3, 1, 4)  # (B, n_query_groups, total_qkv, T, hs)
        
        # Split Q, K, V
        q, k, v = qkv.split((q_per_kv, 1, 1), dim=2)
        
        # Handle K, V expansion for MQA/GQA
        if n_query_groups != num_heads:
            k = k.expand(batch, n_query_groups, q_per_kv, seq_len, head_dim)
            v = v.expand(batch, n_query_groups, q_per_kv, seq_len, head_dim)
        
        # Reshape to final shape
        q = q.reshape(batch, -1, seq_len, head_dim)  # (B, nh_q, T, hs)
        k = k.reshape(batch, -1, seq_len, head_dim)  # (B, nh_k, T, hs)
        v = v.reshape(batch, -1, seq_len, head_dim)  # (B, nh_v, T, hs)
        
        return q, k, v

    # Validate implementations
    q1, k1, v1 = impl1(x)
    q2, k2, v2 = impl2(x)
    
    # tprint('allclose', sep='*')
    # assert torch.allclose(q1, q2, rtol=1e-4)
    # assert torch.allclose(k1, k2, rtol=1e-4)
    # assert torch.allclose(v1, v2, rtol=1e-4)
    # print('passed')

    tprint('shape', sep='*')
    try:
        assert q1.shape == q2.shape
        assert k1.shape == k2.shape
        assert v1.shape == v2.shape
        print('passed')
    except Exception as e:
        cprint(q1.shape, k1.shape, v1.shape)
        cprint(q2.shape, k2.shape, v2.shape)
        raise e


# Test different GQA configurations
x = torch.randn(batch_size, block_size, n_embed)

# Test standard multi-head attention
tprint('MHA')
qkv_grouped(x, num_heads=n_head, n_query_groups=n_head)

# Test MQA (1 KV head)
tprint('MQA')
qkv_grouped(x, num_heads=n_head, n_query_groups=1)

# Test GQA (e.g. 8 heads, 2 groups)
tprint('GQA')
qkv_grouped(x, num_heads=n_head, n_query_groups=2)

[35m
[35m
******************** qkv_grouped -> shape ********************[0m
passed
[35m
[35m
******************** qkv_grouped -> shape ********************[0m
passed
[35m
[35m
******************** qkv_grouped -> shape ********************[0m
passed


## PE

In [9]:
def positional_encoding_normal(
    x: Float[Tensor, "batch seq_len embedding_dim"],
    num_heads: int = n_head
):
    batch, seq_len, dim = x.shape
    head_dim = dim // num_heads
    
    # Fixed weights for testing
    fixed_weight = torch.randn(seq_len, dim)

    # x-transformer style (using sin-cos)
    def impl1(x):
        pe = fixed_weight.clone()
        position = torch.arange(0, seq_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, dim, 2) * -(math.log(10000.0) / dim))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)
        
        return x + pe

    # lit-gpt style (using RoPE)
    def impl2(x):
        # Generate sin/cos for RoPE
        inv_freq = 1.0 / (10000 ** (torch.arange(0, head_dim, 2) / head_dim))
        t = torch.arange(seq_len)
        freqs = torch.einsum('i,j->ij', t, inv_freq)
        emb = torch.cat((freqs, freqs), dim=-1)
        cos = emb.cos().view(1, seq_len, 1, head_dim)
        sin = emb.sin().view(1, seq_len, 1, head_dim)
        
        # Apply RoPE
        x_reshaped = x.view(batch, seq_len, num_heads, head_dim)
        x_rope = apply_rope(x_reshaped, cos, sin)
        return x_rope.view(batch, seq_len, -1)

    # nanoGPT style (simple positional embedding)
    def impl3(x):
        position_embeddings = nn.Embedding(seq_len, dim)
        position_embeddings.weight.data = fixed_weight.clone()
        
        positions = torch.arange(seq_len, device=x.device).unsqueeze(0)
        position_embeds = position_embeddings(positions)
        return x + position_embeds

    # Compare results
    out1 = impl1(x)
    out2 = impl2(x)
    out3 = impl3(x)
    
    tprint('shape', sep='*')
    cprint(out1.shape, out2.shape, out3.shape)
    assert out1.shape == out2.shape == out3.shape
    print('passed')

    return out1


# positional_encoding_normal(x)

## Attention

In [10]:
def attention_normal(
    q: Float[Tensor, "batch num_heads seq_len head_dim"],
    k: Float[Tensor, "batch num_heads seq_len head_dim"],
    v: Float[Tensor, "batch num_heads seq_len head_dim"],
    mask: Optional[Bool[Tensor, "batch seq_len seq_len"]] = None
):
    batch, num_heads, seq_len, head_dim = q.shape
    scale = 1.0 / math.sqrt(head_dim)

    # x-transformer style
    def impl1(q, k, v):
        sim = einsum('b h i d, b h j d -> b h i j', q, k) * scale
        
        if exists(mask):
            sim = sim.masked_fill(~mask, float('-inf'))
            
        attn = F.softmax(sim, dim=-1)
        out = einsum('b h i j, b h j d -> b h i d', attn, v)
        return out, attn

    # lit-gpt style (use torch.nn.functional)
    def impl2(q, k, v):
        if hasattr(F, 'scaled_dot_product_attention'):
            out = F.scaled_dot_product_attention(
                q.transpose(1, 2),  # (b, seq, head, dim)
                k.transpose(1, 2),
                v.transpose(1, 2),
                attn_mask=mask,
                dropout_p=0.0,
                is_causal=mask is None
            )
            return out.transpose(1, 2), None  # 返回None作为attn weights
        
        attn = torch.matmul(q, k.transpose(-2, -1)) * scale
        if exists(mask):
            attn = attn.masked_fill(~mask, float('-inf'))
        attn = F.softmax(attn, dim=-1)
        out = torch.matmul(attn, v)
        return out, attn

    # nanoGPT style
    def impl3(q, k, v):
        att = (q @ k.transpose(-2, -1)) * scale
        if exists(mask):
            att = att.masked_fill(~mask, float('-inf'))
        att = F.softmax(att, dim=-1)
        out = att @ v
        return out, att

    out1, attn1 = impl1(q, k, v)
    out2, attn2 = impl2(q, k, v)
    out3, attn3 = impl3(q, k, v)
    
    tprint('shape', sep='*')
    cprint(out1.shape)
    assert out1.shape == out2.shape == out3.shape
    if attn2 is not None:  # if not using SDPA
        assert attn1.shape == attn2.shape == attn3.shape
    print('passed')
    
    tprint('allclose', sep='*')
    assert torch.allclose(out1, out3, rtol=1e-4)
    if attn2 is not None:  # if not using SDPA
        assert torch.allclose(out1, out2, rtol=1e-4)
        assert torch.allclose(attn1, attn3, rtol=1e-4)
    print('passed')

    return out1, attn1


# attention_normal(q, k, v)

## Mask

In [11]:
def mask_normal(
    batch_size: int,
    seq_len: int,
    mask_type: str = "causal",  # "causal", "padding", or "bidirectional"
    device: Optional[torch.device] = None
) -> Bool[Tensor, "batch seq_len seq_len"]:
    
    # x-transformer style (support multiple mask types)
    def impl1():
        if mask_type == "causal":
            # Create lower triangular mask
            mask = torch.ones((seq_len, seq_len), device=device).triu(1).bool()
            mask = mask.unsqueeze(0).expand(batch_size, -1, -1)
        elif mask_type == "padding":
            # Generate random padding mask for testing
            mask = torch.rand(batch_size, seq_len) > 0.2
            mask = mask.unsqueeze(1) & mask.unsqueeze(2)
        else:  # bidirectional
            mask = torch.ones((batch_size, seq_len, seq_len), device=device).bool()
        return ~mask  # Note: here we take the inverse, so True means "allow attend"
    
    # lit-gpt style (focus on causal mask + cache optimization)
    def impl2():
        if mask_type != "causal":
            raise NotImplementedError("lit-gpt mainly supports causal mask")
        # Precompute and cache mask
        ones = torch.ones((seq_len, seq_len), dtype=torch.bool, device=device)
        mask = torch.tril(ones).unsqueeze(0)
        return mask.expand(batch_size, -1, -1)
    
    # nanoGPT style
    def impl3():
        if mask_type == "causal":
            mask = torch.tril(torch.ones(seq_len, seq_len, device=device))
            return mask.view(1, seq_len, seq_len).expand(batch_size, -1, -1)
        return torch.ones((batch_size, seq_len, seq_len), device=device)
    
    mask1 = impl1()
    mask2 = impl2()
    mask3 = impl3()
    
    tprint('shape', sep='*')
    cprint(mask1.shape)
    assert mask1.shape == mask2.shape == mask3.shape
    print('passed')
    
    if mask_type == "causal":
        tprint('allclose', sep='*')
        assert torch.allclose(mask1.float(), mask2.float())
        assert torch.allclose(mask1.float(), mask3.float())
        print('passed')
    
    return mask1


# mask_normal(batch_size, seq_len)

## Post Attention

In [12]:
def attention_post_normal(
    attn_weights: Float[Tensor, "batch num_heads seq_len seq_len"],
    dropout_p: float = 0.1
) -> Float[Tensor, "batch num_heads seq_len seq_len"]:
    batch, num_heads, seq_len, _ = attn_weights.shape
    
    # x-transformer style (support talking heads etc.)
    def impl1(attn):
        # Simulate talking heads
        talking_heads = nn.Linear(num_heads, num_heads, bias=False)
        talking_heads.weight.data = torch.eye(num_heads)  # Initialize as identity matrix
        
        attn = rearrange(attn, 'b h i j -> b i j h')
        attn = talking_heads(attn)
        attn = rearrange(attn, 'b i j h -> b h i j')
        
        attn = F.dropout(attn, p=dropout_p)
        return attn
    
    # lit-gpt style
    def impl2(attn):
        return F.dropout(attn, p=dropout_p)
    
    # nanoGPT style
    def impl3(attn):
        return F.dropout(attn, p=dropout_p)
    
    out1 = impl1(attn_weights)
    out2 = impl2(attn_weights)
    out3 = impl3(attn_weights)
    
    tprint('shape', sep='*')
    cprint(out1.shape)
    assert out1.shape == out2.shape == out3.shape
    print('passed')
    
    return out1


# attention_post_normal(attn_weights)

## KV Cache

In [13]:
def kv_cache_normal(
    k: Float[Tensor, "batch num_heads seq_len head_dim"],
    v: Float[Tensor, "batch num_heads seq_len head_dim"],
    past_key_values: Optional[Tuple[Tensor, Tensor]] = None,
    use_cache: bool = True
) -> Tuple[Tensor, Tensor]:
    batch, num_heads, seq_len, head_dim = k.shape
    
    # x-transformer style
    def impl1():
        if past_key_values is not None:
            past_k, past_v = past_key_values
            k = torch.cat([past_k, k], dim=2)
            v = torch.cat([past_v, v], dim=2)
        
        if use_cache:
            return (k, v), (k, v)
        return (k, v), None
    
    # lit-gpt style
    def impl2():
        if past_key_values is not None:
            past_k, past_v = past_key_values
            # Use efficient indexing to copy
            k_cache = torch.empty_like(past_k)
            v_cache = torch.empty_like(past_v)
            k_cache[:, :, :-seq_len] = past_k[:, :, seq_len:]
            v_cache[:, :, :-seq_len] = past_v[:, :, seq_len:]
            k_cache[:, :, -seq_len:] = k
            v_cache[:, :, -seq_len:] = v
            k, v = k_cache, v_cache
            
        if use_cache:
            return (k, v), (k, v)
        return (k, v), None
    
    # nanoGPT style
    def impl3():
        if past_key_values is not None:
            past_k, past_v = past_key_values
            k = torch.cat([past_k, k], dim=2)
            v = torch.cat([past_v, v], dim=2)
        
        if use_cache:
            return (k, v), (k, v)
        return (k, v), None
    
    (k1, v1), cache1 = impl1()
    (k2, v2), cache2 = impl2()
    (k3, v3), cache3 = impl3()
    
    tprint('shape', sep='*')
    cprint(k1.shape, v1.shape)
    assert k1.shape == k2.shape == k3.shape
    assert v1.shape == v2.shape == v3.shape
    print('passed')
   
    tprint('allclose', sep='*')
    assert torch.allclose(k1, k3, rtol=1e-4)
    assert torch.allclose(v1, v3, rtol=1e-4)
    if use_cache:
        assert all(torch.allclose(c1, c3, rtol=1e-4) 
                  for c1, c3 in zip(cache1, cache3))
    print('passed')
    
    return (k1, v1), cache1


# kv_cache_normal(k, v)

# Put Everything Together

In [14]:
class ModularAttention(nn.Module):
    """Configurable modular attention implementation"""
    
    def __init__(
        self,
        dim: int,
        num_heads: int,
        projection_type: str = "unified",
        qkv_bias: bool = False,
        attn_dropout: float = 0.0,
        scaling_type: str = "default"  # or "learned" or "fixed"
    ):
        super().__init__()
        self.dim = dim
        self.num_heads = num_heads
        self.head_dim = dim // num_heads
        
        # QKV projection configuration
        if projection_type == "unified":
            self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        elif projection_type == "separated":
            self.to_q = nn.Linear(dim, dim, bias=qkv_bias)
            self.to_kv = nn.Linear(dim, dim * 2, bias=qkv_bias)
        else:  # INDIVIDUAL
            self.to_q = nn.Linear(dim, dim, bias=qkv_bias)
            self.to_k = nn.Linear(dim, dim, bias=qkv_bias)
            self.to_v = nn.Linear(dim, dim, bias=qkv_bias)
            
        # Scaling configuration
        self.scaling_type = scaling_type
        if scaling_type == "learned":
            self.scale = nn.Parameter(torch.ones(1) / np.sqrt(self.head_dim))
        else:
            self.scale = 1.0 / np.sqrt(self.head_dim)
            
        self.attn_dropout = nn.Dropout(attn_dropout)
        
    def _project_qkv(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """Implement QKV projection based on different projection_type"""
        if hasattr(self, 'qkv'):
            qkv = self.qkv(x)
            q, k, v = qkv.chunk(3, dim=-1)
        elif hasattr(self, 'to_kv'):
            q = self.to_q(x)
            k, v = self.to_kv(x).chunk(2, dim=-1)
        else:
            q = self.to_q(x)
            k = self.to_k(x)
            v = self.to_v(x)
            
        return map(
            lambda t: t.view(t.shape[0], -1, self.num_heads, self.head_dim).transpose(1, 2),
            (q, k, v)
        )
    
    def _apply_attention(
        self,
        q: torch.Tensor,
        k: torch.Tensor,
        v: torch.Tensor,
        mask: Optional[torch.Tensor] = None
    ) -> torch.Tensor:
        """Calculate attention scores and apply to values"""
        
        # Calculate attention scores
        attn_weights = torch.matmul(q, k.transpose(-2, -1))
        
        # Apply scaling
        if isinstance(self.scale, nn.Parameter):
            attn_weights = attn_weights * self.scale
        else:
            attn_weights = attn_weights * self.scale
            
        # Apply mask
        if mask is not None:
            attn_weights = attn_weights.masked_fill(mask == 0, float('-inf'))
            
        # Softmax and dropout
        attn_weights = F.softmax(attn_weights, dim=-1)
        attn_weights = self.attn_dropout(attn_weights)
        
        # Apply attention weights to values
        output = torch.matmul(attn_weights, v)
        return output, attn_weights
    
    def forward(
        self,
        x: torch.Tensor,
        mask: Optional[torch.Tensor] = None
    ) -> torch.Tensor:
        """Complete attention forward pass"""
        B, N, C = x.shape
        
        # 1. QKV projection
        q, k, v = self._project_qkv(x)
        
        # 2. Calculate attention
        out, weights = self._apply_attention(q, k, v, mask)
        
        # 3. Reshape output
        out = out.transpose(1, 2).contiguous().view(B, N, C)
        
        return out, weights