# Lab 03: Transformers and Paragraphs

### What You Will Learn

- The fundamental reasons why the Transformer is such
a powerful and popular architecure
- Core intuitions for the behavior of Transformer architectures
- How to use a convolutional encoder and a Transformer decoder to recognize
entire paragraphs of text

## Setup

In [1]:
if "bootstrap" not in locals() or bootstrap.run:
    # path management for Python
    pythonpath, = !echo $PYTHONPATH
    if "." not in pythonpath.split(":"):
        pythonpath = ".:" + pythonpath
        %env PYTHONPATH={pythonpath}
        !echo $PYTHONPATH

    # get both Colab and local notebooks into the same state
    !wget --quiet https://fsdl.me/gist-bootstrap -O bootstrap.py
    import bootstrap

    # change into the lab directory
    # bootstrap.change_to_lab_dir(lab_idx=lab_idx)

    # allow "hot-reloading" of modules
    %load_ext autoreload
    %autoreload 2
    # needed for inline plots in some contexts
    %matplotlib inline

    bootstrap.run = False  # change to True re-run setup
    
!pwd
%ls

^C
/home/amazingguni/git/fsdl-text-recognizer-2022-labs/notebooks
[0m[01;34m__pycache__[0m/  lab01_pytorch.ipynb     lab02b_cnn.ipynb
bootstrap.py  lab02a_lightning.ipynb  lab03_transformers.ipynb


In [None]:
from IPython import display

base_url = "https://fsdl-public-assets.s3.us-west-2.amazonaws.com"

display.Image(url=base_url + "/aiayn-figure-1.png")

In [None]:
from text_recognizer.models import ResnetTransformer


ResnetTransformer.forward??

In [None]:
from text_recognizer.lit_models import TransformerLitModel

TransformerLitModel.training_step??

In [None]:
TransformerLitModel.teacher_forward??

### Intuition #1: Transformers are highly residual.

In [None]:
display.Image(url=base_url + "/transformer-residual-view.png")

### Intuition #2 Transformer heads learn low rank transformations.

In [None]:
display.Latex(r"$\text{softmax}(Q \cdot K^T) \cdot V$")

In [None]:
import matplotlib.pyplot as plt
import torch


low_rank = torch.randn(100, 1) @ torch.randn(1, 100)
full_rank = torch.randn(100, 100)
plt.figure(); plt.title("rank 1/100 matrix"); plt.imshow(low_rank, cmap="Greys"); plt.axis("off")
plt.figure(); plt.title("rank 100/100 matrix");  plt.imshow(full_rank, cmap="Greys"); plt.axis("off");

### Residuality and low rank together make Transformers less like a sequence model and more like a computer (that we can take gradients through).

In [None]:
display.Image(url=base_url + "/transformer-layer-residual.png")

In [None]:
display.Image(url=base_url + "/residual-stream-read-write.png")

In [None]:
display.Image(url=base_url + "/residual-token-to-token.png")

### Implementation detail: Transformers are position-insensitive by default.

In [None]:
from text_recognizer.models import transformer_util


attention_mask = transformer_util.generate_square_subsequent_mask(100)

ax = plt.matshow(torch.exp(attention_mask.T));  cb = plt.colorbar(ticks=[0, 1], fraction=0.05)
plt.ylabel("Can the embedding at this index"); plt.xlabel("attend to embeddings at this index?")
print(attention_mask[:10, :10].T); cb.set_ticklabels([False, True]);

In [None]:
PositionalEncoder = transformer_util.PositionalEncoding(d_model=50, dropout=0.0, max_len=200)

pe = PositionalEncoder.pe.squeeze().T[:, :]  # placing sequence dimension along the "x-axis"

ax = plt.matshow(pe); plt.colorbar(ticks=[-1, 0, 1], fraction=0.05)
plt.xlabel("sequence index"); plt.ylabel("embedding dimension"); plt.title("Positional Encoding", y=1.1)
print(pe[:4, :8])

In [None]:
fake_embeddings = torch.randn_like(pe) * 0.5

ax = plt.matshow(fake_embeddings); plt.colorbar(ticks=torch.arange(-2, 3), fraction=0.05)
plt.xlabel("sequence index"); plt.ylabel("embedding dimension"); plt.title("Embeddings Without Positional Encoding", y=1.1)

fake_embeddings_with_pe = fake_embeddings + pe

plt.matshow(fake_embeddings_with_pe); plt.colorbar(ticks=torch.arange(-2, 3), fraction=0.05)
plt.xlabel("sequence index"); plt.ylabel("embedding dimension"); plt.title("Embeddings With Positional Encoding", y=1.1);

# Using Transformers to read paragraphs of text

In [None]:
import text_recognizer.data


emnist_lines = text_recognizer.data.EMNISTLines()
line_cnn = text_recognizer.models.LineCNNSimple(emnist_lines.config())

# for sliding, see the for loop over range(S)
line_cnn.forward??

In [None]:
iam_paragraphs = text_recognizer.data.IAMParagraphs()

iam_paragraphs.prepare_data()
iam_paragraphs.setup()
xs, ys = next(iter(iam_paragraphs.val_dataloader()))

iam_paragraphs

In [None]:
import random

import numpy as np
import wandb


def show(y):
    y = y.detach().cpu()  # bring back from accelerator if it's being used
    return "".join(np.array(iam_paragraphs.mapping)[y]).replace("<P>", "")

idx = random.randint(0, len(xs))

print(show(ys[idx]))
wandb.Image(xs[idx]).image

In [None]:
import text_recognizer.models


rnt = text_recognizer.models.ResnetTransformer(data_config=iam_paragraphs.config())

In [None]:
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

rnt.to(device); xs = xs.to(device); ys = ys.to(device);

In [None]:
resnet_embedding, = rnt.resnet(xs[idx:idx+1].repeat(1, 3, 1, 1))
 # resnet is designed for RGB images, so we replicate the input across channels 3 times

In [None]:
resnet_idx = random.randint(0, len(resnet_embedding))  # re-execute to view a different channel
plt.matshow(resnet_embedding[resnet_idx].detach().cpu(), cmap="Greys_r");
plt.axis("off"); plt.colorbar(fraction=0.05);

In [None]:
preds, = rnt(xs[idx:idx+1])  # can take up to two minutes on a CPU. Transformers ❤️ GPUs

In [None]:
print(show(preds.cpu()))
wandb.Image(xs[idx]).image

In [None]:
import text_recognizer.lit_models

lit_rnt = text_recognizer.lit_models.TransformerLitModel(rnt)

In [None]:
forcing_outs, = lit_rnt.teacher_forward(xs[idx:idx+1], ys[idx:idx+1])

In [None]:
forcing_preds = torch.argmax(forcing_outs, dim=0)

print(show(forcing_preds.cpu()))
wandb.Image(xs[idx]).image

## Training the `ResNetTransformer`

In [None]:
import torch

gpus = int(torch.cuda.is_available())

if gpus:
    !nvidia-smi
else:
    print("watch out! working with this model on a typical CPU is not feasible")

In [None]:
%%time
# above %%magic times the cell, useful as a poor man's profiler

%run training/run_experiment.py --data_class IAMParagraphs --model_class ResnetTransformer --loss transformer \
  --gpus={gpus} --batch_size 4 --precision 16 \
  --limit_train_batches 10 --limit_test_batches 1 --limit_val_batches 2