In [9]:
import torch
import torch.nn as nn
from torch.nn import functional as F
import math

In [15]:
class GPTConfig:
    seq: int = 1024
    vocab: int = 50_000
    n_layer: int = 12
    n_head: int = 12
    n_embd: int = 768


In [16]:
class CausalSelfAttention(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.c_attn = nn.Linear(config.n_embd, 3*config.n_embd)
        self.seq = config.seq
        self.n_embd = config.n_embd
        self.n_head = config.n_head

        # [1,1,seq,seq]
        self.register_buffer("mask", torch.tril(torch.ones(self.seq, self.seq)).view(1,1,self.seq,self.seq) )


    def forward(self, x):
        B, S, E = x.size()
        assert self.seq==S and self.n_emd==E
        # [batch, seq, 3*embd], split into Q, K, V
        Q, K, V = self.c_attn(x).split(self.n_embd, dim=2)

        assert self.n_embd%self.n_head==0
        # [batch, n_head, seq, d_k]
        Q = Q.view(B, self.seq, self.n_head, -1).transpose(1,2)
        K = K.view(B, self.seq, self.n_head, -1).transpose(1,2)
        V = V.view(B, self.seq, self.n_head, -1).transpose(1,2)

        att = (Q @ K.transpose(2,3))/math.sqrt(K.size(-1))
        att = att.masked_fill(self.mask==0, -math.inf)
        att = F.softmax(att, dim=-1)
        # [batch, n_head, seq, d_k]
        y = att@V
        # [batch, seq, n_embd]
        y.transpose(1,2).contiguous().view(B, self.seq, self.n_embd)
        return y


class MLP(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.c_fc = nn.Linear(config.n_embd, 4*config.n_embd)
        self.gelu = nn.GELU()
        self.c_proj = nn.Linear(4*config.n_embd, config.n_embd)
    
    def forward(self, x):
        x = self.c_fc(x)
        x = self.gelu(x)
        x = self.c_proj(x)
        return x

class Block(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.attn = CausalSelfAttention(config)
        self.mlp = MLP(config)
        self.ln_1 = nn.LayerNorm(config.n_embd)
        self.ln_2 = nn.LayerNorm(config.n_embd)
    
    def forward(self,x):
        x = x + self.attn(self.ln_1(x))
        x  =x + self.mlp(self.ln_2(x))
        return x


In [17]:
class GPT(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.vocab = config.vocab

        self.transformer = nn.ModuleDict(dict(
            wte = nn.Embedding(config.vocab, config.n_embd),
            wpe = nn.Embedding(config.seq, config.n_embd),
            h = nn.ModuleList([Block(config) for _ in range(config.n_layer)]),
            ln_f = nn.LayerNorm(config.n_embd)
        ))
        self.lm_head = nn.Linear(config.n_embd, config.vocab)
        self.criterion = nn.CrossEntropyLoss()
    
    def forward(self,x, target):
        # target: [batch, seq]
        tok_embd = self.transformer.wte(x)
        pos_embd = self.transformer.wpe(x)
        x = tok_embd+pos_embd
        for block in self.transformer.h:
            x = block(x)
        x = self.transformer.ln_f(x)
        logits = self.lm_head(x)
        # logits: [batch, seq, vocab], target: [batch, seq]
        loss = self.criterion( logits.view(-1, self.vocab), target.view(-1) )
        return logits, loss

