<a href="https://colab.research.google.com/gist/JordanLazzaro/81cf023d5d5478a5958cf885c8891504/transformersinanutshell.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Transformers in a Nutshell
An educational but usable example of a (character level) GPT-2 transformer language model.

Feel free to play around with the dataset and hyperparams, the current ones are both chosen for ease of training and understandability.

In [None]:
!pip install -q wget wandb pytorch-lightning

import os
import sys
import wget
from tqdm import tqdm

# for logging metrics to wandb
import wandb
wandb.login()

# for dataset
from torch.utils.data import Dataset
from torch.utils.data.dataloader import DataLoader

# for model
import math
import torch
from torch import nn
import torch.nn.functional as F
from torchmetrics.functional import accuracy
import pytorch_lightning as pl
from pytorch_lightning.loggers import WandbLogger

In [2]:
class GPT2Config:
    """
    'gpt2-mini' config from minGPT
    """
    # data
    default_data_url = 'https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt'
    
    # model
    vocab_size = None
    max_seq_len = 128
    emb_size = 192
    num_blocks = 6
    num_heads = 6
    fc_hidden_dim = 4 * emb_size
    
    # regularization
    attn_dropout_p = 0.1
    res_dropout_p = 0.1
    emb_dropout_p = 0.1
    
    # training
    max_learning_rate = 2.5e-4
    batch_size = 512
    device = 'cuda' if torch.cuda.is_available() else 'cpu'

    def __init__(self, **kwargs):
        """ any extra config args """
        for k, v in kwargs.items():
            setattr(self, k, v)

In [3]:
class CharDataset(Dataset):
    def __init__(self, config, data=None):
        """
        A toy dataset class for charGPT modified from the minGPT repo
        """
        self.config = config
        if data is None:
            filename = wget.download(config.default_data_url)
            data = open(filename, 'r').read()

        chars = sorted(list(set(data)))
        data_size, vocab_size = len(data), len(chars)
        print('data has %d characters, %d unique.' % (data_size, vocab_size))

        self.stoi = { ch:i for i,ch in enumerate(chars) }
        self.itos = { i:ch for i,ch in enumerate(chars) }
        self.vocab_size = vocab_size
        self.data = data

    def __len__(self):
        return len(self.data) - self.config.max_seq_len

    def __getitem__(self, idx):
        # grab a chunk of (block_size + 1) characters from the data
        chunk = self.data[idx:idx + self.config.max_seq_len + 1]
        # encode every character to an integer
        dix = [self.stoi[s] for s in chunk]
        # return as tensors
        x = torch.tensor(dix[:-1], dtype=torch.long)
        y = torch.tensor(dix[1:], dtype=torch.long)
        
        return x, y

In [4]:
class CausalMultiHeadAttention(nn.Module):
    def __init__(self, config):
        super().__init__()
        assert config.emb_size % config.num_heads == 0

        self.W_Q = nn.Linear(config.emb_size, config.emb_size, bias=False)
        self.W_K = nn.Linear(config.emb_size, config.emb_size, bias=False)
        self.W_V = nn.Linear(config.emb_size, config.emb_size, bias=False)
        self.res_proj = nn.Linear(config.emb_size, config.emb_size, bias=False)

        self.attn_dropout = nn.Dropout(config.attn_dropout_p)
        self.res_dropout = nn.Dropout(config.res_dropout_p)
        
        self.register_buffer(
            'mask',
            torch.tril(torch.ones(config.max_seq_len, config.max_seq_len))
        )
        
        self.num_heads = config.num_heads

    def forward(self, x):
        # step 0) size: (b_s, s_l, e_s)
        batch_size, seq_len, emb_size = x.size()
        head_dim = emb_size // self.num_heads
        
        # step 1) size: (b_s, s_l, e_s) -> (b_s, s_l, n_h, h_d)
        Q = self.W_Q(x).reshape(batch_size, seq_len, self.num_heads, head_dim)
        K = self.W_K(x).reshape(batch_size, seq_len, self.num_heads, head_dim)
        V = self.W_V(x).reshape(batch_size, seq_len, self.num_heads, head_dim)

        # step 2) size: (b_s, s_l, n_h, h_d) -> (b_s, n_h, s_l, h_d)
        Q = Q.transpose(1, 2)
        K = K.transpose(1, 2)
        V = V.transpose(1, 2)

        # step 3) size: (b_s, n_h, s_l, h_d) x (b_s, n_h, h_d, s_l) = (b_s, n_h, s_l, s_l)
        scores = Q @ K.transpose(-2, -1) * (1.0 / math.sqrt(head_dim))

        # step 4) mask score values occuring ahead of a given element's position
        scores = scores.masked_fill(self.mask[:seq_len, :seq_len]==0, float('-inf'))

        # step 5) row-wise softmax (prob. dist. over values for every query)
        attn = F.softmax(scores, dim=-1)
        attn = self.attn_dropout(attn)

        # step 6) size: (b_s, n_h, s_l, s_l) x (b_s, n_h, s_l, h_d) = (b_s, n_h, s_l, h_d)
        out = attn @ V

        # step 7) size: (b_s, n_h, s_l, h_d) -> (b_s, s_l, e_s)
        out = out.transpose(1, 2).reshape(batch_size, seq_len, emb_size)
        
        # step 8) project concatentated heads into embedding space
        out = self.res_proj(out)
        out = self.res_dropout(out)

        return out

In [5]:
class MLP(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.hidden = nn.Linear(config.emb_size, config.fc_hidden_dim)
        self.gelu = nn.GELU()
        self.res_proj = nn.Linear(config.fc_hidden_dim, config.emb_size)
        self.res_dropout = nn.Dropout(config.res_dropout_p)

    def forward(self, x):
        x = self.hidden(x)
        x = self.gelu(x)
        x = self.res_proj(x)
        x = self.res_dropout(x)

        return x

In [6]:
class TransformerBlock(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.ln1 = nn.LayerNorm(config.emb_size)
        self.attn = CausalMultiHeadAttention(config)
        self.ln2 = nn.LayerNorm(config.emb_size)
        self.mlp = MLP(config)

    def forward(self, x):
        x = self.ln1(x)
        x = x + self.attn(x)
        x = self.ln2(x)
        x = x + self.mlp(x)

        return x

In [7]:
class GPT2(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config

        self.tok_emb = nn.Embedding(dataset.vocab_size, config.emb_size)
        self.pos_emb = nn.Embedding(config.max_seq_len, config.emb_size)
        self.emb_dropout = nn.Dropout(config.emb_dropout_p)

        self.blocks = nn.Sequential(*[TransformerBlock(config) for _ in range(config.num_blocks)])
        self.ln = nn.LayerNorm(config.emb_size)
        self.head = nn.Linear(config.emb_size, config.vocab_size, bias=False)
        
        # parameter with list of indices to slice into for retrieving pos_emb
        self.pos_idxs = nn.Parameter(torch.arange(0, config.max_seq_len), requires_grad=False)

        self.apply(self._init_weights)

    def forward(self, x):
        batch_size, seq_len = x.size()
        tok_embs = self.tok_emb(x)
        pos_embs = self.pos_emb(self.pos_idxs[:seq_len])

        seq = self.emb_dropout(tok_embs + pos_embs)
        seq = self.blocks(seq)
        seq = self.ln(seq)
        out = self.head(seq)

        return out

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
        elif isinstance(module, nn.LayerNorm):
            torch.nn.init.zeros_(module.bias)
            torch.nn.init.ones_(module.weight)
        
        for name, param in self.named_parameters():
            if name.endswith('res_proj.weight'):
                # two residual connections per block (i.e. attn and mlp)
                torch.nn.init.normal_(param, mean=0.0, std=0.02/math.sqrt(2 * config.num_blocks))

In [8]:
class GPT2LitModel(pl.LightningModule):
    def __init__(self, model, config):
        super().__init__()
        self.model = model
        self.config = config

    def forward(self, x):
        return self.model(x)

    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self.model(x)
        loss = F.cross_entropy(logits.reshape(-1, logits.size(-1)), y.reshape(-1))
        self.log('train_loss', loss)
        
        return loss

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.config.max_learning_rate)
        
        return optimizer

In [None]:
# wiring everything up to start training
config = GPT2Config()
dataset = CharDataset(config)
config.vocab_size = dataset.vocab_size

train_loader = DataLoader(dataset, num_workers=4, batch_size=config.batch_size, shuffle=True)

model = GPT2(config)
lit_model = GPT2LitModel(model, config)

wandb_logger = WandbLogger()
trainer = pl.Trainer(logger=wandb_logger, accelerator="gpu", devices=1, max_epochs=10)

# trainer without wandb logging
# trainer = pl.Trainer(accelerator="gpu", devices=1, max_epochs=10)

In [None]:
trainer.fit(lit_model, train_loader)

In [None]:
# save our trained model so we can use it later
from google.colab import drive
drive.mount('/content/gdrive')
!ls /content/gdrive/My\ Drive

model_save_name = 'shakespeareGPT.pt'
path = f'/content/gdrive/My Drive/{model_save_name}'
torch.save(lit_model.state_dict(), path)

In [None]:
# simple helper function to prompt model and get readable result
@torch.no_grad()
def get_predictions(model, prompt, max_seq_len=128):
    input = torch.LongTensor([dataset.stoi[i] for i in prompt]).unsqueeze(0)
    while input.size(1) < max_seq_len:
        logits = model(input)
        logits = logits[:, -1, :]
        probs = F.softmax(logits)
        idxs = torch.multinomial(probs, num_samples=1)
        input = torch.cat((input, idxs), dim=1)
    
    out_str = ''.join([dataset.itos[int(i)] for i in input[0].tolist()])

    return out_str

In [None]:
# loading saved model to use for inference
# from google.colab import drive
drive.mount('/content/gdrive')
!ls /content/gdrive/My\ Drive

model_save_name = 'shakespeareGPT.pt'
path = F"/content/gdrive/My Drive/{model_save_name}"
lit_model.load_state_dict(torch.load(path))

In [None]:
prompt = 'Who art thou?' # put your propmt here!
preds_str = get_predictions(lit_model, prompt)
print(preds_str)