In [None]:
from dataclasses import dataclass
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

@dataclass
class GPT2Config:
    vocab_size: int = 50257
    n_positions: int = 1024
    n_embd: int = 768
    n_layer: int = 12
    n_head: int = 12

    @property
    def head_dim(self) -> int:
        return self.n_embd // self.n_head

config = GPT2Config()
print(f"Config: {config.n_layer} layers, {config.n_head} heads, {config.n_embd} dim")

In [2]:
class CausalSelfAttention(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.n_head = config.n_head
        self.n_embd = config.n_embd
        self.head_dim = config.head_dim

        self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=True)
        self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=True)
        self.scale = 1.0 / math.sqrt(self.head_dim)

    def forward(self, x, kv_cache=None):
        B, T, C = x.shape

        qkv = self.c_attn(x)
        q, k, v = qkv.split(self.n_embd, dim=-1)

        q = q.view(B, T, self.n_head, self.head_dim).transpose(1, 2)
        k = k.view(B, T, self.n_head, self.head_dim).transpose(1, 2)
        v = v.view(B, T, self.n_head, self.head_dim).transpose(1, 2)

        if kv_cache is not None:
            k_cache, v_cache = kv_cache
            k = torch.cat([k_cache, k], dim=2)
            v = torch.cat([v_cache, v], dim=2)

        new_cache = (k, v)

        seq_len = k.shape[2]
        mask = torch.tril(torch.ones(T, seq_len, device=x.device))
        attn = torch.matmul(q, k.transpose(-2, -1)) * self.scale
        attn = attn.masked_fill(mask == 0, float('-inf'))
        attn = F.softmax(attn, dim=-1)
        out = torch.matmul(attn, v)

        out = out.transpose(1, 2).contiguous().view(B, T, C)
        out = self.c_proj(out)
        return out, new_cache

In [3]:
class MLP(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=True)
        self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=True)

    def forward(self, x):
        x = self.c_fc(x)
        x = F.gelu(x, approximate='tanh')
        x = self.c_proj(x)
        return x

In [4]:
class Block(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.ln_1 = nn.LayerNorm(config.n_embd)
        self.attn = CausalSelfAttention(config)
        self.ln_2 = nn.LayerNorm(config.n_embd)
        self.mlp = MLP(config)

    def forward(self, x, kv_cache=None):
        attn_out, new_cache = self.attn(self.ln_1(x), kv_cache)
        x = x + attn_out
        x = x + self.mlp(self.ln_2(x))
        return x, new_cache

In [None]:
class GPT2(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config

        self.wte = nn.Embedding(config.vocab_size, config.n_embd)
        self.wpe = nn.Embedding(config.n_positions, config.n_embd)
        self.h = nn.ModuleList([Block(config) for _ in range(config.n_layer)])
        self.ln_f = nn.LayerNorm(config.n_embd)
        self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
        self.lm_head.weight = self.wte.weight

    def forward(self, input_ids, kv_caches=None, start_pos=0):
        B, T = input_ids.shape

        pos = torch.arange(start_pos, start_pos + T, device=input_ids.device)
        tok_emb = self.wte(input_ids)
        pos_emb = self.wpe(pos)
        x = tok_emb + pos_emb

        if kv_caches is None:
            kv_caches = [None] * self.config.n_layer

        new_caches = []
        for i, block in enumerate(self.h):
            x, new_cache = block(x, kv_caches[i])
            new_caches.append(new_cache)

        x = self.ln_f(x)
        logits = self.lm_head(x)
        return logits, new_caches

model = GPT2(config)
print(f"Parameters: {sum(p.numel() for p in model.parameters()):,}")

In [None]:
def load_hf_weights(model, model_name="gpt2"):
    from transformers import GPT2LMHeadModel

    hf_model = GPT2LMHeadModel.from_pretrained(model_name)
    hf_sd = hf_model.state_dict()

    transpose_keys = [
        'attn.c_attn.weight', 'attn.c_proj.weight',
        'mlp.c_fc.weight', 'mlp.c_proj.weight'
    ]

    our_sd = model.state_dict()

    for key in our_sd.keys():
        hf_key = key
        if key.startswith('h.'):
            hf_key = 'transformer.' + key
        elif key in ['wte.weight', 'wpe.weight', 'ln_f.weight', 'ln_f.bias']:
            hf_key = 'transformer.' + key
        elif key == 'lm_head.weight':
            hf_key = 'lm_head.weight'

        if hf_key not in hf_sd:
            continue

        hf_tensor = hf_sd[hf_key]

        needs_transpose = any(t in key for t in transpose_keys)
        if needs_transpose and len(hf_tensor.shape) == 2:
            hf_tensor = hf_tensor.T

        if our_sd[key].shape != hf_tensor.shape:
            continue

        our_sd[key] = hf_tensor

    model.load_state_dict(our_sd)
    print(f"Loaded weights from {model_name}")
    return model

model = GPT2(config)
model = load_hf_weights(model, "gpt2")
model.eval()
print("Model ready!")

In [None]:
def generate(model, prompt, max_new_tokens=50, device='cpu'):
    from transformers import GPT2Tokenizer

    tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
    model = model.to(device)
    model.eval()

    input_ids = tokenizer.encode(prompt, return_tensors='pt').to(device)
    T = input_ids.shape[1]

    with torch.no_grad():
        logits, kv_caches = model(input_ids, kv_caches=None, start_pos=0)
        next_token = torch.argmax(logits[:, -1, :], dim=-1, keepdim=True)
        input_ids = torch.cat([input_ids, next_token], dim=1)

        for i in range(max_new_tokens - 1):
            logits, kv_caches = model(next_token, kv_caches=kv_caches, start_pos=T + i)
            next_token = torch.argmax(logits[:, -1, :], dim=-1, keepdim=True)
            input_ids = torch.cat([input_ids, next_token], dim=1)

            if next_token.item() == tokenizer.eos_token_id:
                break

    return tokenizer.decode(input_ids[0])

output = generate(model, "The meaning of life is", max_new_tokens=30)
print(output)

In [None]:
# Test KV cache is working
input_ids = torch.randint(0, config.vocab_size, (1, 5))
logits, caches = model(input_ids)
print(f"First pass - Input: {input_ids.shape}")
print(f"Cache K shape: {caches[0][0].shape}")  # [1, 12, 5, 64]

next_token = torch.randint(0, config.vocab_size, (1, 1))
logits, caches = model(next_token, kv_caches=caches, start_pos=5)
print(f"Cached pass - Input: {next_token.shape}")
print(f"Cache K shape: {caches[0][0].shape}")  # [1, 12, 6, 64] - grew by 1!