<a href="https://colab.research.google.com/github/JayThibs/transformers-from-scratch/blob/main/pytorch_lightning_transformer.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Transformers with PyTorch Lightning

References:

* [Simple PyTorch Transformer Example with Greedy Decoding](https://colab.research.google.com/drive/1swXWW5sOLW8zSZBaQBYcGQkQ_Bje_bmI) by Sergey Karayev from Full Stack Deep Learning
* [The Annotated Transformer ++](https://github.com/gordicaleksa/pytorch-original-transformer/blob/main/The%20Annotated%20Transformer%20%2B%2B.ipynb) by gordicaleksa / The AI Epiphany
* [Transformers from Scratch](https://e2eml.school/transformers.html) by End-to-End ML School
* [Notes on GPT-2 and BERT models](https://www.kaggle.com/residentmario/notes-on-gpt-2-and-bert-models) by Aleksey Bilogur
* [GPT-3: Language Models are Few-Shot Learners (Paper Explained)](https://www.youtube.com/watch?v=SY5PvZrJhLE) by Yannic Kilcher
* [Various Annotated Transformer PyTorch Papers](https://nn.labml.ai/transformers/index.html) by labml.ai

This notebook provides a simple, self-contained example of Transformer models:

* using both the encoder and decoder parts
* greedy decoding at inference time

For the first part of the notebook, we'll train on a simple synthetic example, and use PyTorch Lightning since it will greatly simplify the training loop.

When the first transformer paper came out (Attention Is All You Need), the authors used the transformer architecture for machine translation. This means that they needed both the encoder and decoder parts of the architecture to first encode the text, and then decode (generate) the translation.

After that paper, researchers realized that they could use the encoder and decoder separately in order to create models for approaching different tasks. This led to the emergence of BERT-like models (encoder / non-autoregressive) and GPT-like models (decoder / autoregressive).

However, we'll be going over each part of the entire transformer.

Note: Autoregressive means that model only takes into account the text or context that came before our current prediction. Each new prediction is taken into account in the next prediction. Non-autoregressive models take the entire surrounding context! So, a model like BERT uses bi-directionality (that's what the B stands for) and takes in the entire surrounding context for word prediction when trying to predict a masked word. This makes it so that GPT is great at generating text, while BERT is great at taking in an entire piece of text and classifying it.

# Installations

In [None]:
!pip install pytorch_lightning spacy --quiet

[K     |████████████████████████████████| 525 kB 9.2 MB/s 
[K     |████████████████████████████████| 596 kB 58.0 MB/s 
[K     |████████████████████████████████| 332 kB 70.8 MB/s 
[K     |████████████████████████████████| 132 kB 61.7 MB/s 
[K     |████████████████████████████████| 829 kB 54.0 MB/s 
[K     |████████████████████████████████| 1.1 MB 56.8 MB/s 
[K     |████████████████████████████████| 271 kB 65.4 MB/s 
[K     |████████████████████████████████| 160 kB 54.9 MB/s 
[K     |████████████████████████████████| 192 kB 70.9 MB/s 
[?25h  Building wheel for future (setup.py) ... [?25l[?25hdone


# Imports

In [7]:
# Python native libs
import math
import copy
import os
import time
import enum
import argparse

# Visualization imports
import matplotlib.pyplot as plt
import seaborn


# Deep learning imports
import pytorch_lightning as pl
import torch
import torch as nn
import torch.nn.functional as F
from torch.optim import Adam
from torch.utils.tensorboard import SummaryWriter
from torch.hub import download_url_to_file

# Data manipulation
import numpy as np
# from torchtext.data import Dataset, BucketIterator, Field, Example
from torchtext.data.utils import interleave_keys
from torchtext import datasets
# from torchtext.data import Example
import spacy

# BLEU
from nltk.translate.bleu_score import corpus_bleu

# Data

Since this notebook is focused on understanding the transformer architecture in code, we'll be generating simple input and output data for training a model.

Input: An array of values where each element is repeated twice, e.g. [1, 1, 5, 5, 3, 3]

Output: Same as input, but the duplicates are removed, e.g. [1, 5, 3]

In [None]:
N = 10000
S = 32 # target/output sequence length. The input will be twice as long.
C = 128 # number of "classes", including 0, the "start token", and 1, the "end token"

Y = (torch.rand((N * 10, S - 2)) * (C - 2)).long() + 2 # Only generate ints in (2, 99) range

# Make sure we only have unique rows
Y = torch.tensor(np.unique(Y, axis=0)[:N])
X = torch.repeat_interleave(Y, 2, dim=1)

# Add special 0 "start" and 1 "end" tokens to beginning and end
Y = torch.cat([torch.zeros((N, 1)), Y, torch.ones((N, 1))], dim=1).long()
X = torch.cat([torch.zeros((N, 1)), X, torch.ones((N, 1))], dim=1).long()

# Look at the data
print(X, X.shape)
print(Y, Y.shape)
print(Y.min(), Y.max())

tensor([[  0,   2,   2,  ...,  16,  16,   1],
        [  0,   2,   2,  ..., 102, 102,   1],
        [  0,   2,   2,  ...,  49,  49,   1],
        ...,
        [  0,  14,  14,  ...,  29,  29,   1],
        [  0,  14,  14,  ...,   8,   8,   1],
        [  0,  14,  14,  ...,  46,  46,   1]]) torch.Size([10000, 62])
tensor([[  0,   2,   2,  ...,  25,  16,   1],
        [  0,   2,   2,  ...,  56, 102,   1],
        [  0,   2,   2,  ..., 117,  49,   1],
        ...,
        [  0,  14,  70,  ...,  51,  29,   1],
        [  0,  14,  70,  ..., 114,   8,   1],
        [  0,  14,  70,  ..., 126,  46,   1]]) torch.Size([10000, 32])
tensor(0) tensor(127)


In [None]:
# Wrap data in the simplest possible way to enable PyTorch data fetching
# https://pytorch.org/docs/stable/data.html

BATCH_SIZE = 128
TRAIN_FRAC = 0.8

dataset = list(zip(X, Y)) # This fulfills the pytorch.utils.data.Dataset interface

# Split into train and val
num_train = int(N * TRAIN_FRAC)
num_val = N - num_train
data_train, data_val = torch.utils.data.random_split(dataset, (num_train, num_val))

dataloader_train = torch.utils.data.DataLoader(data_train, batch_size=BATCH_SIZE)
dataloader_val = torch.utils.data.DataLoader(data_val, batch_size=BATCH_SIZE)

# Sample batch
x, y = next(iter(data_train))
x, y

(tensor([  0,  10,  10,  96,  96, 122, 122,  81,  81,  64,  64,  75,  75,  26,
          26,  61,  61,  23,  23,  65,  65,  40,  40, 100, 100,  24,  24, 120,
         120,   7,   7, 110, 110, 107, 107, 118, 118, 126, 126,  28,  28,  91,
          91, 119, 119,  67,  67,  29,  29,  45,  45,  74,  74,  62,  62, 114,
         114,  28,  28,  28,  28,   1]),
 tensor([  0,  10,  96, 122,  81,  64,  75,  26,  61,  23,  65,  40, 100,  24,
         120,   7, 110, 107, 118, 126,  28,  91, 119,  67,  29,  45,  74,  62,
         114,  28,  28,   1]))

# Model

In [None]:
class PositionalEncoding(nn.Module):
    """
    Classic Attention-is-all-you-need positional encoding.
    From PyTorch docs.
    """

    def __init__(self, d_model, dropout=0.1, max_len=5000):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)

        pe = torch.zeros(max_len, d_model)
        position = torch.arrange(0, max_len, dtype=torch.float).unsqueeze(1) # gives us the ordered position of words
        div_term = torch.exp(torch.arrange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0).transpose(0, 1)
        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x + self.pe[:x.size(0), :]
        return self.dropout(x)


def generate_square_subsequence_mask(size: int):
    """Generate a triangular (size, size) mask. From PyTorch docs."""
    mask = (torch.triu(torch.ones(size, size)) == 1).transpose(0, 1)
    mask = mask.float().masked_fill(mask == 0, float(-inf)).masked_fill(mask == 1, float(0.0))
    return mask


class Transformer(nn.Module):
    """
    Classic Transformer that both encodes and decodes.

    Prediction-time inference is done greedily.

    NOTE: Start token is hard-coded to be 0, end token to be 1. If changing, update predict() accordingly.
    """

    def __init__(self, num_classes: int, max_output_length:int, dim: int = 128):
        super().__init__()

        # Parameters
        self.dim = dim
        self.max_output_length = max_output_length
        nhead = 4
        num_layers = 4
        dim_feedforward = dim

        # Encoder part
        self.embedding = nn.Embedding(num_classes, dim)
        self.pos_encoder = PositionalEncoding(d_model=self.dim)
        self.transformer_encoder = nn.TransformerEncoder(
            encoder_layer = nn.TransformerEncoderLayer(d_model=self.dim, nhead=nhead, dim_feedforward=dim_feedforward),
            num_layers=num_layers
        )

        # Decoder part
        self.y_mask = generate_square_subsequence_mask(self.max_output_length)
        self.transformer_encoder = nn.TransformerDecoder(
            decoder_layer = nn.TransformerDecoderLayer(d_model=self.dim, nhead=nhead, dim_feedforward=dim_feedforward),
            num_layers=num_layers
        )
        self.fc = nn.Linear(self.dim, num_classes)

        # It is empirically important to initialize weights properly
        self.init_weights()
    
    def init_weights(self):
        pass
    
    def forward():
        pass

    def encode():
        pass

    def decode():
        pass

    def predict():
        pass


model = Transformer(num_classes=C, max_output_length=y.shape[1])
logits = model(x, y[:, :-1])
print(x.shape, y.shape, logits.shape)
print(x[0:1])
print(model.predict(x[0:1]))