In [1]:
import math
from dataclasses import dataclass
import torch
import torch.nn as nn
from torch.nn import functional as F

In [4]:
@dataclass
class GPTconfig :
    block_size: int = 1024
    vocab_size: int = 50257
    n_layer: int = 12
    n_head: int = 12
    n_embd: int = 768

In [3]:
class CausalSelfAttention(nn.Module):
    def __init__(self, config):
        super().__init__()
        assert config.n_embd % config.n_head == 0
        # key, query, value projections for all heads, but in a batch
        self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd)
        # output projection
        self.c_proj = nn.Linear(config.n_embd, config.n_embd)
        self.c_proj.NANOGPT_SCALE_INIT = 1
        # regularization
        self.n_head = config.n_head
        self.n_embd = config.n_embd
    def forward(self, x):
        B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)
        # calculate query, key, values for all heads in batch and move head forward to be the batch dim
        # nh is "number of heads", hs is "head size", and C (number of channels) = nh * hs
        # e.g. in GPT-2 (124M), n_head=12, hs=64, so nh*hs=C=768 channels in the Transformer
        qkv = self.c_attn(x)
        q, k, v = qkv.split(self.n_embd, dim=2)
        k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
        q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
        v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
        y = F.scaled_dot_product_attention(q, k, v, is_causal=True) # flash attention
        y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side
        # output projection
        y = self.c_proj(y)
        return y
    
class MLP(nn.Module) :
    def __init__(self, config) :
        self.c_fc = nn.Linear(config.n_embd, 4*config.n_embd),
        self.gelu = nn.GELU(approximate='tanh'),
        self.c_proj = nn.Linear(4*config.n_embd, config.n_embd)
    def forward(self, x) :
        x = self.c_fc(x)
        x = self.gelu(x)
        x = self.c_proj(x)
        return x

class Block(nn.Module) :
    def __init__(self, config) :
        super().__init__()
        self.ln1 = nn.LayerNorm(config.n_embd),
        self.attn = CausalSelfAttention(config),
        self.ln2 = nn.LayerNorm(config.n_embd),
        self.mlp = MLP(config)
    def forward(self, x) :
        x = x + self.attn(self.ln1(x))
        x = x + self.mlp(self.ln2(x))
        return x

class GPT(nn.Module) :
    def __init__(self, config) :
        super().__init__()
        self.config = config
        
        self.transformer = nn.ModuleDict(Dict(
            wte = nn.Embedding(config.vocab_size, vocab.n_embd),
            wpt = nn.Embedding(config.block_size, vocab.n_embd),
            h = nn.ModuleList([Block(config) for _ in range(config.n_layer)]),
            ln_f = nn.LayerNorm(config.n_embd)
        ))
        self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)